add_cot.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import argparse
  2. import json
  3. import os
  4. from typing import Dict, List
  5. import yaml
  6. from datasets import load_dataset, load_from_disk
  7. from tqdm import tqdm
  8. from vllm import LLM, SamplingParams
  9. def load_system_prompt(yaml_path: str) -> str:
  10. """Load system prompt from a YAML file."""
  11. with open(yaml_path, "r") as f:
  12. config = yaml.safe_load(f)
  13. return config["system_prompt"]
  14. def setup_llm(
  15. model_name: str,
  16. tensor_parallel_size: int = 1,
  17. gpu_memory_utilization: float = 0.9,
  18. max_model_len: int = 128000,
  19. gpu_ids: List[int] = None,
  20. ) -> LLM:
  21. """Initialize the vLLM LLM with specified parameters for multi-GPU support."""
  22. # If specific GPUs are requested, set CUDA_VISIBLE_DEVICES
  23. if gpu_ids is not None:
  24. os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_ids))
  25. llm = LLM(
  26. model=model_name,
  27. tensor_parallel_size=tensor_parallel_size,
  28. gpu_memory_utilization=gpu_memory_utilization,
  29. max_model_len=max_model_len,
  30. )
  31. return llm
  32. def create_messages(system_prompt: str, conversation: str) -> List[Dict[str, str]]:
  33. """Create the messages list for the model input."""
  34. return [
  35. {"role": "system", "content": system_prompt},
  36. {"role": "user", "content": conversation},
  37. ]
  38. def format_prompt(system_prompt: str, conversation: str) -> str:
  39. """Format the system prompt and conversation into the specific chat template format."""
  40. return (
  41. f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>{system_prompt}<|eot_id|>"
  42. f"<|start_header_id|>user<|end_header_id|>{conversation}"
  43. )
  44. def process_dataset(
  45. dataset,
  46. llm: LLM,
  47. system_prompt: str,
  48. output_file: str,
  49. start_index: int = 0,
  50. end_index: int = None,
  51. max_new_tokens: int = 128000,
  52. ) -> None:
  53. """Process the dataset using vLLM."""
  54. sampling_params = SamplingParams(
  55. max_tokens=max_new_tokens,
  56. temperature=0.7,
  57. top_p=0.95,
  58. )
  59. # Handle end_index
  60. if end_index is None:
  61. end_index = len(dataset)
  62. else:
  63. end_index = min(end_index, len(dataset))
  64. # Validate indices
  65. if start_index < 0:
  66. start_index = 0
  67. if start_index >= len(dataset):
  68. raise ValueError(
  69. f"Start index {start_index} is larger than dataset size {len(dataset)}"
  70. )
  71. if start_index >= end_index:
  72. raise ValueError(
  73. f"Start index {start_index} must be less than end index {end_index}"
  74. )
  75. # Select the specified range
  76. dataset_slice = dataset.select(range(start_index, end_index))
  77. # Process examples one at a time
  78. with open(output_file, "w") as f:
  79. for item in tqdm(
  80. dataset_slice, desc=f"Processing rows {start_index} to {end_index}"
  81. ):
  82. # Format the prompt as a single string
  83. prompt = format_prompt(system_prompt, item["conversations"])
  84. # Generate the response
  85. output = llm.generate(prompt, sampling_params)[0]
  86. print(output.outputs[0].text)
  87. # Save the result
  88. result = {
  89. "id": item["id"],
  90. "conversations": output.outputs[0].text,
  91. }
  92. f.write(json.dumps(result) + "\n")
  93. def main():
  94. parser = argparse.ArgumentParser(
  95. description="Process dataset using vLLM with multi-GPU support"
  96. )
  97. parser.add_argument(
  98. "--model", type=str, required=True, help="Name or path of the model to use"
  99. )
  100. parser.add_argument(
  101. "--config",
  102. type=str,
  103. required=True,
  104. help="Path to YAML config file containing system prompt",
  105. )
  106. parser.add_argument(
  107. "--output-file",
  108. type=str,
  109. default="processed_outputs.jsonl",
  110. help="Output file path",
  111. )
  112. parser.add_argument(
  113. "--dataset-path", type=str, required=True, help="Path to the dataset"
  114. )
  115. parser.add_argument(
  116. "--gpu-ids",
  117. type=str,
  118. help="Comma-separated list of GPU IDs to use (e.g., '0,1,2,3')",
  119. )
  120. parser.add_argument(
  121. "--tensor-parallel-size",
  122. type=int,
  123. default=1,
  124. help="Number of GPUs to use for tensor parallelism",
  125. )
  126. parser.add_argument(
  127. "--gpu-memory-utilization",
  128. type=float,
  129. default=0.9,
  130. help="Target GPU memory utilization (0.0 to 1.0)",
  131. )
  132. # Add new arguments for range specification
  133. parser.add_argument(
  134. "--start-index",
  135. type=int,
  136. default=0,
  137. help="Starting index in the dataset (inclusive)",
  138. )
  139. parser.add_argument(
  140. "--end-index",
  141. type=int,
  142. help="Ending index in the dataset (exclusive). If not specified, processes until the end.",
  143. )
  144. args = parser.parse_args()
  145. # Parse GPU IDs if provided
  146. gpu_ids = None
  147. if args.gpu_ids:
  148. gpu_ids = [int(gpu_id) for gpu_id in args.gpu_ids.split(",")]
  149. # Load system prompt from YAML
  150. system_prompt = load_system_prompt(args.config)
  151. # Load dataset
  152. dataset = load_from_disk(args.dataset_path)
  153. dataset = dataset.select(range(0, 2000))
  154. # Initialize vLLM with multi-GPU support
  155. llm = setup_llm(
  156. model_name=args.model,
  157. tensor_parallel_size=args.tensor_parallel_size,
  158. gpu_memory_utilization=args.gpu_memory_utilization,
  159. gpu_ids=gpu_ids,
  160. )
  161. # Process dataset
  162. process_dataset(
  163. dataset=dataset,
  164. llm=llm,
  165. system_prompt=system_prompt,
  166. output_file=args.output_file,
  167. start_index=args.start_index,
  168. end_index=args.end_index,
  169. )
  170. if __name__ == "__main__":
  171. main()