12345678910111213141516171819202122232425262728293031 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- from dataclasses import dataclass
- from typing import Optional
- import torch
- from transformers import BitsAndBytesConfig
- @dataclass
- class quantization_config:
- quant_type: str = "fp4" # "fp4" or "nf4"
- compute_dtype: torch.dtype = torch.bfloat16
- use_double_quant: bool = False
- quant_storage: torch.dtype = torch.bfloat16
- def create_bnb_config(self, quantization: str) -> BitsAndBytesConfig:
- if quantization not in {"4bit", "8bit"}:
- raise ValueError("quantization must be either '4bit' or '8bit'")
- if quantization == "4bit":
- config_params = {
- "bnb_4bit_quant_type": self.quant_type,
- "bnb_4bit_compute_dtype": self.compute_dtype,
- "bnb_4bit_use_double_quant": self.use_double_quant,
- "bnb_4bit_quant_storage": self.quant_storage,
- }
-
- return BitsAndBytesConfig(load_in_4bit=True, **config_params)
- else:
- return BitsAndBytesConfig(load_in_8bit=True)
|