add_cot_vllm.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # VLLM_WORKER_MULTIPROC_METHOD=spawn python scripts/add_cot_vllm.py --model_id meta-llama/Llama-3.3-70B-Instruct --dataset-path datasets/2_ready_for_CoT/func-calling-multi-turn-final/ --config configs/config.yaml --output-path "datasets/3_CoT_added/func-calling-multi-turn-final/" --batch-size 96 --max-seq-len 16000
  2. import argparse
  3. import json
  4. import os
  5. import random
  6. import re
  7. from typing import Any, Dict, List
  8. import torch
  9. import yaml
  10. from datasets import Dataset, load_from_disk
  11. from tqdm import tqdm
  12. from transformers import AutoProcessor
  13. from vllm import LLM, SamplingParams
  14. class LLM_Singleton:
  15. _instance = None
  16. def __new__(
  17. cls,
  18. model_id,
  19. max_model_len=64000,
  20. max_num_seqs=16,
  21. enforce_eager=True,
  22. debug=False,
  23. ):
  24. if cls._instance is None:
  25. cls._instance = super(LLM_Singleton, cls).__new__(cls)
  26. cls._instance._initialize(
  27. model_id,
  28. tensor_parallel_size=torch.cuda.device_count(),
  29. max_model_len=max_model_len,
  30. max_num_seqs=max_num_seqs,
  31. enforce_eager=enforce_eager,
  32. debug=debug,
  33. )
  34. return cls._instance
  35. def _initialize(
  36. self,
  37. model_id,
  38. tensor_parallel_size=1,
  39. max_model_len=64000,
  40. max_num_seqs=16,
  41. enforce_eager=True,
  42. debug=False,
  43. ):
  44. if debug:
  45. print(
  46. f"Initializing LLM with params: {model_id}, {tensor_parallel_size}, {max_model_len}"
  47. )
  48. self.llm = LLM(
  49. model_id,
  50. tensor_parallel_size=tensor_parallel_size,
  51. max_model_len=max_model_len,
  52. max_num_seqs=max_num_seqs,
  53. enforce_eager=enforce_eager,
  54. gpu_memory_utilization=0.95,
  55. )
  56. self.processor = AutoProcessor.from_pretrained(model_id)
  57. def load_system_prompt(yaml_path):
  58. with open(yaml_path, "r") as f:
  59. config = yaml.safe_load(f)
  60. return config["system_prompt"]
  61. def create_chat_message(system_prompt, conversation):
  62. messages = [
  63. {"role": "system", "content": system_prompt},
  64. {"role": "user", "content": conversation},
  65. ]
  66. return messages
  67. def parse_json_output(output_text):
  68. output_text = output_text.strip()
  69. json_match = re.search(r"\[.*\]", output_text, re.DOTALL)
  70. if json_match:
  71. output_text = json_match.group(0)
  72. try:
  73. if output_text.startswith('"') and output_text.endswith('"'):
  74. output_text = json.loads(output_text)
  75. result = json.loads(output_text)
  76. # Clean the result to remove 'tool': None entries
  77. cleaned_result = []
  78. for item in result:
  79. cleaned_item = {
  80. k: v for k, v in item.items() if k != "tool" or v is not None
  81. }
  82. cleaned_result.append(cleaned_item)
  83. return cleaned_result
  84. except json.JSONDecodeError as e:
  85. print(f"Error parsing output: {e}")
  86. return None
  87. def process_dataset(
  88. dataset,
  89. system_prompt: str,
  90. start_index: int = 0,
  91. end_index: int = None,
  92. n_samples: int = 0,
  93. model_instance: Any = None,
  94. batch_size: int = 16,
  95. max_seq_len: int = 64000,
  96. ) -> List[Dict]:
  97. if end_index is None:
  98. end_index = len(dataset)
  99. else:
  100. end_index = min(end_index, len(dataset))
  101. # Handle random sampling
  102. dataset_size = end_index - start_index
  103. if n_samples > 0:
  104. n_samples = min(n_samples, dataset_size)
  105. indices = random.sample(range(start_index, end_index), n_samples)
  106. dataset_slice = dataset.select(indices)
  107. else:
  108. dataset_slice = dataset.select(range(start_index, end_index))
  109. results = []
  110. for i in tqdm(range(0, len(dataset_slice), batch_size), desc=f"Processing batches"):
  111. batch_slice = dataset_slice.select(
  112. range(i, min(i + batch_size, len(dataset_slice)))
  113. )
  114. try:
  115. batch_inputs = []
  116. for item in batch_slice:
  117. conversation_str = json.dumps(
  118. item["conversations"], ensure_ascii=False, indent=2
  119. )
  120. messages = create_chat_message(system_prompt, conversation_str)
  121. input_text = model_instance.processor.apply_chat_template(
  122. messages, add_generation_prompt=True, tokenize=False
  123. )
  124. batch_inputs.append(input_text)
  125. sampling_params = SamplingParams(
  126. max_tokens=max_seq_len, temperature=0.1, top_p=0.95
  127. )
  128. outputs = model_instance.llm.generate(batch_inputs, sampling_params)
  129. for item, output in zip(batch_slice, outputs):
  130. enhanced_convos = parse_json_output(output.outputs[0].text.strip())
  131. if enhanced_convos is None:
  132. print(
  133. f"Warning: Failed to parse output for item {item.get('id', 'unknown')}"
  134. )
  135. enhanced_convos = item["conversations"]
  136. results.append(
  137. {
  138. "id": item["id"],
  139. "conversations": item["conversations"],
  140. "cot_conversations": enhanced_convos,
  141. }
  142. )
  143. except Exception as e:
  144. print(f"Error processing batch starting at {i}: {str(e)}")
  145. continue
  146. return results
  147. def main():
  148. parser = argparse.ArgumentParser(
  149. description="Process dataset to enhance conversations with CoT reasoning"
  150. )
  151. parser.add_argument(
  152. "--model_id", type=str, required=True, help="Model name or path"
  153. )
  154. parser.add_argument(
  155. "--config",
  156. type=str,
  157. required=True,
  158. help="Path to YAML config with system prompt",
  159. )
  160. parser.add_argument(
  161. "--output-path",
  162. type=str,
  163. required=True,
  164. help="Output dataset directory path",
  165. )
  166. parser.add_argument(
  167. "--dataset-path", type=str, required=True, help="Input dataset path"
  168. )
  169. parser.add_argument(
  170. "--start-index", type=int, default=0, help="Starting index (inclusive)"
  171. )
  172. parser.add_argument("--end-index", type=int, help="Ending index (exclusive)")
  173. parser.add_argument(
  174. "--n-samples",
  175. type=int,
  176. default=0,
  177. help="Number of random samples to process. If 0, process all samples in range",
  178. )
  179. parser.add_argument(
  180. "--batch-size", type=int, default=16, help="Batch size for processing"
  181. )
  182. parser.add_argument("--debug", action="store_true", help="Enable debug mode")
  183. parser.add_argument(
  184. "--max-seq-len",
  185. type=int,
  186. default=64000,
  187. help="Maximum sequence length for generation per batch",
  188. )
  189. parser.add_argument(
  190. "--max-num-seqs",
  191. type=int,
  192. default=16,
  193. help="Maximum number of sequences in batch",
  194. )
  195. parser.add_argument(
  196. "--enforce-eager",
  197. action="store_true",
  198. help="Whether to enforce eager execution",
  199. )
  200. args = parser.parse_args()
  201. # Set spawn method for multiprocessing
  202. os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
  203. system_prompt = load_system_prompt(args.config)
  204. dataset = load_from_disk(args.dataset_path)
  205. if isinstance(dataset, dict):
  206. dataset = dataset["train"]
  207. # Initialize VLLM instance
  208. model_instance = LLM_Singleton(
  209. model_id=args.model_id,
  210. max_model_len=args.max_seq_len,
  211. max_num_seqs=16,
  212. enforce_eager=True,
  213. debug=args.debug,
  214. )
  215. results = process_dataset(
  216. dataset=dataset,
  217. system_prompt=system_prompt,
  218. start_index=args.start_index,
  219. end_index=args.end_index,
  220. n_samples=args.n_samples,
  221. model_instance=model_instance,
  222. batch_size=args.batch_size,
  223. max_seq_len=args.max_seq_len,
  224. )
  225. output_dataset = Dataset.from_dict(
  226. {
  227. "id": [r["id"] for r in results],
  228. "conversations": [r["conversations"] for r in results],
  229. "cot_conversations": [r["cot_conversations"] for r in results],
  230. }
  231. )
  232. output_dataset.save_to_disk(args.output_path)
  233. if __name__ == "__main__":
  234. main()