| 12345678910111213141516171819202122232425262728293031 | # Copyright (c) Meta Platforms, Inc. and affiliates.# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.from dataclasses import dataclassfrom typing import Optionalimport torchfrom transformers import BitsAndBytesConfig@dataclassclass quantization_config:    quant_type: str =  "fp4" # "fp4" or "nf4"    compute_dtype: torch.dtype = torch.bfloat16    use_double_quant: bool = False    quant_storage: torch.dtype = torch.bfloat16    def create_bnb_config(self, quantization: str) -> BitsAndBytesConfig:        if quantization not in {"4bit", "8bit"}:            raise ValueError("quantization must be either '4bit' or '8bit'")        if quantization == "4bit":            config_params = {                "bnb_4bit_quant_type": self.quant_type,                "bnb_4bit_compute_dtype": self.compute_dtype,                "bnb_4bit_use_double_quant": self.use_double_quant,                "bnb_4bit_quant_storage": self.quant_storage,            }                        return BitsAndBytesConfig(load_in_4bit=True, **config_params)        else:            return BitsAndBytesConfig(load_in_8bit=True)
 |