| 12345678910111213141516171819202122232425262728293031 | 
							- # Copyright (c) Meta Platforms, Inc. and affiliates.
 
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
- from dataclasses import dataclass
 
- from typing import Optional
 
- import torch
 
- from transformers import BitsAndBytesConfig
 
- @dataclass
 
- class quantization_config:
 
-     quant_type: str =  "fp4" # "fp4" or "nf4"
 
-     compute_dtype: torch.dtype = torch.bfloat16
 
-     use_double_quant: bool = False
 
-     quant_storage: torch.dtype = torch.bfloat16
 
-     def create_bnb_config(self, quantization: str) -> BitsAndBytesConfig:
 
-         if quantization not in {"4bit", "8bit"}:
 
-             raise ValueError("quantization must be either '4bit' or '8bit'")
 
-         if quantization == "4bit":
 
-             config_params = {
 
-                 "bnb_4bit_quant_type": self.quant_type,
 
-                 "bnb_4bit_compute_dtype": self.compute_dtype,
 
-                 "bnb_4bit_use_double_quant": self.use_double_quant,
 
-                 "bnb_4bit_quant_storage": self.quant_storage,
 
-             }
 
-             
 
-             return BitsAndBytesConfig(load_in_4bit=True, **config_params)
 
-         else:
 
-             return BitsAndBytesConfig(load_in_8bit=True)
 
 
  |