|
@@ -1,85 +1,222 @@
|
|
|
-# Source: https://www.philschmid.de/fine-tune-llms-in-2024-with-trl
|
|
|
+# Unified script supporting multiple fine-tuning configurations
|
|
|
+
|
|
|
+import argparse
|
|
|
+import sys
|
|
|
|
|
|
import torch
|
|
|
from datasets import load_dataset
|
|
|
-from peft import LoraConfig
|
|
|
-from transformers import (
|
|
|
- AutoModelForCausalLM,
|
|
|
- AutoTokenizer,
|
|
|
- BitsAndBytesConfig,
|
|
|
- TrainingArguments,
|
|
|
+from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
+
|
|
|
+# Parse command line arguments
|
|
|
+parser = argparse.ArgumentParser(
|
|
|
+ description="Unified fine-tuning script for Llama 3.1 8B"
|
|
|
+)
|
|
|
+parser.add_argument(
|
|
|
+ "--quantized",
|
|
|
+ type=str,
|
|
|
+ choices=["true", "false"],
|
|
|
+ required=True,
|
|
|
+ help="Whether to use quantization (true/false)",
|
|
|
+)
|
|
|
+parser.add_argument(
|
|
|
+ "--peft",
|
|
|
+ type=str,
|
|
|
+ choices=["true", "false"],
|
|
|
+ required=True,
|
|
|
+ help="Whether to use PEFT (true/false)",
|
|
|
+)
|
|
|
+parser.add_argument(
|
|
|
+ "--cot",
|
|
|
+ type=str,
|
|
|
+ choices=["true", "false"],
|
|
|
+ required=True,
|
|
|
+ help="Whether to use Chain-of-Thought dataset (true/false)",
|
|
|
)
|
|
|
-from trl import setup_chat_format, SFTTrainer
|
|
|
+args = parser.parse_args()
|
|
|
+
|
|
|
+# Convert string arguments to boolean
|
|
|
+use_quantized = args.quantized.lower() == "true"
|
|
|
+use_peft = args.peft.lower() == "true"
|
|
|
+use_cot = args.cot.lower() == "true"
|
|
|
+
|
|
|
+# Check for unsupported combination
|
|
|
+if not use_peft and use_quantized:
|
|
|
+ print(
|
|
|
+ "ERROR: Full Fine-Tuning (peft=false) with Quantization (quantized=true) is NOT RECOMMENDED!"
|
|
|
+ )
|
|
|
+ print("This combination can lead to:")
|
|
|
+ print("- Gradient precision loss due to quantization")
|
|
|
+ print("- Training instability")
|
|
|
+ print("- Suboptimal convergence")
|
|
|
+ print("\nRecommended combinations:")
|
|
|
+ print(
|
|
|
+ "1. --peft=true --quantized=true (PEFT + Quantized - Most memory efficient)"
|
|
|
+ )
|
|
|
+ print("2. --peft=true --quantized=false (PEFT + Non-quantized - Good balance)")
|
|
|
+ print(
|
|
|
+ "3. --peft=false --quantized=false (FFT + Non-quantized - Maximum performance)"
|
|
|
+ )
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+print(f"Configuration: PEFT={use_peft}, Quantized={use_quantized}, CoT={use_cot}")
|
|
|
+
|
|
|
+# Import additional modules based on configuration
|
|
|
+if use_quantized:
|
|
|
+ from transformers import BitsAndBytesConfig
|
|
|
+if use_peft:
|
|
|
+ from peft import LoraConfig
|
|
|
+
|
|
|
+from trl import setup_chat_format, SFTConfig, SFTTrainer
|
|
|
|
|
|
-FT_DATASET = "train_text2sql_sft_dataset.json"
|
|
|
-# uncomment to use the reasoning dataset created by "create_reasoning_dataset.py"
|
|
|
-# FT_DATASET = "train_text2sql_cot_dataset.json"
|
|
|
+# Dataset configuration based on CoT parameter
|
|
|
+if use_cot:
|
|
|
+ FT_DATASET = "train_text2sql_cot_dataset.json"
|
|
|
+ print("Using Chain-of-Thought reasoning dataset")
|
|
|
+else:
|
|
|
+ FT_DATASET = "train_text2sql_sft_dataset.json"
|
|
|
+ print("Using standard SFT dataset")
|
|
|
|
|
|
-dataset = load_dataset("json", data_files=SFT_DATASET, split="train")
|
|
|
+dataset = load_dataset("json", data_files=FT_DATASET, split="train")
|
|
|
|
|
|
model_id = "meta-llama/Llama-3.1-8B-Instruct"
|
|
|
|
|
|
-bnb_config = BitsAndBytesConfig(
|
|
|
- load_in_4bit=True,
|
|
|
- bnb_4bit_use_double_quant=True,
|
|
|
- bnb_4bit_quant_type="nf4",
|
|
|
- bnb_4bit_compute_dtype=torch.bfloat16,
|
|
|
-)
|
|
|
+# Configure quantization if needed
|
|
|
+quantization_config = None
|
|
|
+if use_quantized:
|
|
|
+ quantization_config = BitsAndBytesConfig(
|
|
|
+ load_in_4bit=True,
|
|
|
+ bnb_4bit_use_double_quant=True,
|
|
|
+ bnb_4bit_quant_type="nf4",
|
|
|
+ bnb_4bit_compute_dtype=torch.bfloat16,
|
|
|
+ )
|
|
|
|
|
|
+# Load model
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
model_id,
|
|
|
device_map="auto",
|
|
|
torch_dtype=torch.bfloat16,
|
|
|
- quantization_config=bnb_config,
|
|
|
+ quantization_config=quantization_config,
|
|
|
)
|
|
|
-tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
|
|
|
+# Load tokenizer
|
|
|
+tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
tokenizer.padding_side = "right"
|
|
|
|
|
|
if tokenizer.pad_token is None:
|
|
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
|
|
model.resize_token_embeddings(len(tokenizer))
|
|
|
|
|
|
-peft_config = LoraConfig(
|
|
|
- lora_alpha=128,
|
|
|
- lora_dropout=0.05,
|
|
|
- r=256,
|
|
|
- bias="none",
|
|
|
- target_modules="all-linear",
|
|
|
- task_type="CAUSAL_LM",
|
|
|
-)
|
|
|
+# Configure PEFT if needed
|
|
|
+peft_config = None
|
|
|
+if use_peft:
|
|
|
+ peft_config = LoraConfig(
|
|
|
+ lora_alpha=128,
|
|
|
+ lora_dropout=0.05,
|
|
|
+ r=256,
|
|
|
+ bias="none",
|
|
|
+ target_modules="all-linear",
|
|
|
+ task_type="CAUSAL_LM",
|
|
|
+ )
|
|
|
|
|
|
-args = TrainingArguments(
|
|
|
- output_dir="llama31-8b-text2sql-fine-tuned", # directory to save and repository id
|
|
|
- num_train_epochs=3, # number of training epochs
|
|
|
- per_device_train_batch_size=3, # batch size per device during training
|
|
|
- gradient_accumulation_steps=2, # number of steps before performing a backward/update pass
|
|
|
- gradient_checkpointing=True, # use gradient checkpointing to save memory
|
|
|
- optim="adamw_torch_fused", # use fused adamw optimizer
|
|
|
- logging_steps=10, # log every 10 steps
|
|
|
- save_strategy="epoch", # save checkpoint every epoch
|
|
|
- learning_rate=2e-4, # learning rate, based on QLoRA paper
|
|
|
- bf16=True, # use bfloat16 precision
|
|
|
- tf32=True, # use tf32 precision
|
|
|
- max_grad_norm=0.3, # max gradient norm based on QLoRA paper
|
|
|
- warmup_ratio=0.03, # warmup ratio based on QLoRA paper
|
|
|
- lr_scheduler_type="constant", # use constant learning rate scheduler
|
|
|
- push_to_hub=True, # push model to hub
|
|
|
- report_to="tensorboard", # report metrics to tensorboard
|
|
|
-)
|
|
|
+# Configure training arguments based on combination using newer TRL API
|
|
|
+cot_suffix = "cot" if use_cot else "nocot"
|
|
|
|
|
|
-max_seq_length = 4096
|
|
|
+if use_peft and use_quantized:
|
|
|
+ # PEFT + Quantized: Use SFTConfig (newer API)
|
|
|
+ print("Using PEFT + Quantized configuration")
|
|
|
+ args = SFTConfig(
|
|
|
+ output_dir=f"llama31-8b-text2sql-peft-quantized-{cot_suffix}",
|
|
|
+ num_train_epochs=3,
|
|
|
+ per_device_train_batch_size=3,
|
|
|
+ gradient_accumulation_steps=2,
|
|
|
+ gradient_checkpointing=True,
|
|
|
+ optim="adamw_torch_fused",
|
|
|
+ logging_steps=10,
|
|
|
+ save_strategy="epoch",
|
|
|
+ learning_rate=2e-4,
|
|
|
+ bf16=True,
|
|
|
+ tf32=True,
|
|
|
+ max_grad_norm=0.3,
|
|
|
+ warmup_ratio=0.03,
|
|
|
+ lr_scheduler_type="constant",
|
|
|
+ push_to_hub=True,
|
|
|
+ report_to="tensorboard",
|
|
|
+ max_seq_length=4096,
|
|
|
+ packing=True,
|
|
|
+ )
|
|
|
|
|
|
+elif use_peft and not use_quantized:
|
|
|
+ # PEFT + Non-quantized: Use SFTConfig (newer API)
|
|
|
+ print("Using PEFT + Non-quantized configuration")
|
|
|
+ args = SFTConfig(
|
|
|
+ output_dir=f"llama31-8b-text2sql-peft-nonquantized-{cot_suffix}",
|
|
|
+ num_train_epochs=3,
|
|
|
+ per_device_train_batch_size=2, # Slightly reduced for non-quantized
|
|
|
+ gradient_accumulation_steps=4, # Increased to maintain effective batch size
|
|
|
+ gradient_checkpointing=True,
|
|
|
+ optim="adamw_torch_fused",
|
|
|
+ logging_steps=10,
|
|
|
+ save_strategy="epoch",
|
|
|
+ learning_rate=2e-4,
|
|
|
+ bf16=True,
|
|
|
+ tf32=True,
|
|
|
+ max_grad_norm=0.3,
|
|
|
+ warmup_ratio=0.03,
|
|
|
+ lr_scheduler_type="constant",
|
|
|
+ push_to_hub=True,
|
|
|
+ report_to="tensorboard",
|
|
|
+ max_seq_length=4096,
|
|
|
+ packing=True,
|
|
|
+ )
|
|
|
+
|
|
|
+else: # not use_peft and not use_quantized
|
|
|
+ # FFT + Non-quantized: Use SFTConfig (newer API)
|
|
|
+ print("Using Full Fine-Tuning + Non-quantized configuration")
|
|
|
+ args = SFTConfig(
|
|
|
+ output_dir=f"llama31-8b-text2sql-fft-nonquantized-{cot_suffix}",
|
|
|
+ num_train_epochs=1, # Reduced epochs for full fine-tuning
|
|
|
+ per_device_train_batch_size=1, # Reduced batch size for full model training
|
|
|
+ gradient_accumulation_steps=8, # Increased to maintain effective batch size
|
|
|
+ gradient_checkpointing=True,
|
|
|
+ optim="adamw_torch_fused",
|
|
|
+ logging_steps=10,
|
|
|
+ save_strategy="epoch",
|
|
|
+ learning_rate=5e-6, # Lower learning rate for full fine-tuning
|
|
|
+ bf16=True,
|
|
|
+ tf32=True,
|
|
|
+ max_grad_norm=1.0, # Standard gradient clipping for full fine-tuning
|
|
|
+ warmup_ratio=0.1, # Warmup ratio for full fine-tuning
|
|
|
+ lr_scheduler_type="cosine", # Cosine scheduler for full fine-tuning
|
|
|
+ push_to_hub=True,
|
|
|
+ report_to="tensorboard",
|
|
|
+ dataloader_pin_memory=False, # Disable pin memory to save GPU memory
|
|
|
+ remove_unused_columns=False, # Keep all columns
|
|
|
+ max_seq_length=4096,
|
|
|
+ packing=True,
|
|
|
+ )
|
|
|
+
|
|
|
+# Create trainer with consistent newer API
|
|
|
trainer = SFTTrainer(
|
|
|
model=model,
|
|
|
args=args,
|
|
|
train_dataset=dataset,
|
|
|
- max_seq_length=max_seq_length,
|
|
|
- tokenizer=tokenizer,
|
|
|
+ processing_class=tokenizer,
|
|
|
peft_config=peft_config,
|
|
|
- packing=True,
|
|
|
)
|
|
|
|
|
|
+# Print memory requirements estimate
|
|
|
+print("\nEstimated GPU Memory Requirements:")
|
|
|
+if use_peft and use_quantized:
|
|
|
+ print("- PEFT + Quantized: ~12-16 GB")
|
|
|
+elif use_peft and not use_quantized:
|
|
|
+ print("- PEFT + Non-quantized: ~20-25 GB")
|
|
|
+else: # FFT + Non-quantized
|
|
|
+ print("- Full Fine-Tuning + Non-quantized: ~70-90 GB")
|
|
|
+
|
|
|
+print("\nStarting training...")
|
|
|
trainer.train()
|
|
|
|
|
|
+print("Training completed. Saving model...")
|
|
|
trainer.save_model()
|
|
|
+
|
|
|
+print("Model saved successfully!")
|