trl_sft.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Source: https://www.philschmid.de/fine-tune-llms-in-2024-with-trl
  2. import torch
  3. from datasets import load_dataset
  4. from peft import LoraConfig
  5. from transformers import (
  6. AutoModelForCausalLM,
  7. AutoTokenizer,
  8. BitsAndBytesConfig,
  9. TrainingArguments,
  10. )
  11. from trl import setup_chat_format, SFTTrainer
  12. FT_DATASET = "train_text2sql_sft_dataset.json"
  13. # uncomment to use the reasoning dataset created by "create_reasoning_dataset.py"
  14. # FT_DATASET = "train_text2sql_cot_dataset.json"
  15. dataset = load_dataset("json", data_files=SFT_DATASET, split="train")
  16. model_id = "meta-llama/Llama-3.1-8B-Instruct"
  17. bnb_config = BitsAndBytesConfig(
  18. load_in_4bit=True,
  19. bnb_4bit_use_double_quant=True,
  20. bnb_4bit_quant_type="nf4",
  21. bnb_4bit_compute_dtype=torch.bfloat16,
  22. )
  23. model = AutoModelForCausalLM.from_pretrained(
  24. model_id,
  25. device_map="auto",
  26. torch_dtype=torch.bfloat16,
  27. quantization_config=bnb_config,
  28. )
  29. tokenizer = AutoTokenizer.from_pretrained(model_id)
  30. tokenizer.padding_side = "right"
  31. if tokenizer.pad_token is None:
  32. tokenizer.add_special_tokens({"pad_token": "[PAD]"})
  33. model.resize_token_embeddings(len(tokenizer))
  34. peft_config = LoraConfig(
  35. lora_alpha=128,
  36. lora_dropout=0.05,
  37. r=256,
  38. bias="none",
  39. target_modules="all-linear",
  40. task_type="CAUSAL_LM",
  41. )
  42. args = TrainingArguments(
  43. output_dir="llama31-8b-text2sql-fine-tuned", # directory to save and repository id
  44. num_train_epochs=3, # number of training epochs
  45. per_device_train_batch_size=3, # batch size per device during training
  46. gradient_accumulation_steps=2, # number of steps before performing a backward/update pass
  47. gradient_checkpointing=True, # use gradient checkpointing to save memory
  48. optim="adamw_torch_fused", # use fused adamw optimizer
  49. logging_steps=10, # log every 10 steps
  50. save_strategy="epoch", # save checkpoint every epoch
  51. learning_rate=2e-4, # learning rate, based on QLoRA paper
  52. bf16=True, # use bfloat16 precision
  53. tf32=True, # use tf32 precision
  54. max_grad_norm=0.3, # max gradient norm based on QLoRA paper
  55. warmup_ratio=0.03, # warmup ratio based on QLoRA paper
  56. lr_scheduler_type="constant", # use constant learning rate scheduler
  57. push_to_hub=True, # push model to hub
  58. report_to="tensorboard", # report metrics to tensorboard
  59. )
  60. max_seq_length = 4096
  61. trainer = SFTTrainer(
  62. model=model,
  63. args=args,
  64. train_dataset=dataset,
  65. max_seq_length=max_seq_length,
  66. tokenizer=tokenizer,
  67. peft_config=peft_config,
  68. packing=True,
  69. )
  70. trainer.train()
  71. trainer.save_model()