merge_peft.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import torch
  2. from peft import PeftModel
  3. from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
  4. peft_model_path = "../fine-tuning/final_test/llama31-8b-text2sql-peft-nonquantized-cot"
  5. output_dir = (
  6. "../fine-tuning/final_test/llama31-8b-text2sql-peft-nonquantized-cot_merged"
  7. )
  8. # === Load Base Model and Tokenizer ===
  9. print("Loading base model and tokenizer...")
  10. base_model_id = "meta-llama/Llama-3.1-8B-Instruct"
  11. tokenizer = AutoTokenizer.from_pretrained(base_model_id)
  12. # Configure quantization if needed
  13. quantization_config = None
  14. use_quantized = False
  15. if use_quantized:
  16. quantization_config = BitsAndBytesConfig(
  17. load_in_4bit=True,
  18. bnb_4bit_use_double_quant=True,
  19. bnb_4bit_quant_type="nf4",
  20. bnb_4bit_compute_dtype=torch.bfloat16,
  21. )
  22. # Load model
  23. base_model = AutoModelForCausalLM.from_pretrained(
  24. base_model_id,
  25. device_map="auto",
  26. torch_dtype=torch.bfloat16,
  27. quantization_config=quantization_config,
  28. )
  29. base_model.resize_token_embeddings(128257)
  30. # === Load PEFT Adapter and Merge ===
  31. print("Loading PEFT adapter and merging...")
  32. # peft_config = PeftConfig.from_pretrained(peft_model_path)
  33. model = PeftModel.from_pretrained(base_model, peft_model_path)
  34. model = model.merge_and_unload() # This merges the adapter weights into the base model
  35. # === Save the Merged Model ===
  36. print(f"Saving merged model to {output_dir} ...")
  37. model.save_pretrained(output_dir)
  38. tokenizer.save_pretrained(output_dir)
  39. print("Done! The merged model is ready for vLLM serving.")