quantization.py 1.3 KB

12345678910111213141516171819202122232425262728293031
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from dataclasses import dataclass
  4. from typing import Optional
  5. import torch
  6. from transformers import BitsAndBytesConfig
  7. @dataclass
  8. class quantizatio_config:
  9. quant_type: str # "int4" or "int8"
  10. compute_dtype: torch.dtype
  11. use_double_quant: bool = False
  12. quant_storage: torch.dtype = torch.bfloat16
  13. def create_bnb_config(self) -> BitsAndBytesConfig:
  14. if self.quant_type not in {"int4", "int8"}:
  15. raise ValueError("quant_type must be either 'int4' or 'int8'")
  16. config_params = {
  17. "bnb_4bit_quant_type" if self.quant_type == "int4" else "bnb_8bit_quant_type": self.quant_type,
  18. "bnb_4bit_compute_dtype" if self.quant_type == "int4" else "bnb_8bit_compute_dtype": self.compute_dtype,
  19. "bnb_4bit_use_double_quant" if self.quant_type == "int4" else "bnb_8bit_use_double_quant": self.use_double_quant,
  20. "bnb_4bit_quant_storage" if self.quant_type == "int4" else "bnb_8bit_quant_storage": self.quant_storage,
  21. }
  22. if self.quant_type == "int4":
  23. return BitsAndBytesConfig(load_in_4bit=True, **config_params)
  24. else:
  25. return BitsAndBytesConfig(load_in_8bit=True, **config_params)