add_cot_vllm.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. # VLLM_WORKER_MULTIPROC_METHOD=spawn python scripts/add_cot_vllm.py --model_id meta-llama/Llama-3.3-70B-Instruct --dataset-path datasets/2_ready_for_CoT/func-calling-multi-turn-final/ --config configs/config.yaml --output-path "datasets/3_CoT_added/func-calling-multi-turn-final/" --batch-size 96 --max-seq-len 16000
  2. import argparse
  3. import json
  4. import os
  5. import random
  6. import re
  7. from typing import Any, Dict, List
  8. import torch
  9. import yaml
  10. from datasets import Dataset, load_from_disk
  11. from tqdm import tqdm
  12. from transformers import AutoProcessor
  13. from vllm import LLM, SamplingParams
  14. class LLM_Singleton:
  15. _instance = None
  16. def __new__(
  17. cls,
  18. model_id,
  19. max_model_len=64000,
  20. max_num_seqs=16,
  21. enforce_eager=True,
  22. debug=False,
  23. ):
  24. if cls._instance is None:
  25. cls._instance = super(LLM_Singleton, cls).__new__(cls)
  26. cls._instance._initialize(
  27. model_id,
  28. tensor_parallel_size=torch.cuda.device_count(),
  29. max_model_len=max_model_len,
  30. max_num_seqs=max_num_seqs,
  31. enforce_eager=enforce_eager,
  32. debug=debug,
  33. )
  34. return cls._instance
  35. def _initialize(
  36. self,
  37. model_id,
  38. tensor_parallel_size=1,
  39. max_model_len=64000,
  40. max_num_seqs=16,
  41. enforce_eager=True,
  42. debug=False,
  43. ):
  44. if debug:
  45. print(
  46. f"Initializing LLM with params: {model_id}, {tensor_parallel_size}, {max_model_len}"
  47. )
  48. self.llm = LLM(
  49. model_id,
  50. tensor_parallel_size=tensor_parallel_size,
  51. max_model_len=max_model_len,
  52. max_num_seqs=max_num_seqs,
  53. enforce_eager=enforce_eager,
  54. gpu_memory_utilization=0.95,
  55. )
  56. self.processor = AutoProcessor.from_pretrained(model_id)
  57. def load_system_prompt(yaml_path: str) -> str:
  58. """Load system prompt from YAML config."""
  59. with open(yaml_path, "r") as f:
  60. config = yaml.safe_load(f)
  61. return config["system_prompt"]
  62. def create_chat_message(system_prompt: str, conversation: str) -> List[Dict[str, Any]]:
  63. """Create properly formatted chat messages."""
  64. messages = [
  65. {"role": "system", "content": system_prompt},
  66. {"role": "user", "content": conversation},
  67. ]
  68. return messages
  69. def parse_json_output(output_text: str) -> List[Dict[str, str]]:
  70. """Parse and clean model output to ensure valid JSON."""
  71. output_text = output_text.strip()
  72. json_match = re.search(r"\[.*\]", output_text, re.DOTALL)
  73. if json_match:
  74. output_text = json_match.group(0)
  75. try:
  76. if output_text.startswith('"') and output_text.endswith('"'):
  77. output_text = json.loads(output_text)
  78. result = json.loads(output_text)
  79. # Clean the result to remove 'tool': None entries
  80. cleaned_result = []
  81. for item in result:
  82. cleaned_item = {
  83. k: v for k, v in item.items() if k != "tool" or v is not None
  84. }
  85. cleaned_result.append(cleaned_item)
  86. return cleaned_result
  87. except json.JSONDecodeError as e:
  88. print(f"Error parsing output: {e}")
  89. return None
  90. def process_dataset(
  91. dataset,
  92. system_prompt: str,
  93. start_index: int = 0,
  94. end_index: int = None,
  95. n_samples: int = 0,
  96. model_instance: Any = None,
  97. batch_size: int = 16,
  98. max_seq_len: int = 64000,
  99. ) -> List[Dict]:
  100. """Process the dataset in parallel batches and return results."""
  101. if end_index is None:
  102. end_index = len(dataset)
  103. else:
  104. end_index = min(end_index, len(dataset))
  105. # Handle random sampling
  106. dataset_size = end_index - start_index
  107. if n_samples > 0:
  108. # If n_samples is larger than dataset size, use full dataset size
  109. n_samples = min(n_samples, dataset_size)
  110. # Generate random indices within the specified range
  111. indices = random.sample(range(start_index, end_index), n_samples)
  112. dataset_slice = dataset.select(indices)
  113. else:
  114. # If no sampling requested, use the full range
  115. dataset_slice = dataset.select(range(start_index, end_index))
  116. results = []
  117. for i in tqdm(range(0, len(dataset_slice), batch_size), desc=f"Processing batches"):
  118. batch_slice = dataset_slice.select(
  119. range(i, min(i + batch_size, len(dataset_slice)))
  120. )
  121. try:
  122. batch_inputs = []
  123. for item in batch_slice:
  124. conversation_str = json.dumps(
  125. item["conversations"], ensure_ascii=False, indent=2
  126. )
  127. messages = create_chat_message(system_prompt, conversation_str)
  128. input_text = model_instance.processor.apply_chat_template(
  129. messages, add_generation_prompt=True, tokenize=False
  130. )
  131. batch_inputs.append(input_text)
  132. # max_tokens here is per-batch generation limit
  133. sampling_params = SamplingParams(
  134. max_tokens=max_seq_len, temperature=0.1, top_p=0.95
  135. )
  136. outputs = model_instance.llm.generate(batch_inputs, sampling_params)
  137. for item, output in zip(batch_slice, outputs):
  138. enhanced_convos = parse_json_output(output.outputs[0].text.strip())
  139. if enhanced_convos is None:
  140. print(
  141. f"Warning: Failed to parse output for item {item.get('id', 'unknown')}"
  142. )
  143. enhanced_convos = item["conversations"]
  144. results.append(
  145. {
  146. "id": item["id"],
  147. "conversations": item["conversations"],
  148. "cot_conversations": enhanced_convos,
  149. }
  150. )
  151. except Exception as e:
  152. print(f"Error processing batch starting at {i}: {str(e)}")
  153. continue
  154. return results
  155. def main():
  156. parser = argparse.ArgumentParser(
  157. description="Process dataset to enhance conversations with CoT reasoning"
  158. )
  159. parser.add_argument(
  160. "--model_id", type=str, required=True, help="Model name or path"
  161. )
  162. parser.add_argument(
  163. "--config",
  164. type=str,
  165. required=True,
  166. help="Path to YAML config with system prompt",
  167. )
  168. parser.add_argument(
  169. "--output-path",
  170. type=str,
  171. required=True,
  172. help="Output dataset directory path",
  173. )
  174. parser.add_argument(
  175. "--dataset-path", type=str, required=True, help="Input dataset path"
  176. )
  177. parser.add_argument(
  178. "--start-index", type=int, default=0, help="Starting index (inclusive)"
  179. )
  180. parser.add_argument("--end-index", type=int, help="Ending index (exclusive)")
  181. parser.add_argument(
  182. "--n-samples",
  183. type=int,
  184. default=0,
  185. help="Number of random samples to process. If 0, process all samples in range",
  186. )
  187. parser.add_argument(
  188. "--batch-size", type=int, default=16, help="Batch size for processing"
  189. )
  190. parser.add_argument("--debug", action="store_true", help="Enable debug mode")
  191. parser.add_argument(
  192. "--max-seq-len",
  193. type=int,
  194. default=64000,
  195. help="Maximum sequence length for generation per batch",
  196. )
  197. parser.add_argument(
  198. "--max-num-seqs",
  199. type=int,
  200. default=16,
  201. help="Maximum number of sequences in batch",
  202. )
  203. parser.add_argument(
  204. "--enforce-eager",
  205. action="store_true",
  206. help="Whether to enforce eager execution",
  207. )
  208. args = parser.parse_args()
  209. # Set spawn method for multiprocessing
  210. os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
  211. # Load system prompt and dataset
  212. system_prompt = load_system_prompt(args.config)
  213. dataset = load_from_disk(args.dataset_path)
  214. if isinstance(dataset, dict):
  215. dataset = dataset["train"]
  216. # Initialize VLLM instance
  217. model_instance = LLM_Singleton(
  218. model_id=args.model_id,
  219. max_model_len=args.max_seq_len, # Use the passed max_seq_len
  220. max_num_seqs=16,
  221. enforce_eager=True,
  222. debug=args.debug,
  223. )
  224. # Process dataset and get results
  225. results = process_dataset(
  226. dataset=dataset,
  227. system_prompt=system_prompt,
  228. start_index=args.start_index,
  229. end_index=args.end_index,
  230. n_samples=args.n_samples,
  231. model_instance=model_instance,
  232. batch_size=args.batch_size,
  233. max_seq_len=args.max_seq_len,
  234. )
  235. # Convert results to HuggingFace dataset
  236. output_dataset = Dataset.from_dict(
  237. {
  238. "id": [r["id"] for r in results],
  239. "conversations": [r["conversations"] for r in results],
  240. "cot_conversations": [r["cot_conversations"] for r in results],
  241. }
  242. )
  243. # Save dataset
  244. output_dataset.save_to_disk(args.output_path)
  245. if __name__ == "__main__":
  246. main()