فهرست منبع

Update annotating.py

Sanyam Bhutani 7 ماه پیش
والد
کامیت
a77b738898
1فایلهای تغییر یافته به همراه55 افزوده شده و 0 حذف شده
  1. 55 0
      end-to-end-use-cases/data-tool/dataprep-scripts/annotating.py

+ 55 - 0
end-to-end-use-cases/data-tool/dataprep-scripts/annotating.py

@@ -54,3 +54,58 @@ def format_prompt(system_prompt: str, conversation: str) -> str:
     )
 
 
+def process_dataset(
+    dataset,
+    llm: LLM,
+    system_prompt: str,
+    output_file: str,
+    start_index: int = 0,
+    end_index: int = None,
+    max_new_tokens: int = 128000,
+) -> None:
+    """Process the dataset using vLLM."""
+    sampling_params = SamplingParams(
+        max_tokens=max_new_tokens,
+        temperature=0.7,
+        top_p=0.95,
+    )
+
+    # Handle end_index
+    if end_index is None:
+        end_index = len(dataset)
+    else:
+        end_index = min(end_index, len(dataset))
+
+    # Validate indices
+    if start_index < 0:
+        start_index = 0
+    if start_index >= len(dataset):
+        raise ValueError(
+            f"Start index {start_index} is larger than dataset size {len(dataset)}"
+        )
+    if start_index >= end_index:
+        raise ValueError(
+            f"Start index {start_index} must be less than end index {end_index}"
+        )
+
+    # Select the specified range
+    dataset_slice = dataset.select(range(start_index, end_index))
+
+    # Process examples one at a time
+    with open(output_file, "w") as f:
+        for item in tqdm(
+            dataset_slice, desc=f"Processing rows {start_index} to {end_index}"
+        ):
+            # Format the prompt as a single string
+            prompt = format_prompt(system_prompt, item["conversations"])
+
+            # Generate the response
+            output = llm.generate(prompt, sampling_params)[0]
+
+            print(output.outputs[0].text)
+            # Save the result
+            result = {
+                "id": item["id"],
+                "conversations": output.outputs[0].text,
+            }
+            f.write(json.dumps(result) + "\n")