|
@@ -47,6 +47,7 @@ from llama_recipes.utils.train_utils import (
|
|
|
clear_gpu_cache,
|
|
|
print_model_size,
|
|
|
get_policies,
|
|
|
+ set_quantization_settings
|
|
|
)
|
|
|
from accelerate.utils import is_xpu_available
|
|
|
|
|
@@ -101,6 +102,9 @@ def main(**kwargs):
|
|
|
|
|
|
# Load the pre-trained model and setup its configuration
|
|
|
use_cache = False if train_config.enable_fsdp else None
|
|
|
+ if train_config.quantization:
|
|
|
+ bnb_config = set_quantization_settings(train_config)
|
|
|
+
|
|
|
if train_config.enable_fsdp and train_config.low_cpu_fsdp:
|
|
|
"""
|
|
|
for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
|
|
@@ -108,15 +112,15 @@ def main(**kwargs):
|
|
|
model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
|
|
|
overhead and currently requires latest nightly.
|
|
|
"""
|
|
|
- v = packaging.version.parse(torch.__version__)
|
|
|
- verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
|
|
|
- if not verify_latest_nightly:
|
|
|
- raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
|
|
|
- "please install latest nightly.")
|
|
|
+ # v = packaging.version.parse(torch.__version__)
|
|
|
+ # verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
|
|
|
+ # if not verify_latest_nightly:
|
|
|
+ # raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
|
|
|
+ # "please install latest nightly.")
|
|
|
if rank == 0:
|
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
|
train_config.model_name,
|
|
|
- load_in_8bit=True if train_config.quantization else None,
|
|
|
+ quantization_config=bnb_config,
|
|
|
device_map="auto" if train_config.quantization else None,
|
|
|
use_cache=use_cache,
|
|
|
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
|
|
@@ -130,7 +134,7 @@ def main(**kwargs):
|
|
|
else:
|
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
|
train_config.model_name,
|
|
|
- load_in_8bit=True if train_config.quantization else None,
|
|
|
+ quantization_config=bnb_config,
|
|
|
device_map="auto" if train_config.quantization else None,
|
|
|
use_cache=use_cache,
|
|
|
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
|