Browse Source

Update annotating.py

Sanyam Bhutani 5 months ago
parent
commit
2a88f91c2a
1 changed files with 51 additions and 0 deletions
  1. 51 0
      end-to-end-use-cases/data-tool/dataprep-scripts/annotating.py

+ 51 - 0
end-to-end-use-cases/data-tool/dataprep-scripts/annotating.py

@@ -3,3 +3,54 @@ import json
 import os
 from typing import Dict, List
 
+import yaml
+from datasets import load_dataset, load_from_disk
+from tqdm import tqdm
+from vllm import LLM, SamplingParams
+
+
+def load_system_prompt(yaml_path: str) -> str:
+    """Load system prompt from a YAML file."""
+    with open(yaml_path, "r") as f:
+        config = yaml.safe_load(f)
+    return config["system_prompt"]
+
+
+def setup_llm(
+    model_name: str,
+    tensor_parallel_size: int = 1,
+    gpu_memory_utilization: float = 0.9,
+    max_model_len: int = 128000,
+    gpu_ids: List[int] = None,
+) -> LLM:
+    """Initialize the vLLM LLM with specified parameters for multi-GPU support."""
+
+    # If specific GPUs are requested, set CUDA_VISIBLE_DEVICES
+    if gpu_ids is not None:
+        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_ids))
+
+    llm = LLM(
+        model=model_name,
+        tensor_parallel_size=tensor_parallel_size,
+        gpu_memory_utilization=gpu_memory_utilization,
+        max_model_len=max_model_len,
+    )
+    return llm
+
+
+def create_messages(system_prompt: str, conversation: str) -> List[Dict[str, str]]:
+    """Create the messages list for the model input."""
+    return [
+        {"role": "system", "content": system_prompt},
+        {"role": "user", "content": conversation},
+    ]
+
+
+def format_prompt(system_prompt: str, conversation: str) -> str:
+    """Format the system prompt and conversation into the specific chat template format."""
+    return (
+        f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>{system_prompt}<|eot_id|>"
+        f"<|start_header_id|>user<|end_header_id|>{conversation}"
+    )
+
+