瀏覽代碼

Merge branch 'ft-fw' into ft_fw_s

Suraj Subramanian 1 月之前
父節點
當前提交
b8cc256c0a

+ 0 - 1
src/finetune_pipeline/data/data_loader.py

@@ -107,7 +107,6 @@ def load_dataset(
 
     if not data_path:
         raise ValueError("data_path must be provided")
-
     dataset = None
     if is_local:
         # Load from local disk

+ 4 - 5
src/finetune_pipeline/inference/run_inference.py

@@ -1,8 +1,8 @@
 import argparse
 import json
 import logging
-from typing import Any, Dict, List, Optional, TypedDict, Union
 import os
+from typing import Any, Dict, List, Optional, TypedDict, Union
 
 import requests
 from tqdm import tqdm
@@ -17,6 +17,8 @@ from ..data.data_loader import (
     save_formatted_data,
 )
 
+logging.basicConfig(level=logging.INFO)
+
 # Set up logging
 logger = logging.getLogger(__name__)
 
@@ -94,7 +96,6 @@ def load_inference_data(
         try:
             logger.info("Loading raw data...")
             data = load_data(path, is_local, **dataset_kwargs)
-
             # Apply sample limit if specified
             if max_samples and hasattr(data, "__len__") and len(data) > max_samples:
                 logger.info(
@@ -131,7 +132,6 @@ def load_inference_data(
                 logger.debug(
                     f"Sample conversation: {sample_conv.messages[:2] if hasattr(sample_conv, 'messages') else sample_conv}"
                 )
-
         except Exception as e:
             logger.error(f"Failed to convert data to conversations: {e}")
             logger.error(f"Column mapping: {column_mapping}")
@@ -405,7 +405,7 @@ def main():
 
     config = read_config(args.config)
     inference_config = config.get("inference", {})
-    formatter_config = config.get("formatter", {})
+    formatter_config = config.get("data", {})
 
     # Model parameters
     model_path = inference_config.get("model_path", None)
@@ -434,7 +434,6 @@ def main():
     inference_data_kwargs = inference_config.get("inference_data_kwargs", {})
 
     inference_data = load_inference_data(inference_data_kwargs, formatter_config)
-
     results = run_vllm_batch_inference_on_dataset(
         inference_data,
         model_path,

+ 163 - 162
src/finetune_pipeline/run_pipeline.py

@@ -36,11 +36,12 @@ logger = logging.getLogger(__name__)
 from finetune_pipeline.data.data_loader import load_and_format_data, read_config
 from finetune_pipeline.finetuning.run_finetuning import run_torch_tune
 
-from finetune_pipeline.inference.run_inference import (
-    run_vllm_batch_inference_on_dataset,
-    save_inference_results,
-)
-from finetune_pipeline.inference.start_vllm_server import start_vllm_server
+# from finetune_pipeline.inference.run_inference import (
+#     run_vllm_batch_inference_on_dataset,
+#     save_inference_results,
+# )
+
+# from finetune_pipeline.inference.start_vllm_server import start_vllm_server
 
 
 def run_data_loading(config_path: str) -> Tuple[List[str], List[str]]:
@@ -57,7 +58,7 @@ def run_data_loading(config_path: str) -> Tuple[List[str], List[str]]:
 
     # Read the configuration
     config = read_config(config_path)
-    formatter_config = config.get("formatter", {})
+    data_config = config.get("data", {})
     output_dir = config.get("output_dir", "/tmp/finetune-pipeline/data/")
 
     # Create the output directory if it doesn't exist
@@ -66,7 +67,7 @@ def run_data_loading(config_path: str) -> Tuple[List[str], List[str]]:
     # Load and format the data
     try:
         formatted_data_paths, conversation_data_paths = load_and_format_data(
-            formatter_config, output_dir
+            data_config, output_dir
         )
         logger.info(f"Data loading and formatting complete. Saved to {output_dir}")
         logger.info(f"Formatted data paths: {formatted_data_paths}")
@@ -133,149 +134,149 @@ def run_finetuning(config_path: str, formatted_data_paths: List[str]) -> str:
         raise
 
 
-def run_vllm_server(config_path: str, model_path: str) -> str:
-    """
-    Start the vLLM server.
-
-    Args:
-        config_path: Path to the configuration file
-        model_path: Path to the fine-tuned model
-
-    Returns:
-        URL of the vLLM server
-    """
-    logger.info("=== Step 3: Starting vLLM Server ===")
-
-    # Read the configuration
-    config = read_config(config_path)
-    inference_config = config.get("inference", {})
-
-    model_path = inference_config.get(
-        "model_path", "/home/ubuntu/yash-workspace/medgemma-4b-it"
-    )
-
-    # # Update the model path in the inference config
-    # inference_config["model_path"] = model_path
-
-    # Extract server parameters
-    port = inference_config.get("port", 8000)
-    host = inference_config.get("host", "0.0.0.0")
-    tensor_parallel_size = inference_config.get("tensor_parallel_size", 1)
-    max_model_len = inference_config.get("max_model_len", 4096)
-    max_num_seqs = inference_config.get("max_num_seqs", 256)
-    quantization = inference_config.get("quantization")
-    gpu_memory_utilization = inference_config.get("gpu_memory_utilization", 0.9)
-    enforce_eager = inference_config.get("enforce_eager", False)
-
-    # Start the server in a separate process
-    try:
-        logger.info(f"Starting vLLM server with model {model_path}")
-        result = start_vllm_server(
-            model_path,
-            port,
-            host,
-            tensor_parallel_size,
-            max_model_len,
-            max_num_seqs,
-            quantization,
-            gpu_memory_utilization,
-            enforce_eager,
-        )
-        if result.returncode == 0:
-            server_url = f"http://{host}:{port}/v1"
-            logger.info(f"vLLM server started at {server_url}")
-            return server_url
-        else:
-            logger.error(f"vLLM server failed to start")
-            raise RuntimeError("vLLM server failed to start")
-    except Exception as e:
-        logger.error(f"Error starting vLLM server: {e}")
-        raise
-
-
-def run_inference(
-    config_path: str, formatted_data_paths: List[str], model_path: str = ""
-) -> str:
-    """
-    Run inference on the fine-tuned model.
-
-    Args:
-        config_path: Path to the configuration file
-        formatted_data_paths: Paths to the formatted data (for compatibility)
-
-    Returns:
-        Path to the inference results
-    """
-    logger.info("=== Step 4: Running Inference ===")
-
-    config = read_config(config_path)
-    inference_config = config.get("inference", {})
-    formatter_config = config.get("formatter", {})
-    output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
-
-    # Model parameters
-    if model_path == "":
-        model_path = inference_config.get("model_path", None)
-    if model_path is None:
-        raise ValueError("model_path must be specified in the config")
-
-    # Get inference data configuration
-    inference_data_kwargs = inference_config.get("inference_data_kwargs", {})
-    if not inference_data_kwargs or not inference_data_kwargs.get("data_path"):
-        raise ValueError(
-            "inference_data_kwargs with data_path must be specified in config"
-        )
-
-    # Performance parameters
-    gpu_memory_utilization = inference_config.get("gpu_memory_utilization", 0.95)
-    max_model_len = inference_config.get("max_model_len", 512)
-
-    # Generation parameters
-    max_tokens = inference_config.get("max_tokens", 100)
-    temperature = inference_config.get("temperature", 0.0)
-    top_p = inference_config.get("top_p", 1.0)
-    seed = inference_config.get("seed")
-    structured = inference_config.get("structured", False)
-
-    # Load inference data using the new function
-    try:
-        logger.info("Loading inference data...")
-        from finetune_pipeline.inference.run_inference import load_inference_data
-
-        inference_data = load_inference_data(
-            inference_data_kwargs=inference_data_kwargs,
-            formatter_config=formatter_config,
-        )
-        logger.info(f"Loaded {len(inference_data)} samples for inference")
-
-    except Exception as e:
-        logger.error(f"Failed to load inference data: {e}")
-        raise
-
-    # Run inference
-    try:
-        logger.info(f"Running inference with model: {model_path}")
-        results = run_vllm_batch_inference_on_dataset(
-            inference_data=inference_data,
-            model_path=model_path,
-            temperature=temperature,
-            top_p=top_p,
-            max_tokens=max_tokens,
-            seed=seed,
-            structured=structured,
-            gpu_memory_utilization=gpu_memory_utilization,
-            max_model_len=max_model_len,
-        )
-
-        # Save the results
-        results_path = os.path.join(output_dir, "inference_results.json")
-        save_inference_results(results, results_path)
-
-        logger.info(f"Inference complete. Results saved to {results_path}")
-        return results_path
-    except Exception as e:
-        logger.error(f"Error during inference: {e}")
-        raise
+# def run_vllm_server(config_path: str, model_path: str) -> str:
+#     """
+#     Start the vLLM server.
+
+#     Args:
+#         config_path: Path to the configuration file
+#         model_path: Path to the fine-tuned model
+
+#     Returns:
+#         URL of the vLLM server
+#     """
+#     logger.info("=== Step 3: Starting vLLM Server ===")
+
+#     # Read the configuration
+#     config = read_config(config_path)
+#     inference_config = config.get("inference", {})
+
+#     model_path = inference_config.get(
+#         "model_path", "/home/ubuntu/yash-workspace/medgemma-4b-it"
+#     )
+
+#     # # Update the model path in the inference config
+#     # inference_config["model_path"] = model_path
+
+#     # Extract server parameters
+#     port = inference_config.get("port", 8000)
+#     host = inference_config.get("host", "0.0.0.0")
+#     tensor_parallel_size = inference_config.get("tensor_parallel_size", 1)
+#     max_model_len = inference_config.get("max_model_len", 4096)
+#     max_num_seqs = inference_config.get("max_num_seqs", 256)
+#     quantization = inference_config.get("quantization")
+#     gpu_memory_utilization = inference_config.get("gpu_memory_utilization", 0.9)
+#     enforce_eager = inference_config.get("enforce_eager", False)
+
+#     # Start the server in a separate process
+#     try:
+#         logger.info(f"Starting vLLM server with model {model_path}")
+#         result = start_vllm_server(
+#             model_path,
+#             port,
+#             host,
+#             tensor_parallel_size,
+#             max_model_len,
+#             max_num_seqs,
+#             quantization,
+#             gpu_memory_utilization,
+#             enforce_eager,
+#         )
+#         if result.returncode == 0:
+#             server_url = f"http://{host}:{port}/v1"
+#             logger.info(f"vLLM server started at {server_url}")
+#             return server_url
+#         else:
+#             logger.error(f"vLLM server failed to start")
+#             raise RuntimeError("vLLM server failed to start")
+#     except Exception as e:
+#         logger.error(f"Error starting vLLM server: {e}")
+#         raise
+
+
+# def run_inference(
+#     config_path: str, formatted_data_paths: List[str], model_path: str = ""
+# ) -> str:
+#     """
+#     Run inference on the fine-tuned model.
+
+#     Args:
+#         config_path: Path to the configuration file
+#         formatted_data_paths: Paths to the formatted data (for compatibility)
+
+#     Returns:
+#         Path to the inference results
+#     """
+#     logger.info("=== Step 4: Running Inference ===")
+
+#     config = read_config(config_path)
+#     inference_config = config.get("inference", {})
+#     formatter_config = config.get("formatter", {})
+#     output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
+
+#     # Model parameters
+#     if model_path == "":
+#         model_path = inference_config.get("model_path", None)
+#     if model_path is None:
+#         raise ValueError("model_path must be specified in the config")
+
+#     # Get inference data configuration
+#     inference_data_kwargs = inference_config.get("inference_data_kwargs", {})
+#     if not inference_data_kwargs or not inference_data_kwargs.get("data_path"):
+#         raise ValueError(
+#             "inference_data_kwargs with data_path must be specified in config"
+#         )
+
+#     # Performance parameters
+#     gpu_memory_utilization = inference_config.get("gpu_memory_utilization", 0.95)
+#     max_model_len = inference_config.get("max_model_len", 512)
+
+#     # Generation parameters
+#     max_tokens = inference_config.get("max_tokens", 100)
+#     temperature = inference_config.get("temperature", 0.0)
+#     top_p = inference_config.get("top_p", 1.0)
+#     seed = inference_config.get("seed")
+#     structured = inference_config.get("structured", False)
+
+#     # Load inference data using the new function
+#     try:
+#         logger.info("Loading inference data...")
+#         from finetune_pipeline.inference.run_inference import load_inference_data
+
+#         inference_data = load_inference_data(
+#             inference_data_kwargs=inference_data_kwargs,
+#             formatter_config=formatter_config,
+#         )
+#         logger.info(f"Loaded {len(inference_data)} samples for inference")
+
+#     except Exception as e:
+#         logger.error(f"Failed to load inference data: {e}")
+#         raise
+
+#     # Run inference
+#     try:
+#         logger.info(f"Running inference with model: {model_path}")
+#         results = run_vllm_batch_inference_on_dataset(
+#             inference_data=inference_data,
+#             model_path=model_path,
+#             temperature=temperature,
+#             top_p=top_p,
+#             max_tokens=max_tokens,
+#             seed=seed,
+#             structured=structured,
+#             gpu_memory_utilization=gpu_memory_utilization,
+#             max_model_len=max_model_len,
+#         )
+
+#         # Save the results
+#         results_path = os.path.join(output_dir, "inference_results.json")
+#         save_inference_results(results, results_path)
+
+#         logger.info(f"Inference complete. Results saved to {results_path}")
+#         return results_path
+#     except Exception as e:
+#         logger.error(f"Error during inference: {e}")
+#         raise
 
 
 def run_pipeline(
@@ -366,18 +367,18 @@ def run_pipeline(
         model_path = os.path.join(output_dir, "finetuned_model")
 
     time.sleep(5)
-    # Step 3: Inference
-    if not skip_inference:
-        try:
-            results_path = run_inference(config_path, formatted_data_paths, model_path)
-            logger.info(
-                f"Pipeline completed successfully. Results saved to {results_path}"
-            )
-        except Exception as e:
-            logger.error(f"Pipeline failed at inference step: {e}")
-            sys.exit(1)
-    else:
-        logger.info("Skipping inference step")
+    # # Step 3: Inference
+    # if not skip_inference:
+    #     try:
+    #         results_path = run_inference(config_path, formatted_data_paths, model_path)
+    #         logger.info(
+    #             f"Pipeline completed successfully. Results saved to {results_path}"
+    #         )
+    #     except Exception as e:
+    #         logger.error(f"Pipeline failed at inference step: {e}")
+    #         sys.exit(1)
+    # else:
+    #     logger.info("Skipping inference step")
 
     logger.info("Pipeline execution complete")