trl_sft.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Source: https://www.philschmid.de/fine-tune-llms-in-2024-with-trl
  2. import torch
  3. from datasets import load_dataset
  4. from peft import LoraConfig
  5. from transformers import (
  6. AutoModelForCausalLM,
  7. AutoTokenizer,
  8. BitsAndBytesConfig,
  9. TrainingArguments,
  10. )
  11. from trl import setup_chat_format, SFTTrainer
  12. dataset = load_dataset(
  13. "json", data_files="train_text2sql_sft_dataset.json", split="train"
  14. )
  15. model_id = "meta-llama/Llama-3.1-8B-Instruct"
  16. bnb_config = BitsAndBytesConfig(
  17. load_in_4bit=True,
  18. bnb_4bit_use_double_quant=True,
  19. bnb_4bit_quant_type="nf4",
  20. bnb_4bit_compute_dtype=torch.bfloat16,
  21. )
  22. model = AutoModelForCausalLM.from_pretrained(
  23. model_id,
  24. device_map="auto",
  25. torch_dtype=torch.bfloat16,
  26. quantization_config=bnb_config,
  27. )
  28. tokenizer = AutoTokenizer.from_pretrained(model_id)
  29. tokenizer.padding_side = "right"
  30. if tokenizer.pad_token is None:
  31. tokenizer.add_special_tokens({"pad_token": "[PAD]"})
  32. model.resize_token_embeddings(len(tokenizer))
  33. peft_config = LoraConfig(
  34. lora_alpha=128,
  35. lora_dropout=0.05,
  36. r=256,
  37. bias="none",
  38. target_modules="all-linear",
  39. task_type="CAUSAL_LM",
  40. )
  41. args = TrainingArguments(
  42. output_dir="llama31-8b-text2sql-epochs-20", # directory to save and repository id
  43. num_train_epochs=20, # number of training epochs
  44. per_device_train_batch_size=3, # batch size per device during training
  45. gradient_accumulation_steps=2, # number of steps before performing a backward/update pass
  46. gradient_checkpointing=True, # use gradient checkpointing to save memory
  47. optim="adamw_torch_fused", # use fused adamw optimizer
  48. logging_steps=10, # log every 10 steps
  49. save_strategy="epoch", # save checkpoint every epoch
  50. learning_rate=2e-4, # learning rate, based on QLoRA paper
  51. bf16=True, # use bfloat16 precision
  52. tf32=True, # use tf32 precision
  53. max_grad_norm=0.3, # max gradient norm based on QLoRA paper
  54. warmup_ratio=0.03, # warmup ratio based on QLoRA paper
  55. lr_scheduler_type="constant", # use constant learning rate scheduler
  56. push_to_hub=True, # push model to hub
  57. report_to="tensorboard", # report metrics to tensorboard
  58. )
  59. max_seq_length = 4096
  60. trainer = SFTTrainer(
  61. model=model,
  62. args=args,
  63. train_dataset=dataset,
  64. max_seq_length=max_seq_length,
  65. tokenizer=tokenizer,
  66. peft_config=peft_config,
  67. packing=True,
  68. )
  69. trainer.train()
  70. trainer.save_model()