annotating.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import argparse
  2. import json
  3. import os
  4. from typing import Dict, List
  5. import yaml
  6. from datasets import load_dataset, load_from_disk
  7. from tqdm import tqdm
  8. from vllm import LLM, SamplingParams
  9. def load_system_prompt(yaml_path: str) -> str:
  10. """Load system prompt from a YAML file."""
  11. with open(yaml_path, "r") as f:
  12. config = yaml.safe_load(f)
  13. return config["system_prompt"]
  14. def setup_llm(
  15. model_name: str,
  16. tensor_parallel_size: int = 1,
  17. gpu_memory_utilization: float = 0.9,
  18. max_model_len: int = 128000,
  19. gpu_ids: List[int] = None,
  20. ) -> LLM:
  21. """Initialize the vLLM LLM with specified parameters for multi-GPU support."""
  22. # If specific GPUs are requested, set CUDA_VISIBLE_DEVICES
  23. if gpu_ids is not None:
  24. os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_ids))
  25. llm = LLM(
  26. model=model_name,
  27. tensor_parallel_size=tensor_parallel_size,
  28. gpu_memory_utilization=gpu_memory_utilization,
  29. max_model_len=max_model_len,
  30. )
  31. return llm
  32. def create_messages(system_prompt: str, conversation: str) -> List[Dict[str, str]]:
  33. """Create the messages list for the model input."""
  34. return [
  35. {"role": "system", "content": system_prompt},
  36. {"role": "user", "content": conversation},
  37. ]
  38. def format_prompt(system_prompt: str, conversation: str) -> str:
  39. """Format the system prompt and conversation into the specific chat template format."""
  40. return (
  41. f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>{system_prompt}<|eot_id|>"
  42. f"<|start_header_id|>user<|end_header_id|>{conversation}"
  43. )