|
@@ -1,9 +1,10 @@
|
|
|
from dataclasses import dataclass, field
|
|
|
from typing import List
|
|
|
-
|
|
|
+import torch
|
|
|
+from transformers import BitsAndBytesConfig
|
|
|
|
|
|
@dataclass
|
|
|
-class QuantizationConfig:
|
|
|
+class quantizatio_config:
|
|
|
quant_type: str # "int4" or "int8"
|
|
|
compute_dtype: torch.dtype
|
|
|
use_double_quant: bool = False
|
|
@@ -11,16 +12,20 @@ class QuantizationConfig:
|
|
|
|
|
|
def create_bnb_config(self):
|
|
|
if self.quant_type == "int4":
|
|
|
- quant_type_str = "nf4"
|
|
|
+ return BitsAndBytesConfig(
|
|
|
+ load_in_4bit=True,
|
|
|
+ bnb_4bit_quant_type="nf4",
|
|
|
+ bnb_4bit_compute_dtype=self.compute_dtype,
|
|
|
+ bnb_4bit_use_double_quant=self.use_double_quant,
|
|
|
+ bnb_4bit_quant_storage=self.quant_storage
|
|
|
+ )
|
|
|
elif self.quant_type == "int8":
|
|
|
- quant_type_str = "int8"
|
|
|
+ return BitsAndBytesConfig(
|
|
|
+ load_in_8bit=True,
|
|
|
+ bnb_8bit_quant_type="int8",
|
|
|
+ bnb_8bit_compute_dtype=self.compute_dtype,
|
|
|
+ bnb_8bit_use_double_quant=self.use_double_quant,
|
|
|
+ bnb_8bit_quant_storage=self.quant_storage
|
|
|
+ )
|
|
|
else:
|
|
|
raise ValueError("quant_type must be either 'int4' or 'int8'")
|
|
|
-
|
|
|
- return BitsAndBytesConfig(
|
|
|
- load_in_4bit=(self.quant_type == "int4"),
|
|
|
- bnb_4bit_quant_type=quant_type_str,
|
|
|
- bnb_4bit_compute_dtype=self.compute_dtype,
|
|
|
- bnb_4bit_use_double_quant=self.use_double_quant,
|
|
|
- bnb_4bit_quant_storage=self.quant_storage
|
|
|
- )
|