finetune_grid.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import os
  2. import subprocess
  3. from pathlib import Path
  4. from ..utils import load_config
  5. def get_general_finetune_args(finetuning_config, output_dir):
  6. experiment_dir = Path(output_dir).parent
  7. model_path = finetuning_config["model_path"]
  8. if not os.path.exists(model_path):
  9. raise RuntimeError(f"Model path {model_path} does not exist")
  10. tokenizer_path = finetuning_config["tokenizer_path"]
  11. # TODO: Change "task1" to task name defined in config
  12. dataset_path = (
  13. experiment_dir / "formatted_datasets" / "task1" / "train_conversation_data.json"
  14. )
  15. return [
  16. f"dataset.dataset_path={dataset_path}",
  17. f"checkpointer.checkpoint_dir={model_path}",
  18. f"tokenizer.path={tokenizer_path}",
  19. f"epochs={finetuning_config['epochs']}",
  20. f"batch_size={finetuning_config['batch_size']}",
  21. f"metric_logger.log_dir={experiment_dir}/finetune_logs",
  22. ]
  23. def build_fft_jobs(config, output_dir):
  24. """Build FFT (Full Fine-Tuning) jobs based on config"""
  25. jobs = []
  26. finetuning_config = config["finetuning"]
  27. recipe = (
  28. "full_finetune_distributed"
  29. if finetuning_config.get("distributed")
  30. else "full_finetune_single_device"
  31. )
  32. torchtune_config = finetuning_config.get("fft_torchtune_config")
  33. base_cmd = [
  34. "tune",
  35. "run",
  36. "--nproc_per_node",
  37. str(finetuning_config["ngpu"]),
  38. recipe,
  39. "--config",
  40. torchtune_config,
  41. ]
  42. base_cmd += get_general_finetune_args(finetuning_config, output_dir)
  43. # Build list of modules to train based on config
  44. modules_to_train = []
  45. if finetuning_config.get("fusion", False):
  46. modules_to_train.append("fusion")
  47. if finetuning_config.get("fusion+encoder", False):
  48. modules_to_train.append("fusion+encoder")
  49. if finetuning_config.get("fusion+decoder", False):
  50. modules_to_train.append("fusion+decoder")
  51. if finetuning_config.get("fusion+encoder+decoder", False):
  52. modules_to_train.append("fusion+encoder+decoder")
  53. for modules in modules_to_train:
  54. op_path = f"{output_dir}/full_{modules}"
  55. if os.path.exists(op_path):
  56. print(f"Skipping {op_path} as it already exists")
  57. continue
  58. module_opts = [f"model.{mod}_trainable=True" for mod in modules.split("+")]
  59. jobs.append(base_cmd + [f"output_dir={op_path}"] + module_opts)
  60. return jobs
  61. def build_lora_jobs(config, output_dir):
  62. """Build LoRA jobs based on config"""
  63. jobs = []
  64. finetuning_config = config["finetuning"]
  65. if not finetuning_config.get("lora_ranks"):
  66. return jobs
  67. recipe = (
  68. "lora_finetune_distributed"
  69. if finetuning_config.get("distributed")
  70. else "lora_finetune_single_device"
  71. )
  72. torchtune_config = finetuning_config.get("lora_torchtune_config")
  73. base_cmd = [
  74. "tune",
  75. "run",
  76. "--nproc_per_node",
  77. str(finetuning_config["ngpu"]),
  78. recipe,
  79. "--config",
  80. torchtune_config,
  81. ]
  82. base_cmd += get_general_finetune_args(finetuning_config, output_dir)
  83. for rank in finetuning_config["lora_ranks"]:
  84. op_path = f"{output_dir}/lora_{rank}"
  85. if os.path.exists(op_path):
  86. print(f"Skipping {op_path} as it already exists")
  87. continue
  88. jobs.append(
  89. base_cmd
  90. + [
  91. f"output_dir={op_path}",
  92. f"model.lora_rank={rank}",
  93. f"model.lora_alpha={int(rank)*2}",
  94. ]
  95. )
  96. return jobs
  97. def run_finetune_grid(experiment_dir: str):
  98. print("🚀 Starting fine-tuning grid execution...")
  99. print(f"📁 Experiment directory: {experiment_dir}")
  100. # Get script directory and config path
  101. script_dir = Path(__file__).parent.parent.parent
  102. config_path = script_dir / "config.yaml"
  103. print(f"📝 Loading configuration from: {config_path}")
  104. # Load configuration
  105. config = load_config(config_path)
  106. print("✅ Configuration loaded successfully")
  107. # Set output directory
  108. output_dir = Path(experiment_dir) / "finetuned_checkpoints"
  109. print(f"💾 Output directory: {output_dir}")
  110. # Create output directory if it doesn't exist
  111. output_dir.mkdir(parents=True, exist_ok=True)
  112. print("📂 Output directory created/verified")
  113. # Build all jobs
  114. all_jobs = []
  115. print("\n🔧 Building fine-tuning jobs...")
  116. # Check if we should run FFT jobs (if any fusion settings are enabled)
  117. finetuning_config = config["finetuning"]
  118. if any(
  119. [
  120. finetuning_config.get("fusion", False),
  121. finetuning_config.get("fusion+encoder", False),
  122. finetuning_config.get("fusion+decoder", False),
  123. finetuning_config.get("fusion+encoder+decoder", False),
  124. ]
  125. ):
  126. print("🔄 Building Full Fine-Tuning (FFT) jobs...")
  127. fft_jobs = build_fft_jobs(config, output_dir)
  128. all_jobs.extend(fft_jobs)
  129. print(f"✅ Built {len(fft_jobs)} FFT jobs")
  130. # Print details of FFT jobs
  131. for i, job in enumerate(fft_jobs, 1):
  132. job_type = "FFT"
  133. modules = [arg for arg in job if "trainable=True" in str(arg)]
  134. if modules:
  135. module_info = ", ".join(
  136. [
  137. mod.split(".")[1].replace("_trainable=True", "")
  138. for mod in modules
  139. ]
  140. )
  141. print(f" 📋 FFT Job {i}: {module_info}")
  142. # Check if we should run LoRA jobs
  143. if finetuning_config.get("lora_ranks"):
  144. print("🔄 Building LoRA fine-tuning jobs...")
  145. lora_jobs = build_lora_jobs(config, output_dir)
  146. all_jobs.extend(lora_jobs)
  147. lora_count = len(lora_jobs)
  148. print(f"✅ Built {lora_count} LoRA jobs")
  149. # Print details of LoRA jobs
  150. ranks = finetuning_config.get("lora_ranks", [])
  151. for i, rank in enumerate(ranks, 1):
  152. print(f" 📋 LoRA Job {i}: rank={rank}, alpha={rank*2}")
  153. total_jobs = len(all_jobs)
  154. print(f"\n📊 Total jobs to execute: {total_jobs}")
  155. # Run all jobs
  156. print(f"\n🎯 Executing {total_jobs} fine-tuning jobs...")
  157. print("=" * 60)
  158. for job_idx, job in enumerate(all_jobs, 1):
  159. print(f"\n📈 Job {job_idx}/{total_jobs} - Starting...")
  160. # Extract job type and details for better logging
  161. job_type = "LoRA" if "lora_finetune" in " ".join(job) else "FFT"
  162. output_path = next(
  163. (arg.split("=")[1] for arg in job if arg.startswith("output_dir=")),
  164. "unknown",
  165. )
  166. job_name = Path(output_path).name if output_path != "unknown" else "unknown"
  167. print(f"🔧 Type: {job_type}")
  168. print(f"📁 Output: {job_name}")
  169. # print(f"⚡ Command: {' '.join(map(str, job))}")
  170. print("-" * 40)
  171. try:
  172. print(f"⏳ Executing job {job_idx}/{total_jobs}...")
  173. subprocess.run(job, check=True, capture_output=False)
  174. print(f"✅ Job {job_idx}/{total_jobs} completed successfully!")
  175. except subprocess.CalledProcessError as e:
  176. print(
  177. f"❌ Job {job_idx}/{total_jobs} failed with return code {e.returncode}"
  178. )
  179. print(f"💥 Error: {e}")
  180. raise
  181. except Exception as e:
  182. print(f"❌ Job {job_idx}/{total_jobs} failed with unexpected error: {e}")
  183. raise
  184. print("\n" + "=" * 60)
  185. print("🎉 All fine-tuning jobs completed successfully!")
  186. print(f"📁 Results saved to: {output_dir}")
  187. print("🏁 Fine-tuning grid execution finished.")
  188. if __name__ == "__main__":
  189. run_finetune_grid("experiments/w2_ocr")