add_cot.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import argparse
  2. import json
  3. import os
  4. from typing import Any, Dict, List, Union
  5. from vllm import LLM, SamplingParams
  6. import torch
  7. import yaml
  8. from datasets import load_dataset, load_from_disk
  9. from tqdm import tqdm
  10. import transformers
  11. def load_system_prompt(yaml_path):
  12. with open(yaml_path, "r") as f:
  13. config = yaml.safe_load(f)
  14. return config["system_prompt"]
  15. def setup_llm(
  16. model_name: str,
  17. tensor_parallel_size: int = 1,
  18. gpu_memory_utilization: float = 0.9,
  19. max_model_len: int = 128000,
  20. gpu_ids: List[int] = None,
  21. ):
  22. if gpu_ids is not None:
  23. os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_ids))
  24. llm = LLM(
  25. model=model_name,
  26. tensor_parallel_size=tensor_parallel_size,
  27. gpu_memory_utilization=gpu_memory_utilization,
  28. max_model_len=max_model_len,
  29. )
  30. return llm, SamplingParams
  31. def setup_hf_pipeline(model_name,gpu_ids):
  32. if gpu_ids is not None:
  33. os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_ids))
  34. pipeline = transformers.pipeline(
  35. "text-generation",
  36. model=model_name,
  37. model_kwargs={"torch_dtype": torch.bfloat16},
  38. device_map="auto",
  39. )
  40. return pipeline
  41. def create_messages(system_prompt, conversation):
  42. return [
  43. {"role": "system", "content": system_prompt},
  44. {"role": "user", "content": conversation},
  45. ]
  46. def format_prompt(system_prompt, conversation):
  47. return (
  48. f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>{system_prompt}<|eot_id|>"
  49. f"<|start_header_id|>user<|end_header_id|>{conversation}"
  50. )
  51. def process_with_vllm(item,llm,system_prompt,sampling_params):
  52. prompt = format_prompt(system_prompt, item["conversations"])
  53. output = llm.generate(prompt, sampling_params)[0]
  54. return output.outputs[0].text
  55. def process_with_hf(item,pipeline,system_prompt,max_new_tokens,):
  56. messages = create_messages(system_prompt, item["conversations"])
  57. outputs = pipeline(
  58. messages,
  59. max_new_tokens=max_new_tokens,
  60. )
  61. return outputs[0]["generated_text"][-1]["content"]
  62. def process_dataset(dataset,system_prompt,output_file,start_index,end_index,max_new_tokens,use_hf,model_instance,sampling_params,):
  63. # Handle end_index
  64. if end_index is None:
  65. end_index = len(dataset)
  66. else:
  67. end_index = min(end_index, len(dataset))
  68. # Validate indices
  69. if start_index < 0:
  70. start_index = 0
  71. if start_index >= len(dataset):
  72. raise ValueError(
  73. f"Start index {start_index} is larger than dataset size {len(dataset)}"
  74. )
  75. if start_index >= end_index:
  76. raise ValueError(
  77. f"Start index {start_index} must be less than end index {end_index}"
  78. )
  79. dataset_slice = dataset.select(range(start_index, end_index))
  80. with open(output_file, "w") as f:
  81. for item in tqdm(
  82. dataset_slice, desc=f"Processing rows {start_index} to {end_index}"
  83. ):
  84. # Select output
  85. if use_hf:
  86. cot_response = process_with_hf(
  87. item, model_instance, system_prompt, max_new_tokens
  88. )
  89. else:
  90. cot_response = process_with_vllm(
  91. item, model_instance, system_prompt, sampling_params
  92. )
  93. result = {
  94. "id": item["id"],
  95. "conversations": item["conversations"], # Keep original conversations
  96. "cot_conversations": cot_response, # Add new CoT conversations
  97. }
  98. f.write(json.dumps(result) + "\n")
  99. def main():
  100. parser = argparse.ArgumentParser(
  101. description="Process dataset using vLLM or HuggingFace pipeline with multi-GPU support"
  102. )
  103. parser.add_argument(
  104. "--model", type=str, required=True, help="Name or path of the model to use"
  105. )
  106. parser.add_argument(
  107. "--config",
  108. type=str,
  109. required=True,
  110. help="Path to YAML config file containing system prompt",
  111. )
  112. parser.add_argument(
  113. "--output-file",
  114. type=str,
  115. default="processed_outputs.jsonl",
  116. help="Output file path",
  117. )
  118. parser.add_argument(
  119. "--dataset-path", type=str, required=True, help="Path to the dataset"
  120. )
  121. parser.add_argument(
  122. "--gpu-ids",
  123. type=str,
  124. help="Comma-separated list of GPU IDs to use (e.g., '0,1,2,3')",
  125. )
  126. parser.add_argument(
  127. "--tensor-parallel-size",
  128. type=int,
  129. default=1,
  130. help="Number of GPUs to use for tensor parallelism",
  131. )
  132. parser.add_argument(
  133. "--gpu-memory-utilization",
  134. type=float,
  135. default=0.9,
  136. help="Target GPU memory utilization (0.0 to 1.0)",
  137. )
  138. parser.add_argument(
  139. "--start-index",
  140. type=int,
  141. default=0,
  142. help="Starting index in the dataset (inclusive)",
  143. )
  144. parser.add_argument(
  145. "--end-index",
  146. type=int,
  147. help="Ending index in the dataset (exclusive). If not specified, processes until the end.",
  148. )
  149. parser.add_argument(
  150. "--use-hf",
  151. action="store_true",
  152. help="Use HuggingFace pipeline instead of vLLM",
  153. )
  154. args = parser.parse_args()
  155. gpu_ids = None
  156. if args.gpu_ids:
  157. gpu_ids = [int(gpu_id) for gpu_id in args.gpu_ids.split(",")]
  158. system_prompt = load_system_prompt(args.config)
  159. dataset = load_from_disk(args.dataset_path)
  160. sampling_params = None
  161. if args.use_hf:
  162. model_instance = setup_hf_pipeline(
  163. model_name=args.model,
  164. gpu_ids=gpu_ids,
  165. )
  166. else:
  167. model_instance, sampling_params = setup_llm(
  168. model_name=args.model,
  169. tensor_parallel_size=args.tensor_parallel_size,
  170. gpu_memory_utilization=args.gpu_memory_utilization,
  171. gpu_ids=gpu_ids,
  172. )
  173. sampling_params = sampling_params(
  174. max_tokens=128000,
  175. temperature=0.7,
  176. top_p=0.95,
  177. )
  178. process_dataset(
  179. dataset=dataset,
  180. system_prompt=system_prompt,
  181. output_file=args.output_file,
  182. start_index=args.start_index,
  183. end_index=args.end_index,
  184. use_hf=args.use_hf,
  185. model_instance=model_instance,
  186. sampling_params=sampling_params,
  187. )
  188. if __name__ == "__main__":
  189. main()