add_cot.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. import argparse
  2. import json
  3. import os
  4. from typing import Any, Dict, List, Union
  5. import torch
  6. import yaml
  7. from datasets import load_dataset, load_from_disk
  8. from tqdm import tqdm
  9. def load_system_prompt(yaml_path: str) -> str:
  10. """Load system prompt from a YAML file."""
  11. with open(yaml_path, "r") as f:
  12. config = yaml.safe_load(f)
  13. return config["system_prompt"]
  14. def setup_llm(
  15. model_name: str,
  16. tensor_parallel_size: int = 1,
  17. gpu_memory_utilization: float = 0.9,
  18. max_model_len: int = 128000,
  19. gpu_ids: List[int] = None,
  20. ):
  21. """Initialize the vLLM LLM with specified parameters for multi-GPU support."""
  22. from vllm import LLM, SamplingParams
  23. if gpu_ids is not None:
  24. os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_ids))
  25. llm = LLM(
  26. model=model_name,
  27. tensor_parallel_size=tensor_parallel_size,
  28. gpu_memory_utilization=gpu_memory_utilization,
  29. max_model_len=max_model_len,
  30. )
  31. return llm, SamplingParams
  32. def setup_hf_pipeline(
  33. model_name: str,
  34. gpu_ids: List[int] = None,
  35. ):
  36. """Initialize the HuggingFace pipeline."""
  37. import transformers
  38. if gpu_ids is not None:
  39. os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_ids))
  40. pipeline = transformers.pipeline(
  41. "text-generation",
  42. model=model_name,
  43. model_kwargs={"torch_dtype": torch.bfloat16},
  44. device_map="auto",
  45. )
  46. return pipeline
  47. def create_messages(system_prompt: str, conversation: str) -> List[Dict[str, str]]:
  48. """Create the messages list for the model input."""
  49. return [
  50. {"role": "system", "content": system_prompt},
  51. {"role": "user", "content": conversation},
  52. ]
  53. def format_prompt(system_prompt: str, conversation: str) -> str:
  54. """Format the system prompt and conversation into the specific chat template format."""
  55. return (
  56. f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>{system_prompt}<|eot_id|>"
  57. f"<|start_header_id|>user<|end_header_id|>{conversation}"
  58. )
  59. def process_with_vllm(
  60. item: Dict,
  61. llm: Any,
  62. system_prompt: str,
  63. sampling_params: Any,
  64. ) -> str:
  65. """Process a single item using vLLM."""
  66. prompt = format_prompt(system_prompt, item["conversations"])
  67. output = llm.generate(prompt, sampling_params)[0]
  68. return output.outputs[0].text
  69. def process_with_hf(
  70. item: Dict,
  71. pipeline: Any,
  72. system_prompt: str,
  73. max_new_tokens: int,
  74. ) -> str:
  75. """Process a single item using HuggingFace pipeline."""
  76. messages = create_messages(system_prompt, item["conversations"])
  77. outputs = pipeline(
  78. messages,
  79. max_new_tokens=max_new_tokens,
  80. )
  81. return outputs[0]["generated_text"][-1]["content"]
  82. def process_dataset(
  83. dataset,
  84. system_prompt: str,
  85. output_file: str,
  86. start_index: int = 0,
  87. end_index: int = None,
  88. max_new_tokens: int = 128000,
  89. use_hf: bool = False,
  90. model_instance: Any = None,
  91. sampling_params: Any = None,
  92. ) -> None:
  93. """Process the dataset using either vLLM or HuggingFace pipeline."""
  94. # Handle end_index
  95. if end_index is None:
  96. end_index = len(dataset)
  97. else:
  98. end_index = min(end_index, len(dataset))
  99. # Validate indices
  100. if start_index < 0:
  101. start_index = 0
  102. if start_index >= len(dataset):
  103. raise ValueError(
  104. f"Start index {start_index} is larger than dataset size {len(dataset)}"
  105. )
  106. if start_index >= end_index:
  107. raise ValueError(
  108. f"Start index {start_index} must be less than end index {end_index}"
  109. )
  110. # Select the specified range
  111. dataset_slice = dataset.select(range(start_index, end_index))
  112. # Process examples one at a time
  113. with open(output_file, "w") as f:
  114. for item in tqdm(
  115. dataset_slice, desc=f"Processing rows {start_index} to {end_index}"
  116. ):
  117. # Generate the response using appropriate method
  118. if use_hf:
  119. cot_response = process_with_hf(
  120. item, model_instance, system_prompt, max_new_tokens
  121. )
  122. else:
  123. cot_response = process_with_vllm(
  124. item, model_instance, system_prompt, sampling_params
  125. )
  126. # Save the result with both original and CoT conversations
  127. result = {
  128. "id": item["id"],
  129. "conversations": item["conversations"], # Keep original conversations
  130. "cot_conversations": cot_response, # Add new CoT conversations
  131. }
  132. f.write(json.dumps(result) + "\n")
  133. def main():
  134. parser = argparse.ArgumentParser(
  135. description="Process dataset using vLLM or HuggingFace pipeline with multi-GPU support"
  136. )
  137. parser.add_argument(
  138. "--model", type=str, required=True, help="Name or path of the model to use"
  139. )
  140. parser.add_argument(
  141. "--config",
  142. type=str,
  143. required=True,
  144. help="Path to YAML config file containing system prompt",
  145. )
  146. parser.add_argument(
  147. "--output-file",
  148. type=str,
  149. default="processed_outputs.jsonl",
  150. help="Output file path",
  151. )
  152. parser.add_argument(
  153. "--dataset-path", type=str, required=True, help="Path to the dataset"
  154. )
  155. parser.add_argument(
  156. "--gpu-ids",
  157. type=str,
  158. help="Comma-separated list of GPU IDs to use (e.g., '0,1,2,3')",
  159. )
  160. parser.add_argument(
  161. "--tensor-parallel-size",
  162. type=int,
  163. default=1,
  164. help="Number of GPUs to use for tensor parallelism",
  165. )
  166. parser.add_argument(
  167. "--gpu-memory-utilization",
  168. type=float,
  169. default=0.9,
  170. help="Target GPU memory utilization (0.0 to 1.0)",
  171. )
  172. parser.add_argument(
  173. "--start-index",
  174. type=int,
  175. default=0,
  176. help="Starting index in the dataset (inclusive)",
  177. )
  178. parser.add_argument(
  179. "--end-index",
  180. type=int,
  181. help="Ending index in the dataset (exclusive). If not specified, processes until the end.",
  182. )
  183. parser.add_argument(
  184. "--use-hf",
  185. action="store_true",
  186. help="Use HuggingFace pipeline instead of vLLM",
  187. )
  188. args = parser.parse_args()
  189. # Parse GPU IDs if provided
  190. gpu_ids = None
  191. if args.gpu_ids:
  192. gpu_ids = [int(gpu_id) for gpu_id in args.gpu_ids.split(",")]
  193. # Load system prompt from YAML
  194. system_prompt = load_system_prompt(args.config)
  195. # Load dataset
  196. dataset = load_from_disk(args.dataset_path)
  197. # Initialize appropriate model instance based on mode
  198. sampling_params = None
  199. if args.use_hf:
  200. model_instance = setup_hf_pipeline(
  201. model_name=args.model,
  202. gpu_ids=gpu_ids,
  203. )
  204. else:
  205. model_instance, sampling_params = setup_llm(
  206. model_name=args.model,
  207. tensor_parallel_size=args.tensor_parallel_size,
  208. gpu_memory_utilization=args.gpu_memory_utilization,
  209. gpu_ids=gpu_ids,
  210. )
  211. sampling_params = sampling_params(
  212. max_tokens=128000,
  213. temperature=0.7,
  214. top_p=0.95,
  215. )
  216. # Process dataset
  217. process_dataset(
  218. dataset=dataset,
  219. system_prompt=system_prompt,
  220. output_file=args.output_file,
  221. start_index=args.start_index,
  222. end_index=args.end_index,
  223. use_hf=args.use_hf,
  224. model_instance=model_instance,
  225. sampling_params=sampling_params,
  226. )
  227. if __name__ == "__main__":
  228. main()