|
@@ -1,12 +1,12 @@
|
|
|
import argparse
|
|
|
import json
|
|
|
import os
|
|
|
-from typing import Dict, List
|
|
|
+from typing import Any, Dict, List, Union
|
|
|
|
|
|
+import torch
|
|
|
import yaml
|
|
|
from datasets import load_dataset, load_from_disk
|
|
|
from tqdm import tqdm
|
|
|
-from vllm import LLM, SamplingParams
|
|
|
|
|
|
|
|
|
def load_system_prompt(yaml_path: str) -> str:
|
|
@@ -22,10 +22,10 @@ def setup_llm(
|
|
|
gpu_memory_utilization: float = 0.9,
|
|
|
max_model_len: int = 128000,
|
|
|
gpu_ids: List[int] = None,
|
|
|
-) -> LLM:
|
|
|
+):
|
|
|
"""Initialize the vLLM LLM with specified parameters for multi-GPU support."""
|
|
|
+ from vllm import LLM, SamplingParams
|
|
|
|
|
|
- # If specific GPUs are requested, set CUDA_VISIBLE_DEVICES
|
|
|
if gpu_ids is not None:
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_ids))
|
|
|
|
|
@@ -35,7 +35,26 @@ def setup_llm(
|
|
|
gpu_memory_utilization=gpu_memory_utilization,
|
|
|
max_model_len=max_model_len,
|
|
|
)
|
|
|
- return llm
|
|
|
+ return llm, SamplingParams
|
|
|
+
|
|
|
+
|
|
|
+def setup_hf_pipeline(
|
|
|
+ model_name: str,
|
|
|
+ gpu_ids: List[int] = None,
|
|
|
+):
|
|
|
+ """Initialize the HuggingFace pipeline."""
|
|
|
+ import transformers
|
|
|
+
|
|
|
+ if gpu_ids is not None:
|
|
|
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_ids))
|
|
|
+
|
|
|
+ pipeline = transformers.pipeline(
|
|
|
+ "text-generation",
|
|
|
+ model=model_name,
|
|
|
+ model_kwargs={"torch_dtype": torch.bfloat16},
|
|
|
+ device_map="auto",
|
|
|
+ )
|
|
|
+ return pipeline
|
|
|
|
|
|
|
|
|
def create_messages(system_prompt: str, conversation: str) -> List[Dict[str, str]]:
|
|
@@ -54,22 +73,45 @@ def format_prompt(system_prompt: str, conversation: str) -> str:
|
|
|
)
|
|
|
|
|
|
|
|
|
+def process_with_vllm(
|
|
|
+ item: Dict,
|
|
|
+ llm: Any,
|
|
|
+ system_prompt: str,
|
|
|
+ sampling_params: Any,
|
|
|
+) -> str:
|
|
|
+ """Process a single item using vLLM."""
|
|
|
+ prompt = format_prompt(system_prompt, item["conversations"])
|
|
|
+ output = llm.generate(prompt, sampling_params)[0]
|
|
|
+ return output.outputs[0].text
|
|
|
+
|
|
|
+
|
|
|
+def process_with_hf(
|
|
|
+ item: Dict,
|
|
|
+ pipeline: Any,
|
|
|
+ system_prompt: str,
|
|
|
+ max_new_tokens: int,
|
|
|
+) -> str:
|
|
|
+ """Process a single item using HuggingFace pipeline."""
|
|
|
+ messages = create_messages(system_prompt, item["conversations"])
|
|
|
+ outputs = pipeline(
|
|
|
+ messages,
|
|
|
+ max_new_tokens=max_new_tokens,
|
|
|
+ )
|
|
|
+ return outputs[0]["generated_text"][-1]["content"]
|
|
|
+
|
|
|
+
|
|
|
def process_dataset(
|
|
|
dataset,
|
|
|
- llm: LLM,
|
|
|
system_prompt: str,
|
|
|
output_file: str,
|
|
|
start_index: int = 0,
|
|
|
end_index: int = None,
|
|
|
max_new_tokens: int = 128000,
|
|
|
+ use_hf: bool = False,
|
|
|
+ model_instance: Any = None,
|
|
|
+ sampling_params: Any = None,
|
|
|
) -> None:
|
|
|
- """Process the dataset using vLLM."""
|
|
|
- sampling_params = SamplingParams(
|
|
|
- max_tokens=max_new_tokens,
|
|
|
- temperature=0.7,
|
|
|
- top_p=0.95,
|
|
|
- )
|
|
|
-
|
|
|
+ """Process the dataset using either vLLM or HuggingFace pipeline."""
|
|
|
# Handle end_index
|
|
|
if end_index is None:
|
|
|
end_index = len(dataset)
|
|
@@ -96,24 +138,28 @@ def process_dataset(
|
|
|
for item in tqdm(
|
|
|
dataset_slice, desc=f"Processing rows {start_index} to {end_index}"
|
|
|
):
|
|
|
- # Format the prompt as a single string
|
|
|
- prompt = format_prompt(system_prompt, item["conversations"])
|
|
|
-
|
|
|
- # Generate the response
|
|
|
- output = llm.generate(prompt, sampling_params)[0]
|
|
|
+ # Generate the response using appropriate method
|
|
|
+ if use_hf:
|
|
|
+ cot_response = process_with_hf(
|
|
|
+ item, model_instance, system_prompt, max_new_tokens
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ cot_response = process_with_vllm(
|
|
|
+ item, model_instance, system_prompt, sampling_params
|
|
|
+ )
|
|
|
|
|
|
- print(output.outputs[0].text)
|
|
|
- # Save the result
|
|
|
+ # Save the result with both original and CoT conversations
|
|
|
result = {
|
|
|
"id": item["id"],
|
|
|
- "conversations": output.outputs[0].text,
|
|
|
+ "conversations": item["conversations"], # Keep original conversations
|
|
|
+ "cot_conversations": cot_response, # Add new CoT conversations
|
|
|
}
|
|
|
f.write(json.dumps(result) + "\n")
|
|
|
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(
|
|
|
- description="Process dataset using vLLM with multi-GPU support"
|
|
|
+ description="Process dataset using vLLM or HuggingFace pipeline with multi-GPU support"
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--model", type=str, required=True, help="Name or path of the model to use"
|
|
@@ -150,7 +196,6 @@ def main():
|
|
|
default=0.9,
|
|
|
help="Target GPU memory utilization (0.0 to 1.0)",
|
|
|
)
|
|
|
- # Add new arguments for range specification
|
|
|
parser.add_argument(
|
|
|
"--start-index",
|
|
|
type=int,
|
|
@@ -162,6 +207,11 @@ def main():
|
|
|
type=int,
|
|
|
help="Ending index in the dataset (exclusive). If not specified, processes until the end.",
|
|
|
)
|
|
|
+ parser.add_argument(
|
|
|
+ "--use-hf",
|
|
|
+ action="store_true",
|
|
|
+ help="Use HuggingFace pipeline instead of vLLM",
|
|
|
+ )
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
# Parse GPU IDs if provided
|
|
@@ -174,24 +224,37 @@ def main():
|
|
|
|
|
|
# Load dataset
|
|
|
dataset = load_from_disk(args.dataset_path)
|
|
|
- dataset = dataset.select(range(0, 2000))
|
|
|
|
|
|
- # Initialize vLLM with multi-GPU support
|
|
|
- llm = setup_llm(
|
|
|
- model_name=args.model,
|
|
|
- tensor_parallel_size=args.tensor_parallel_size,
|
|
|
- gpu_memory_utilization=args.gpu_memory_utilization,
|
|
|
- gpu_ids=gpu_ids,
|
|
|
- )
|
|
|
+ # Initialize appropriate model instance based on mode
|
|
|
+ sampling_params = None
|
|
|
+ if args.use_hf:
|
|
|
+ model_instance = setup_hf_pipeline(
|
|
|
+ model_name=args.model,
|
|
|
+ gpu_ids=gpu_ids,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ model_instance, sampling_params = setup_llm(
|
|
|
+ model_name=args.model,
|
|
|
+ tensor_parallel_size=args.tensor_parallel_size,
|
|
|
+ gpu_memory_utilization=args.gpu_memory_utilization,
|
|
|
+ gpu_ids=gpu_ids,
|
|
|
+ )
|
|
|
+ sampling_params = sampling_params(
|
|
|
+ max_tokens=128000,
|
|
|
+ temperature=0.7,
|
|
|
+ top_p=0.95,
|
|
|
+ )
|
|
|
|
|
|
# Process dataset
|
|
|
process_dataset(
|
|
|
dataset=dataset,
|
|
|
- llm=llm,
|
|
|
system_prompt=system_prompt,
|
|
|
output_file=args.output_file,
|
|
|
start_index=args.start_index,
|
|
|
end_index=args.end_index,
|
|
|
+ use_hf=args.use_hf,
|
|
|
+ model_instance=model_instance,
|
|
|
+ sampling_params=sampling_params,
|
|
|
)
|
|
|
|
|
|
|