run_pipeline.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. #!/usr/bin/env python
  2. """
  3. End-to-end pipeline for data loading, fine-tuning, and inference.
  4. This script integrates all the modules in the finetune_pipeline package:
  5. 1. Data loading and formatting
  6. 2. Model fine-tuning
  7. 3. vLLM server startup
  8. 4. Inference on the fine-tuned model
  9. Example usage:
  10. python run_pipeline.py --config config.yaml
  11. python run_pipeline.py --config config.yaml --skip-finetuning --skip-server
  12. python run_pipeline.py --config config.yaml --only-inference
  13. """
  14. import argparse
  15. import json
  16. import logging
  17. import os
  18. import subprocess
  19. import sys
  20. import time
  21. from pathlib import Path
  22. from typing import Dict, List, Optional, Tuple
  23. # Configure logging
  24. logging.basicConfig(
  25. format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
  26. datefmt="%Y-%m-%d %H:%M:%S",
  27. level=logging.INFO,
  28. )
  29. logger = logging.getLogger(__name__)
  30. # Import modules from the finetune_pipeline package
  31. from finetune_pipeline.data.data_loader import load_and_format_data, read_config
  32. from finetune_pipeline.finetuning.run_finetuning import run_torch_tune
  33. # from finetune_pipeline.inference.run_inference import run_inference_on_eval_data
  34. from finetune_pipeline.inference.start_vllm_server import start_vllm_server
  35. def run_data_loading(config_path: str) -> Tuple[List[str], List[str]]:
  36. """
  37. Run the data loading and formatting step.
  38. Args:
  39. config_path: Path to the configuration file
  40. Returns:
  41. Tuple containing lists of paths to the formatted data and conversation data
  42. """
  43. logger.info("=== Step 1: Data Loading and Formatting ===")
  44. # Read the configuration
  45. config = read_config(config_path)
  46. formatter_config = config.get("formatter", {})
  47. output_dir = config.get("output_dir", "/tmp/finetune-pipeline/data/")
  48. # Create the output directory if it doesn't exist
  49. os.makedirs(output_dir, exist_ok=True)
  50. # Load and format the data
  51. try:
  52. formatted_data_paths, conversation_data_paths = load_and_format_data(
  53. formatter_config, output_dir
  54. )
  55. logger.info(f"Data loading and formatting complete. Saved to {output_dir}")
  56. logger.info(f"Formatted data paths: {formatted_data_paths}")
  57. logger.info(f"Conversation data paths: {conversation_data_paths}")
  58. return formatted_data_paths, conversation_data_paths
  59. except Exception as e:
  60. logger.error(f"Error during data loading and formatting: {e}")
  61. raise
  62. def run_finetuning(config_path: str, formatted_data_paths: List[str]) -> str:
  63. """
  64. Run the fine-tuning step.
  65. Args:
  66. config_path: Path to the configuration file
  67. formatted_data_paths: Paths to the formatted data
  68. Returns:
  69. Path to the fine-tuned model
  70. """
  71. logger.info("=== Step 2: Model Fine-tuning ===")
  72. # Read the configuration
  73. config = read_config(config_path)
  74. finetuning_config = config.get("finetuning", {})
  75. output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
  76. # Get the path to the formatted data for the train split
  77. train_data_path = None
  78. for path in formatted_data_paths:
  79. if "train_" in path:
  80. train_data_path = path
  81. break
  82. if not train_data_path:
  83. logger.warning("No train split found in formatted data. Using the first file.")
  84. train_data_path = formatted_data_paths[0]
  85. # Prepare additional kwargs for the fine-tuning
  86. kwargs = f"dataset=finetune_pipeline.finetuning.custom_sft_dataset dataset.train_on_input=True dataset.dataset_path={train_data_path}"
  87. # Create an args object to pass to run_torch_tune
  88. class Args:
  89. pass
  90. args = Args()
  91. args.kwargs = kwargs
  92. # Run the fine-tuning
  93. try:
  94. logger.info(f"Starting fine-tuning with data from {train_data_path}")
  95. run_torch_tune(finetuning_config, args=args)
  96. # Get the path to the fine-tuned model
  97. # This is a simplification; in reality, you'd need to extract the actual path from the fine-tuning output
  98. model_path = os.path.join(output_dir, "finetuned_model")
  99. logger.info(f"Fine-tuning complete. Model saved to {model_path}")
  100. return model_path
  101. except Exception as e:
  102. logger.error(f"Error during fine-tuning: {e}")
  103. raise
  104. def run_vllm_server(config_path: str, model_path: str) -> str:
  105. """
  106. Start the vLLM server.
  107. Args:
  108. config_path: Path to the configuration file
  109. model_path: Path to the fine-tuned model
  110. Returns:
  111. URL of the vLLM server
  112. """
  113. logger.info("=== Step 3: Starting vLLM Server ===")
  114. # Read the configuration
  115. config = read_config(config_path)
  116. inference_config = config.get("inference", {})
  117. model_path = inference_config.get("model_path","/home/ubuntu/yash-workspace/medgemma-4b-it")
  118. # # Update the model path in the inference config
  119. # inference_config["model_path"] = model_path
  120. # Extract server parameters
  121. port = inference_config.get("port", 8000)
  122. host = inference_config.get("host", "0.0.0.0")
  123. tensor_parallel_size = inference_config.get("tensor_parallel_size", 1)
  124. max_model_len = inference_config.get("max_model_len", 4096)
  125. max_num_seqs = inference_config.get("max_num_seqs", 256)
  126. quantization = inference_config.get("quantization")
  127. gpu_memory_utilization = inference_config.get("gpu_memory_utilization", 0.9)
  128. enforce_eager = inference_config.get("enforce_eager", False)
  129. # Start the server in a separate process
  130. try:
  131. logger.info(f"Starting vLLM server with model {model_path}")
  132. result = start_vllm_server(model_path,
  133. port,
  134. host,
  135. tensor_parallel_size,
  136. max_model_len,
  137. max_num_seqs,
  138. quantization,
  139. gpu_memory_utilization,
  140. enforce_eager)
  141. if result.returncode == 0:
  142. server_url = f"http://{host}:{port}/v1"
  143. logger.info(f"vLLM server started at {server_url}")
  144. return server_url
  145. else:
  146. logger.error(f"vLLM server failed to start")
  147. raise RuntimeError("vLLM server failed to start")
  148. except Exception as e:
  149. logger.error(f"Error starting vLLM server: {e}")
  150. raise
  151. # def run_inference(
  152. # config_path: str, server_url: str, formatted_data_paths: List[str]
  153. # ) -> str:
  154. # """
  155. # Run inference on the fine-tuned model.
  156. # Args:
  157. # config_path: Path to the configuration file
  158. # server_url: URL of the vLLM server
  159. # formatted_data_paths: Paths to the formatted data
  160. # Returns:
  161. # Path to the inference results
  162. # """
  163. # logger.info("=== Step 4: Running Inference ===")
  164. # # Read the configuration
  165. # config = read_config(config_path)
  166. # inference_config = config.get("inference", {})
  167. # output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
  168. # # Get the path to the formatted data for the validation or test split
  169. # eval_data_path = inference_config.get("eval_data")
  170. # if not eval_data_path:
  171. # # Try to find a validation or test split in the formatted data
  172. # for path in formatted_data_paths:
  173. # if "validation_" in path or "test_" in path:
  174. # eval_data_path = path
  175. # break
  176. # if not eval_data_path:
  177. # logger.warning(
  178. # "No validation or test split found in formatted data. Using the first file."
  179. # )
  180. # eval_data_path = formatted_data_paths[0]
  181. # # Extract inference parameters
  182. # model_name = inference_config.get("model_name", "default")
  183. # temperature = inference_config.get("temperature", 0.0)
  184. # top_p = inference_config.get("top_p", 1.0)
  185. # max_tokens = inference_config.get("max_tokens", 100)
  186. # seed = inference_config.get("seed")
  187. # # Run inference
  188. # try:
  189. # logger.info(
  190. # f"Running inference on {eval_data_path} using server at {server_url}"
  191. # )
  192. # results = run_inference_on_eval_data(
  193. # eval_data_path=eval_data_path,
  194. # server_url=server_url,
  195. # is_local=True, # Assuming the formatted data is local
  196. # model_name=model_name,
  197. # temperature=temperature,
  198. # top_p=top_p,
  199. # max_tokens=max_tokens,
  200. # seed=seed,
  201. # )
  202. # # Save the results
  203. # results_path = os.path.join(output_dir, "inference_results.json")
  204. # with open(results_path, "w") as f:
  205. # json.dump(results, f, indent=2)
  206. # logger.info(f"Inference complete. Results saved to {results_path}")
  207. # return results_path
  208. # except Exception as e:
  209. # logger.error(f"Error during inference: {e}")
  210. # raise
  211. def run_pipeline(
  212. config_path: str,
  213. skip_data_loading: bool = False,
  214. skip_finetuning: bool = False,
  215. skip_server: bool = False,
  216. skip_inference: bool = False,
  217. only_data_loading: bool = False,
  218. only_finetuning: bool = False,
  219. only_server: bool = False,
  220. only_inference: bool = False,
  221. ) -> None:
  222. """
  223. Run the end-to-end pipeline.
  224. Args:
  225. config_path: Path to the configuration file
  226. skip_data_loading: Whether to skip the data loading step
  227. skip_finetuning: Whether to skip the fine-tuning step
  228. skip_server: Whether to skip starting the vLLM server
  229. skip_inference: Whether to skip the inference step
  230. only_data_loading: Whether to run only the data loading step
  231. only_finetuning: Whether to run only the fine-tuning step
  232. only_server: Whether to run only the vLLM server step
  233. only_inference: Whether to run only the inference step
  234. """
  235. logger.info(f"Starting pipeline with config {config_path}")
  236. # Check if the config file exists
  237. if not os.path.exists(config_path):
  238. logger.error(f"Config file {config_path} does not exist")
  239. sys.exit(1)
  240. # Handle "only" flags
  241. if only_data_loading:
  242. skip_finetuning = True
  243. skip_server = True
  244. skip_inference = True
  245. elif only_finetuning:
  246. skip_data_loading = True
  247. skip_server = True
  248. skip_inference = True
  249. elif only_server:
  250. skip_data_loading = True
  251. skip_finetuning = True
  252. skip_inference = True
  253. elif only_inference:
  254. skip_data_loading = True
  255. skip_finetuning = True
  256. skip_server = True
  257. # Step 1: Data Loading and Formatting
  258. formatted_data_paths = []
  259. conversation_data_paths = []
  260. if not skip_data_loading:
  261. try:
  262. formatted_data_paths, conversation_data_paths = run_data_loading(
  263. config_path
  264. )
  265. except Exception as e:
  266. logger.error(f"Pipeline failed at data loading step: {e}")
  267. sys.exit(1)
  268. else:
  269. logger.info("Skipping data loading step")
  270. # Try to infer the paths from the config
  271. config = read_config(config_path)
  272. output_dir = config.get("output_dir", "/tmp/finetune-pipeline/data/")
  273. formatter_type = config.get("formatter", {}).get("type", "torchtune")
  274. # This is a simplification; in reality, you'd need to know the exact paths
  275. formatted_data_paths = [
  276. os.path.join(output_dir, f"train_{formatter_type}_formatted_data.json")
  277. ]
  278. # Step 2: Fine-tuning
  279. model_path = ""
  280. if not skip_finetuning:
  281. try:
  282. model_path = run_finetuning(config_path, formatted_data_paths)
  283. except Exception as e:
  284. logger.error(f"Pipeline failed at fine-tuning step: {e}")
  285. sys.exit(1)
  286. else:
  287. logger.info("Skipping fine-tuning step")
  288. # Try to infer the model path from the config
  289. config = read_config(config_path)
  290. output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
  291. model_path = os.path.join(output_dir, "finetuned_model")
  292. # Step 3: Start vLLM Server
  293. server_url = ""
  294. server_process = None
  295. if not skip_server:
  296. try:
  297. server_url = run_vllm_server(config_path, model_path)
  298. except Exception as e:
  299. logger.error(f"Pipeline failed at vLLM server step: {e}")
  300. sys.exit(1)
  301. else:
  302. logger.info("Skipping vLLM server step")
  303. # Try to infer the server URL from the config
  304. config = read_config(config_path)
  305. inference_config = config.get("inference", {})
  306. host = inference_config.get("host", "0.0.0.0")
  307. port = inference_config.get("port", 8000)
  308. server_url = f"http://{host}:{port}/v1"
  309. # Step 4: Inference
  310. if not skip_inference:
  311. try:
  312. results_path = run_inference(config_path, server_url, formatted_data_paths)
  313. logger.info(
  314. f"Pipeline completed successfully. Results saved to {results_path}"
  315. )
  316. except Exception as e:
  317. logger.error(f"Pipeline failed at inference step: {e}")
  318. sys.exit(1)
  319. else:
  320. logger.info("Skipping inference step")
  321. logger.info("Pipeline execution complete")
  322. def main():
  323. """Main function."""
  324. parser = argparse.ArgumentParser(description="Run the end-to-end pipeline")
  325. # Configuration
  326. parser.add_argument(
  327. "--config",
  328. type=str,
  329. required=True,
  330. help="Path to the configuration file",
  331. )
  332. # Skip flags
  333. parser.add_argument(
  334. "--skip-data-loading",
  335. action="store_true",
  336. help="Skip the data loading step",
  337. )
  338. parser.add_argument(
  339. "--skip-finetuning",
  340. action="store_true",
  341. help="Skip the fine-tuning step",
  342. )
  343. parser.add_argument(
  344. "--skip-server",
  345. action="store_true",
  346. help="Skip starting the vLLM server",
  347. )
  348. parser.add_argument(
  349. "--skip-inference",
  350. action="store_true",
  351. help="Skip the inference step",
  352. )
  353. # Only flags
  354. parser.add_argument(
  355. "--only-data-loading",
  356. action="store_true",
  357. help="Run only the data loading step",
  358. )
  359. parser.add_argument(
  360. "--only-finetuning",
  361. action="store_true",
  362. help="Run only the fine-tuning step",
  363. )
  364. parser.add_argument(
  365. "--only-server",
  366. action="store_true",
  367. help="Run only the vLLM server step",
  368. )
  369. parser.add_argument(
  370. "--only-inference",
  371. action="store_true",
  372. help="Run only the inference step",
  373. )
  374. args = parser.parse_args()
  375. # Run the pipeline
  376. run_pipeline(
  377. config_path=args.config,
  378. skip_data_loading=args.skip_data_loading,
  379. skip_finetuning=args.skip_finetuning,
  380. skip_server=args.skip_server,
  381. skip_inference=args.skip_inference,
  382. only_data_loading=args.only_data_loading,
  383. only_finetuning=args.only_finetuning,
  384. only_server=args.only_server,
  385. only_inference=args.only_inference,
  386. )
  387. if __name__ == "__main__":
  388. main()