| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- import torch
- from peft import PeftModel
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
- peft_model_path = "../fine-tuning/final_test/llama31-8b-text2sql-peft-nonquantized-cot"
- output_dir = (
- "../fine-tuning/final_test/llama31-8b-text2sql-peft-nonquantized-cot_merged"
- )
- # === Load Base Model and Tokenizer ===
- print("Loading base model and tokenizer...")
- base_model_id = "meta-llama/Llama-3.1-8B-Instruct"
- tokenizer = AutoTokenizer.from_pretrained(base_model_id)
- # Configure quantization if needed
- quantization_config = None
- use_quantized = False
- if use_quantized:
- quantization_config = BitsAndBytesConfig(
- load_in_4bit=True,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- bnb_4bit_compute_dtype=torch.bfloat16,
- )
- # Load model
- base_model = AutoModelForCausalLM.from_pretrained(
- base_model_id,
- device_map="auto",
- torch_dtype=torch.bfloat16,
- quantization_config=quantization_config,
- )
- base_model.resize_token_embeddings(128257)
- # === Load PEFT Adapter and Merge ===
- print("Loading PEFT adapter and merging...")
- # peft_config = PeftConfig.from_pretrained(peft_model_path)
- model = PeftModel.from_pretrained(base_model, peft_model_path)
- model = model.merge_and_unload() # This merges the adapter weights into the base model
- # === Save the Merged Model ===
- print(f"Saving merged model to {output_dir} ...")
- model.save_pretrained(output_dir)
- tokenizer.save_pretrained(output_dir)
- print("Done! The merged model is ready for vLLM serving.")
|