model_utils.py 1.1 KB

1234567891011121314151617181920212223242526272829303132
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the GNU General Public License version 3.
  3. from peft import PeftModel
  4. from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaConfig
  5. # Function to load the main model for text generation
  6. def load_model(model_name, quantization, use_fast_kernels):
  7. print(f"use_fast_kernels{use_fast_kernels}")
  8. model = AutoModelForCausalLM.from_pretrained(
  9. model_name,
  10. return_dict=True,
  11. load_in_8bit=quantization,
  12. device_map="auto",
  13. low_cpu_mem_usage=True,
  14. attn_implementation="sdpa" if use_fast_kernels else None,
  15. )
  16. return model
  17. # Function to load the PeftModel for performance optimization
  18. def load_peft_model(model, peft_model):
  19. peft_model = PeftModel.from_pretrained(model, peft_model)
  20. return peft_model
  21. # Loading the model from config to load FSDP checkpoints into that
  22. def load_llama_from_config(config_path):
  23. model_config = LlamaConfig.from_pretrained(config_path)
  24. model = LlamaForCausalLM(config=model_config)
  25. return model