12345678910111213141516171819202122232425262728293031 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- from dataclasses import dataclass
- from typing import Optional
- import torch
- from transformers import BitsAndBytesConfig
- @dataclass
- class quantizatio_config:
- quant_type: str # "int4" or "int8"
- compute_dtype: torch.dtype
- use_double_quant: bool = False
- quant_storage: torch.dtype = torch.bfloat16
- def create_bnb_config(self) -> BitsAndBytesConfig:
- if self.quant_type not in {"int4", "int8"}:
- raise ValueError("quant_type must be either 'int4' or 'int8'")
- config_params = {
- "bnb_4bit_quant_type" if self.quant_type == "int4" else "bnb_8bit_quant_type": self.quant_type,
- "bnb_4bit_compute_dtype" if self.quant_type == "int4" else "bnb_8bit_compute_dtype": self.compute_dtype,
- "bnb_4bit_use_double_quant" if self.quant_type == "int4" else "bnb_8bit_use_double_quant": self.use_double_quant,
- "bnb_4bit_quant_storage" if self.quant_type == "int4" else "bnb_8bit_quant_storage": self.quant_storage,
- }
- if self.quant_type == "int4":
- return BitsAndBytesConfig(load_in_4bit=True, **config_params)
- else:
- return BitsAndBytesConfig(load_in_8bit=True, **config_params)
|