run_finetuning.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. #!/usr/bin/env python
  2. """
  3. Fine-tuning script for language models using torch tune.
  4. Reads parameters from a config file and runs the torch tune command.
  5. """
  6. import argparse
  7. import logging
  8. import subprocess
  9. import sys
  10. from pathlib import Path
  11. from typing import Dict
  12. try:
  13. import yaml
  14. HAS_YAML = True
  15. except ImportError:
  16. HAS_YAML = False
  17. # Configure logging
  18. logging.basicConfig(
  19. format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
  20. datefmt="%Y-%m-%d %H:%M:%S",
  21. level=logging.INFO,
  22. )
  23. logger = logging.getLogger(__name__)
  24. ## Will import from dataloader eventually
  25. def read_config(config_path: str) -> Dict:
  26. """
  27. Read the configuration file (supports both JSON and YAML formats).
  28. Args:
  29. config_path: Path to the configuration file
  30. Returns:
  31. dict: Configuration parameters
  32. Raises:
  33. ValueError: If the file format is not supported
  34. ImportError: If the required package for the file format is not installed
  35. """
  36. file_extension = Path(config_path).suffix.lower()
  37. with open(config_path, "r") as f:
  38. if file_extension in [".json"]:
  39. config = json.load(f)
  40. elif file_extension in [".yaml", ".yml"]:
  41. if not HAS_YAML:
  42. raise ImportError(
  43. "The 'pyyaml' package is required to load YAML files. "
  44. "Please install it with 'pip install pyyaml'."
  45. )
  46. config = yaml.safe_load(f)
  47. else:
  48. raise ValueError(
  49. f"Unsupported config file format: {file_extension}. "
  50. f"Supported formats are: .json, .yaml, .yml"
  51. )
  52. return config
  53. def run_torch_tune(training_config: Dict, args=None):
  54. """
  55. Run torch tune command with parameters from config file.
  56. Args:
  57. config_path: Path to the configuration file
  58. args: Command line arguments that may include additional kwargs to pass to the command
  59. """
  60. # # Read the configuration
  61. # config = read_config(config_path)
  62. # Extract parameters from config
  63. # training_config = config.get("finetuning", {})
  64. # Initialize base_cmd to avoid "possibly unbound" error
  65. base_cmd = []
  66. # Determine the command based on configuration
  67. if training_config.get("distributed"):
  68. if training_config.get("strategy") == "lora":
  69. base_cmd = [
  70. "tune",
  71. "run",
  72. "--nproc_per_node",
  73. str(training_config.get("num_processes_per_node", 1)),
  74. "lora_finetune_distributed",
  75. "--config",
  76. training_config.get("torchtune_config"),
  77. ]
  78. elif training_config.get("strategy") == "fft":
  79. base_cmd = [
  80. "tune",
  81. "run",
  82. "--nproc_per_node",
  83. str(training_config.get("num_processes_per_node", 1)),
  84. "full_finetune_distributed",
  85. "--config",
  86. training_config.get("torchtune_config"),
  87. ]
  88. else:
  89. raise ValueError(f"Invalid strategy: {training_config.get('strategy')}")
  90. else:
  91. if training_config.get("strategy") == "lora":
  92. base_cmd = [
  93. "tune",
  94. "run",
  95. "lora_finetune_single_device",
  96. "--config",
  97. training_config.get("torchtune_config"),
  98. ]
  99. elif training_config.get("strategy") == "fft":
  100. base_cmd = [
  101. "tune",
  102. "run",
  103. "full_finetune_single_device",
  104. "--config",
  105. training_config.get("torchtune_config"),
  106. ]
  107. else:
  108. raise ValueError(f"Invalid strategy: {training_config.get('strategy')}")
  109. # Check if we have a valid command
  110. if not base_cmd:
  111. raise ValueError(
  112. "Could not determine the appropriate command based on the configuration"
  113. )
  114. # Add any additional kwargs if provided
  115. # if args and args.kwargs:
  116. # # Split the kwargs string by spaces to get individual key=value pairs
  117. # kwargs_list = args.kwargs.split()
  118. # base_cmd.extend(kwargs_list)
  119. # logger.info(f"Added additional kwargs: {kwargs_list}")
  120. # Log the command
  121. logger.info(f"Running command: {' '.join(base_cmd)}")
  122. # Run the command
  123. try:
  124. subprocess.run(base_cmd, check=True)
  125. logger.info("Training complete!")
  126. except subprocess.CalledProcessError as e:
  127. logger.error(f"Training failed with error: {e}")
  128. sys.exit(1)
  129. def main():
  130. """Main function."""
  131. parser = argparse.ArgumentParser(
  132. description="Fine-tune a language model using torch tune"
  133. )
  134. parser.add_argument(
  135. "--config",
  136. type=str,
  137. required=True,
  138. help="Path to the configuration file (JSON or YAML)",
  139. )
  140. parser.add_argument(
  141. "--kwargs",
  142. type=str,
  143. default=None,
  144. help="Additional key-value pairs to pass to the command (space-separated, e.g., 'dataset=module.function dataset.param=value')",
  145. )
  146. args = parser.parse_args()
  147. config = read_config(args.config)
  148. finetuning_config = config.get("finetuning", {})
  149. run_torch_tune(finetuning_config, args=args)
  150. if __name__ == "__main__":
  151. main()