|
@@ -54,3 +54,58 @@ def format_prompt(system_prompt: str, conversation: str) -> str:
|
|
|
)
|
|
|
|
|
|
|
|
|
+def process_dataset(
|
|
|
+ dataset,
|
|
|
+ llm: LLM,
|
|
|
+ system_prompt: str,
|
|
|
+ output_file: str,
|
|
|
+ start_index: int = 0,
|
|
|
+ end_index: int = None,
|
|
|
+ max_new_tokens: int = 128000,
|
|
|
+) -> None:
|
|
|
+ """Process the dataset using vLLM."""
|
|
|
+ sampling_params = SamplingParams(
|
|
|
+ max_tokens=max_new_tokens,
|
|
|
+ temperature=0.7,
|
|
|
+ top_p=0.95,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Handle end_index
|
|
|
+ if end_index is None:
|
|
|
+ end_index = len(dataset)
|
|
|
+ else:
|
|
|
+ end_index = min(end_index, len(dataset))
|
|
|
+
|
|
|
+ # Validate indices
|
|
|
+ if start_index < 0:
|
|
|
+ start_index = 0
|
|
|
+ if start_index >= len(dataset):
|
|
|
+ raise ValueError(
|
|
|
+ f"Start index {start_index} is larger than dataset size {len(dataset)}"
|
|
|
+ )
|
|
|
+ if start_index >= end_index:
|
|
|
+ raise ValueError(
|
|
|
+ f"Start index {start_index} must be less than end index {end_index}"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Select the specified range
|
|
|
+ dataset_slice = dataset.select(range(start_index, end_index))
|
|
|
+
|
|
|
+ # Process examples one at a time
|
|
|
+ with open(output_file, "w") as f:
|
|
|
+ for item in tqdm(
|
|
|
+ dataset_slice, desc=f"Processing rows {start_index} to {end_index}"
|
|
|
+ ):
|
|
|
+ # Format the prompt as a single string
|
|
|
+ prompt = format_prompt(system_prompt, item["conversations"])
|
|
|
+
|
|
|
+ # Generate the response
|
|
|
+ output = llm.generate(prompt, sampling_params)[0]
|
|
|
+
|
|
|
+ print(output.outputs[0].text)
|
|
|
+ # Save the result
|
|
|
+ result = {
|
|
|
+ "id": item["id"],
|
|
|
+ "conversations": output.outputs[0].text,
|
|
|
+ }
|
|
|
+ f.write(json.dumps(result) + "\n")
|