|
@@ -1,5 +1,8 @@
|
|
|
-from dataclasses import dataclass, field
|
|
|
-from typing import List
|
|
|
+# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
+
|
|
|
+from dataclasses import dataclass
|
|
|
+from typing import Optional
|
|
|
import torch
|
|
|
from transformers import BitsAndBytesConfig
|
|
|
|
|
@@ -10,22 +13,18 @@ class quantizatio_config:
|
|
|
use_double_quant: bool = False
|
|
|
quant_storage: torch.dtype = torch.bfloat16
|
|
|
|
|
|
- def create_bnb_config(self):
|
|
|
+ def create_bnb_config(self) -> BitsAndBytesConfig:
|
|
|
+ if self.quant_type not in {"int4", "int8"}:
|
|
|
+ raise ValueError("quant_type must be either 'int4' or 'int8'")
|
|
|
+
|
|
|
+ config_params = {
|
|
|
+ "bnb_4bit_quant_type" if self.quant_type == "int4" else "bnb_8bit_quant_type": self.quant_type,
|
|
|
+ "bnb_4bit_compute_dtype" if self.quant_type == "int4" else "bnb_8bit_compute_dtype": self.compute_dtype,
|
|
|
+ "bnb_4bit_use_double_quant" if self.quant_type == "int4" else "bnb_8bit_use_double_quant": self.use_double_quant,
|
|
|
+ "bnb_4bit_quant_storage" if self.quant_type == "int4" else "bnb_8bit_quant_storage": self.quant_storage,
|
|
|
+ }
|
|
|
+
|
|
|
if self.quant_type == "int4":
|
|
|
- return BitsAndBytesConfig(
|
|
|
- load_in_4bit=True,
|
|
|
- bnb_4bit_quant_type="nf4",
|
|
|
- bnb_4bit_compute_dtype=self.compute_dtype,
|
|
|
- bnb_4bit_use_double_quant=self.use_double_quant,
|
|
|
- bnb_4bit_quant_storage=self.quant_storage
|
|
|
- )
|
|
|
- elif self.quant_type == "int8":
|
|
|
- return BitsAndBytesConfig(
|
|
|
- load_in_8bit=True,
|
|
|
- bnb_8bit_quant_type="int8",
|
|
|
- bnb_8bit_compute_dtype=self.compute_dtype,
|
|
|
- bnb_8bit_use_double_quant=self.use_double_quant,
|
|
|
- bnb_8bit_quant_storage=self.quant_storage
|
|
|
- )
|
|
|
+ return BitsAndBytesConfig(load_in_4bit=True, **config_params)
|
|
|
else:
|
|
|
- raise ValueError("quant_type must be either 'int4' or 'int8'")
|
|
|
+ return BitsAndBytesConfig(load_in_8bit=True, **config_params)
|