Sanyam Bhutani пре 2 месеци
родитељ
комит
b4efdd135b
1 измењених фајлова са 286 додато и 0 уклоњено
  1. 286 0
      end-to-end-use-cases/data-tool/scripts/add_cot_vllm.py

+ 286 - 0
end-to-end-use-cases/data-tool/scripts/add_cot_vllm.py

@@ -0,0 +1,286 @@
+# VLLM_WORKER_MULTIPROC_METHOD=spawn python scripts/add_cot_vllm.py --model_id meta-llama/Llama-3.3-70B-Instruct --dataset-path datasets/2_ready_for_CoT/func-calling-multi-turn-final/ --config configs/config.yaml --output-path "datasets/3_CoT_added/func-calling-multi-turn-final/" --batch-size 96 --max-seq-len 16000
+
+
+import argparse
+import json
+import os
+import random
+import re
+from typing import Any, Dict, List
+
+import torch
+import yaml
+from datasets import Dataset, load_from_disk
+from tqdm import tqdm
+from transformers import AutoProcessor
+from vllm import LLM, SamplingParams
+
+
+class LLM_Singleton:
+    _instance = None
+
+    def __new__(
+        cls,
+        model_id,
+        max_model_len=64000,
+        max_num_seqs=16,
+        enforce_eager=True,
+        debug=False,
+    ):
+        if cls._instance is None:
+            cls._instance = super(LLM_Singleton, cls).__new__(cls)
+            cls._instance._initialize(
+                model_id,
+                tensor_parallel_size=torch.cuda.device_count(),
+                max_model_len=max_model_len,
+                max_num_seqs=max_num_seqs,
+                enforce_eager=enforce_eager,
+                debug=debug,
+            )
+        return cls._instance
+
+    def _initialize(
+        self,
+        model_id,
+        tensor_parallel_size=1,
+        max_model_len=64000,
+        max_num_seqs=16,
+        enforce_eager=True,
+        debug=False,
+    ):
+        if debug:
+            print(
+                f"Initializing LLM with params: {model_id}, {tensor_parallel_size}, {max_model_len}"
+            )
+
+        self.llm = LLM(
+            model_id,
+            tensor_parallel_size=tensor_parallel_size,
+            max_model_len=max_model_len,
+            max_num_seqs=max_num_seqs,
+            enforce_eager=enforce_eager,
+            gpu_memory_utilization=0.95,
+        )
+        self.processor = AutoProcessor.from_pretrained(model_id)
+
+
+def load_system_prompt(yaml_path: str) -> str:
+    """Load system prompt from YAML config."""
+    with open(yaml_path, "r") as f:
+        config = yaml.safe_load(f)
+    return config["system_prompt"]
+
+
+def create_chat_message(system_prompt: str, conversation: str) -> List[Dict[str, Any]]:
+    """Create properly formatted chat messages."""
+    messages = [
+        {"role": "system", "content": system_prompt},
+        {"role": "user", "content": conversation},
+    ]
+    return messages
+
+
+def parse_json_output(output_text: str) -> List[Dict[str, str]]:
+    """Parse and clean model output to ensure valid JSON."""
+    output_text = output_text.strip()
+    json_match = re.search(r"\[.*\]", output_text, re.DOTALL)
+
+    if json_match:
+        output_text = json_match.group(0)
+
+    try:
+        if output_text.startswith('"') and output_text.endswith('"'):
+            output_text = json.loads(output_text)
+        result = json.loads(output_text)
+
+        # Clean the result to remove 'tool': None entries
+        cleaned_result = []
+        for item in result:
+            cleaned_item = {
+                k: v for k, v in item.items() if k != "tool" or v is not None
+            }
+            cleaned_result.append(cleaned_item)
+
+        return cleaned_result
+    except json.JSONDecodeError as e:
+        print(f"Error parsing output: {e}")
+        return None
+
+
+def process_dataset(
+    dataset,
+    system_prompt: str,
+    start_index: int = 0,
+    end_index: int = None,
+    n_samples: int = 0,
+    model_instance: Any = None,
+    batch_size: int = 16,
+    max_seq_len: int = 64000,
+) -> List[Dict]:
+    """Process the dataset in parallel batches and return results."""
+    if end_index is None:
+        end_index = len(dataset)
+    else:
+        end_index = min(end_index, len(dataset))
+
+    # Handle random sampling
+    dataset_size = end_index - start_index
+    if n_samples > 0:
+        # If n_samples is larger than dataset size, use full dataset size
+        n_samples = min(n_samples, dataset_size)
+        # Generate random indices within the specified range
+        indices = random.sample(range(start_index, end_index), n_samples)
+        dataset_slice = dataset.select(indices)
+    else:
+        # If no sampling requested, use the full range
+        dataset_slice = dataset.select(range(start_index, end_index))
+
+    results = []
+
+    for i in tqdm(range(0, len(dataset_slice), batch_size), desc=f"Processing batches"):
+        batch_slice = dataset_slice.select(
+            range(i, min(i + batch_size, len(dataset_slice)))
+        )
+
+        try:
+            batch_inputs = []
+            for item in batch_slice:
+                conversation_str = json.dumps(
+                    item["conversations"], ensure_ascii=False, indent=2
+                )
+                messages = create_chat_message(system_prompt, conversation_str)
+                input_text = model_instance.processor.apply_chat_template(
+                    messages, add_generation_prompt=True, tokenize=False
+                )
+                batch_inputs.append(input_text)
+
+            # max_tokens here is per-batch generation limit
+            sampling_params = SamplingParams(
+                max_tokens=max_seq_len, temperature=0.1, top_p=0.95
+            )
+
+            outputs = model_instance.llm.generate(batch_inputs, sampling_params)
+
+            for item, output in zip(batch_slice, outputs):
+                enhanced_convos = parse_json_output(output.outputs[0].text.strip())
+                if enhanced_convos is None:
+                    print(
+                        f"Warning: Failed to parse output for item {item.get('id', 'unknown')}"
+                    )
+                    enhanced_convos = item["conversations"]
+
+                results.append(
+                    {
+                        "id": item["id"],
+                        "conversations": item["conversations"],
+                        "cot_conversations": enhanced_convos,
+                    }
+                )
+
+        except Exception as e:
+            print(f"Error processing batch starting at {i}: {str(e)}")
+            continue
+
+    return results
+
+
+def main():
+    parser = argparse.ArgumentParser(
+        description="Process dataset to enhance conversations with CoT reasoning"
+    )
+    parser.add_argument(
+        "--model_id", type=str, required=True, help="Model name or path"
+    )
+    parser.add_argument(
+        "--config",
+        type=str,
+        required=True,
+        help="Path to YAML config with system prompt",
+    )
+    parser.add_argument(
+        "--output-path",
+        type=str,
+        required=True,
+        help="Output dataset directory path",
+    )
+    parser.add_argument(
+        "--dataset-path", type=str, required=True, help="Input dataset path"
+    )
+    parser.add_argument(
+        "--start-index", type=int, default=0, help="Starting index (inclusive)"
+    )
+    parser.add_argument("--end-index", type=int, help="Ending index (exclusive)")
+    parser.add_argument(
+        "--n-samples",
+        type=int,
+        default=0,
+        help="Number of random samples to process. If 0, process all samples in range",
+    )
+    parser.add_argument(
+        "--batch-size", type=int, default=16, help="Batch size for processing"
+    )
+    parser.add_argument("--debug", action="store_true", help="Enable debug mode")
+    parser.add_argument(
+        "--max-seq-len",
+        type=int,
+        default=64000,
+        help="Maximum sequence length for generation per batch",
+    )
+    parser.add_argument(
+        "--max-num-seqs",
+        type=int,
+        default=16,
+        help="Maximum number of sequences in batch",
+    )
+    parser.add_argument(
+        "--enforce-eager",
+        action="store_true",
+        help="Whether to enforce eager execution",
+    )
+    args = parser.parse_args()
+
+    # Set spawn method for multiprocessing
+    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+
+    # Load system prompt and dataset
+    system_prompt = load_system_prompt(args.config)
+    dataset = load_from_disk(args.dataset_path)
+    if isinstance(dataset, dict):
+        dataset = dataset["train"]
+
+    # Initialize VLLM instance
+    model_instance = LLM_Singleton(
+        model_id=args.model_id,
+        max_model_len=args.max_seq_len,  # Use the passed max_seq_len
+        max_num_seqs=16,
+        enforce_eager=True,
+        debug=args.debug,
+    )
+
+    # Process dataset and get results
+    results = process_dataset(
+        dataset=dataset,
+        system_prompt=system_prompt,
+        start_index=args.start_index,
+        end_index=args.end_index,
+        n_samples=args.n_samples,
+        model_instance=model_instance,
+        batch_size=args.batch_size,
+        max_seq_len=args.max_seq_len,
+    )
+
+    # Convert results to HuggingFace dataset
+    output_dataset = Dataset.from_dict(
+        {
+            "id": [r["id"] for r in results],
+            "conversations": [r["conversations"] for r in results],
+            "cot_conversations": [r["cot_conversations"] for r in results],
+        }
+    )
+
+    # Save dataset
+    output_dataset.save_to_disk(args.output_path)
+
+
+if __name__ == "__main__":
+    main()