quantization.py 1.2 KB

12345678910111213141516171819202122232425262728293031
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from dataclasses import dataclass
  4. from typing import Optional
  5. import torch
  6. from transformers import BitsAndBytesConfig
  7. @dataclass
  8. class quantization_config:
  9. quant_type: str = "fp4" # "fp4" or "nf4"
  10. compute_dtype: torch.dtype = torch.bfloat16
  11. use_double_quant: bool = False
  12. quant_storage: torch.dtype = torch.bfloat16
  13. def create_bnb_config(self, quantization: str) -> BitsAndBytesConfig:
  14. if quantization not in {"4bit", "8bit"}:
  15. raise ValueError("quantization must be either '4bit' or '8bit'")
  16. if quantization == "4bit":
  17. config_params = {
  18. "bnb_4bit_quant_type": self.quant_type,
  19. "bnb_4bit_compute_dtype": self.compute_dtype,
  20. "bnb_4bit_use_double_quant": self.use_double_quant,
  21. "bnb_4bit_quant_storage": self.quant_storage,
  22. }
  23. return BitsAndBytesConfig(load_in_4bit=True, **config_params)
  24. else:
  25. return BitsAndBytesConfig(load_in_8bit=True)