trl_sft.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # Unified script supporting multiple fine-tuning configurations
  2. import argparse
  3. import sys
  4. import torch
  5. from datasets import load_dataset
  6. from transformers import AutoModelForCausalLM, AutoTokenizer
  7. # Parse command line arguments
  8. parser = argparse.ArgumentParser(
  9. description="Unified fine-tuning script for Llama 3.1 8B"
  10. )
  11. parser.add_argument(
  12. "--quantized",
  13. type=str,
  14. choices=["true", "false"],
  15. required=True,
  16. help="Whether to use quantization (true/false)",
  17. )
  18. parser.add_argument(
  19. "--peft",
  20. type=str,
  21. choices=["true", "false"],
  22. required=True,
  23. help="Whether to use PEFT (true/false)",
  24. )
  25. parser.add_argument(
  26. "--cot",
  27. type=str,
  28. choices=["true", "false"],
  29. required=True,
  30. help="Whether to use Chain-of-Thought dataset (true/false)",
  31. )
  32. args = parser.parse_args()
  33. # Convert string arguments to boolean
  34. use_quantized = args.quantized.lower() == "true"
  35. use_peft = args.peft.lower() == "true"
  36. use_cot = args.cot.lower() == "true"
  37. # Check for unsupported combination
  38. if not use_peft and use_quantized:
  39. print(
  40. "ERROR: Full Fine-Tuning (peft=false) with Quantization (quantized=true) is NOT RECOMMENDED!"
  41. )
  42. print("This combination can lead to:")
  43. print("- Gradient precision loss due to quantization")
  44. print("- Training instability")
  45. print("- Suboptimal convergence")
  46. print("\nRecommended combinations:")
  47. print(
  48. "1. --peft=true --quantized=true (PEFT + Quantized - Most memory efficient)"
  49. )
  50. print("2. --peft=true --quantized=false (PEFT + Non-quantized - Good balance)")
  51. print(
  52. "3. --peft=false --quantized=false (FFT + Non-quantized - Maximum performance)"
  53. )
  54. sys.exit(1)
  55. print(f"Configuration: PEFT={use_peft}, Quantized={use_quantized}, CoT={use_cot}")
  56. # Import additional modules based on configuration
  57. if use_quantized:
  58. from transformers import BitsAndBytesConfig
  59. if use_peft:
  60. from peft import LoraConfig
  61. from trl import setup_chat_format, SFTConfig, SFTTrainer
  62. # Dataset configuration based on CoT parameter
  63. if use_cot:
  64. FT_DATASET = "train_text2sql_cot_dataset.json"
  65. print("Using Chain-of-Thought reasoning dataset")
  66. else:
  67. FT_DATASET = "train_text2sql_sft_dataset.json"
  68. print("Using standard SFT dataset")
  69. dataset = load_dataset("json", data_files=FT_DATASET, split="train")
  70. model_id = "meta-llama/Llama-3.1-8B-Instruct"
  71. # Configure quantization if needed
  72. quantization_config = None
  73. if use_quantized:
  74. quantization_config = BitsAndBytesConfig(
  75. load_in_4bit=True,
  76. bnb_4bit_use_double_quant=True,
  77. bnb_4bit_quant_type="nf4",
  78. bnb_4bit_compute_dtype=torch.bfloat16,
  79. )
  80. # Load model
  81. model = AutoModelForCausalLM.from_pretrained(
  82. model_id,
  83. device_map="auto",
  84. torch_dtype=torch.bfloat16,
  85. quantization_config=quantization_config,
  86. )
  87. # Load tokenizer
  88. tokenizer = AutoTokenizer.from_pretrained(model_id)
  89. tokenizer.padding_side = "right"
  90. if tokenizer.pad_token is None:
  91. tokenizer.add_special_tokens({"pad_token": "[PAD]"})
  92. model.resize_token_embeddings(len(tokenizer))
  93. # Configure PEFT if needed
  94. peft_config = None
  95. if use_peft:
  96. peft_config = LoraConfig(
  97. lora_alpha=128,
  98. lora_dropout=0.05,
  99. r=256,
  100. bias="none",
  101. target_modules="all-linear",
  102. task_type="CAUSAL_LM",
  103. )
  104. # Configure training arguments based on combination using newer TRL API
  105. cot_suffix = "cot" if use_cot else "nocot"
  106. if use_peft and use_quantized:
  107. # PEFT + Quantized: Use SFTConfig (newer API)
  108. print("Using PEFT + Quantized configuration")
  109. args = SFTConfig(
  110. output_dir=f"llama31-8b-text2sql-peft-quantized-{cot_suffix}",
  111. num_train_epochs=3,
  112. per_device_train_batch_size=3,
  113. gradient_accumulation_steps=2,
  114. gradient_checkpointing=True,
  115. optim="adamw_torch_fused",
  116. logging_steps=10,
  117. save_strategy="epoch",
  118. learning_rate=2e-4,
  119. bf16=True,
  120. tf32=True,
  121. max_grad_norm=0.3,
  122. warmup_ratio=0.03,
  123. lr_scheduler_type="constant",
  124. push_to_hub=True,
  125. report_to="tensorboard",
  126. max_seq_length=4096,
  127. packing=True,
  128. )
  129. elif use_peft and not use_quantized:
  130. # PEFT + Non-quantized: Use SFTConfig (newer API)
  131. print("Using PEFT + Non-quantized configuration")
  132. args = SFTConfig(
  133. output_dir=f"llama31-8b-text2sql-peft-nonquantized-{cot_suffix}",
  134. num_train_epochs=3,
  135. per_device_train_batch_size=2, # Slightly reduced for non-quantized
  136. gradient_accumulation_steps=4, # Increased to maintain effective batch size
  137. gradient_checkpointing=True,
  138. optim="adamw_torch_fused",
  139. logging_steps=10,
  140. save_strategy="epoch",
  141. learning_rate=2e-4,
  142. bf16=True,
  143. tf32=True,
  144. max_grad_norm=0.3,
  145. warmup_ratio=0.03,
  146. lr_scheduler_type="constant",
  147. push_to_hub=True,
  148. report_to="tensorboard",
  149. max_seq_length=4096,
  150. packing=True,
  151. )
  152. else: # not use_peft and not use_quantized
  153. # FFT + Non-quantized: Use SFTConfig (newer API)
  154. print("Using Full Fine-Tuning + Non-quantized configuration")
  155. args = SFTConfig(
  156. output_dir=f"llama31-8b-text2sql-fft-nonquantized-{cot_suffix}",
  157. num_train_epochs=1, # Reduced epochs for full fine-tuning
  158. per_device_train_batch_size=1, # Reduced batch size for full model training
  159. gradient_accumulation_steps=8, # Increased to maintain effective batch size
  160. gradient_checkpointing=True,
  161. optim="adamw_torch_fused",
  162. logging_steps=10,
  163. save_strategy="epoch",
  164. learning_rate=5e-6, # Lower learning rate for full fine-tuning
  165. bf16=True,
  166. tf32=True,
  167. max_grad_norm=1.0, # Standard gradient clipping for full fine-tuning
  168. warmup_ratio=0.1, # Warmup ratio for full fine-tuning
  169. lr_scheduler_type="cosine", # Cosine scheduler for full fine-tuning
  170. push_to_hub=True,
  171. report_to="tensorboard",
  172. dataloader_pin_memory=False, # Disable pin memory to save GPU memory
  173. remove_unused_columns=False, # Keep all columns
  174. max_seq_length=4096,
  175. packing=True,
  176. )
  177. # Create trainer with consistent newer API
  178. trainer = SFTTrainer(
  179. model=model,
  180. args=args,
  181. train_dataset=dataset,
  182. processing_class=tokenizer,
  183. peft_config=peft_config,
  184. )
  185. # Print memory requirements estimate
  186. print("\nEstimated GPU Memory Requirements:")
  187. if use_peft and use_quantized:
  188. print("- PEFT + Quantized: ~12-16 GB")
  189. elif use_peft and not use_quantized:
  190. print("- PEFT + Non-quantized: ~20-25 GB")
  191. else: # FFT + Non-quantized
  192. print("- Full Fine-Tuning + Non-quantized: ~70-90 GB")
  193. print("\nStarting training...")
  194. trainer.train()
  195. print("Training completed. Saving model...")
  196. trainer.save_model()
  197. print("Model saved successfully!")