Bläddra i källkod

adding support for FSDP+Qlora

HamidShojanazeri 1 år sedan
förälder
incheckning
a42e0c0bdf
1 ändrade filer med 27 tillägg och 43 borttagningar
  1. 27 43
      src/llama_recipes/finetuning.py

+ 27 - 43
src/llama_recipes/finetuning.py

@@ -1,6 +1,7 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+from collections import Counter
 import os
 
 import dataclasses
@@ -8,7 +9,7 @@ import fire
 import random
 import torch
 import torch.optim as optim
-from peft import get_peft_model, prepare_model_for_kbit_training, PeftModel
+from peft import get_peft_model, PeftModel
 from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
     ShardingStrategy
@@ -18,6 +19,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
 from transformers import (
     AutoTokenizer,
+    BitsAndBytesConfig,
     LlamaForCausalLM,
     LlamaConfig,
 )
@@ -25,6 +27,7 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
 from llama_recipes.configs import fsdp_config as FSDP_CONFIG
 from llama_recipes.configs import train_config as TRAIN_CONFIG
+from llama_recipes.configs.quantization import quantizatio_config 
 from llama_recipes.data.concatenator import ConcatDataset
 from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 
@@ -66,7 +69,6 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
     run.config.update(fsdp_config, allow_val_change=True)
     return run
 
-
 def main(**kwargs):
     # Update the configuration for the training and sharding process
     train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
@@ -97,38 +99,26 @@ def main(**kwargs):
     if train_config.use_wandb:
         if not train_config.enable_fsdp or rank==0:
             wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
+    
+    #setting quantization configs        
+    quant_config = quantizatio_config(
+        quant_type=train_config.quantization,
+        compute_dtype=torch.bfloat16,
+        use_double_quant=True,
+        quant_storage=torch.bfloat16
+    ) if train_config.quantization else None
+    bnb_config = quant_config.create_bnb_config()
 
     # Load the pre-trained model and setup its configuration
     use_cache = False if train_config.enable_fsdp else None
-    if train_config.enable_fsdp and train_config.low_cpu_fsdp:
-        """
-        for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
-        this avoids cpu oom when loading large models like llama 70B, in which case
-        model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
-        overhead and currently requires latest nightly.
-        """
-        if rank == 0:
-            model = LlamaForCausalLM.from_pretrained(
-                train_config.model_name,
-                load_in_8bit=True if train_config.quantization else None,
-                device_map="auto" if train_config.quantization else None,
-                use_cache=use_cache,
-                attn_implementation="sdpa" if train_config.use_fast_kernels else None,
-            )
-        else:
-            llama_config = LlamaConfig.from_pretrained(train_config.model_name)
-            llama_config.use_cache = use_cache
-            with torch.device("meta"):
-                model = LlamaForCausalLM(llama_config)
-
-    else:
-        model = LlamaForCausalLM.from_pretrained(
-            train_config.model_name,
-            load_in_8bit=True if train_config.quantization else None,
-            device_map="auto" if train_config.quantization else None,
-            use_cache=use_cache,
-            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
-        )
+    model = LlamaForCausalLM.from_pretrained(
+        train_config.model_name,
+        quantization_config=bnb_config,
+        use_cache=use_cache,
+        attn_implementation="sdpa" if train_config.use_fast_kernels else None,
+        device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
+        torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
+    )
 
     # Load the tokenizer and add special tokens
     tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
@@ -142,14 +132,10 @@ def main(**kwargs):
 
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
 
-    # Prepare the model for int8 training if quantization is enabled
-    if train_config.quantization:
-        model = prepare_model_for_kbit_training(model)
-
     # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
-    if train_config.enable_fsdp and fsdp_config.pure_bf16:
+    if train_config.enable_fsdp and fsdp_config.pure_bf16 and not train_config.quantization:
         model.to(torch.bfloat16)
-
+        
     if train_config.use_peft:
         # Load the pre-trained peft model checkpoint and setup its configuration
         if train_config.from_peft_checkpoint:
@@ -163,7 +149,6 @@ def main(**kwargs):
             wandb_run.config.update(peft_config)
         model.print_trainable_parameters()
 
-
     hsdp_device_mesh = None
     if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
         hsdp_device_mesh = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
@@ -182,7 +167,6 @@ def main(**kwargs):
             device_id = torch.xpu.current_device()
         elif torch.cuda.is_available():
             device_id = torch.cuda.current_device()
-
         model = FSDP(
             model,
             auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
@@ -196,8 +180,10 @@ def main(**kwargs):
             param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False))
             if train_config.low_cpu_fsdp and rank != 0 else None,
         )
-        if fsdp_config.fsdp_activation_checkpointing:
-            apply_fsdp_checkpointing(model)
+        if fsdp_config.fsdp_activation_checkpointing:            
+            model.enable_input_require_grads()
+            model.gradient_checkpointing_enable()
+            apply_fsdp_checkpointing(model)                      
     elif not train_config.quantization and not train_config.enable_fsdp:
         if is_xpu_available():
             model.to("xpu:0")
@@ -212,7 +198,6 @@ def main(**kwargs):
         dataset_config,
         split="train",
     )
-
     if not train_config.enable_fsdp or rank == 0:
         print(f"--> Training Set Length = {len(dataset_train)}")
 
@@ -272,7 +257,6 @@ def main(**kwargs):
             weight_decay=train_config.weight_decay,
         )
     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
-
     # Start the training process
     results = train(
         model,