Browse Source

adding optimizer overlap for FSDP

Hamid Shojanazeri 1 year ago
parent
commit
5d16d1a223
4 changed files with 46 additions and 2 deletions
  1. 8 0
      README.md
  2. 8 0
      docs/multi_gpu.md
  3. 2 1
      src/llama_recipes/configs/fsdp.py
  4. 28 1
      src/llama_recipes/finetuning.py

+ 8 - 0
README.md

@@ -144,6 +144,14 @@ Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memor
 torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels
 ```
 
+## FSDP optimizer overlap
+
+setting `optimizer_overlap` in [fsdp_config](./src/llama_recipes/configs/fsdp.py) enable optimizer overlap that lowers the GPU memory footprint during the training. The main idea here is the fusion of gradient calculation and parameter update in the one step. This is a new feature in FSDP and is only available on PyTorch nightly binaries for versions before 2.1.0.
+
+```bash
+torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model --optimizer_overlap
+```
+
 ### Fine-tuning using FSDP Only
 
 If you are interested in running full parameter fine-tuning without making use of PEFT methods, please use the following command. Make sure to change the `nproc_per_node` to your available GPUs. This has been tested with `BF16` on 8xA100, 40GB GPUs.

+ 8 - 0
docs/multi_gpu.md

@@ -46,6 +46,14 @@ Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memor
 torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model --use_fast_kernels
 ```
 
+## FSDP optimizer overlap
+
+setting `optimizer_overlap` in [fsdp_config](./src/llama_recipes/configs/fsdp.py) enable optimizer overlap that lowers the GPU memory footprint during the training. The main idea here is the fusion of gradient calculation and parameter update in the one step. This is a new feature in FSDP and is only available on PyTorch nightly binaries for versions before 2.1.0.
+
+```bash
+torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model --optimizer_overlap
+```
+
 ### Fine-tuning using FSDP Only
 
 If interested in running full parameter finetuning without making use of PEFT methods, please use the following command. Make sure to change the `nproc_per_node` to your available GPUs. This has been tested with `BF16` on 8xA100, 40GB GPUs.

+ 2 - 1
src/llama_recipes/configs/fsdp.py

@@ -16,4 +16,5 @@ class fsdp_config:
     fsdp_cpu_offload: bool=False
     pure_bf16: bool = False
     optimizer: str= "AdamW"
-    
+    optimizer_overlap: bool = False
+    

+ 28 - 1
src/llama_recipes/finetuning.py

@@ -64,7 +64,16 @@ def main(**kwargs):
         torch.cuda.set_device(local_rank)
         clear_gpu_cache(local_rank)
         setup_environ_flags(rank)
-
+    
+    #import _apply_optimizer_in_backward for FSDP optimizer overlap
+    optimizer_in_backward_available = False
+    if fsdp_config.optimizer_overlap:
+        try:
+            from torch.distributed.optim import _apply_optimizer_in_backward
+            optimizer_in_backward_available = True
+        except ImportError:
+            print("The required module for optimizer overlap in 'torch.distributed.optim' is not available.")
+            
     # Load the pre-trained model and setup its configuration
     use_cache = False if train_config.enable_fsdp else None
     if train_config.enable_fsdp and train_config.low_cpu_fsdp:
@@ -151,6 +160,7 @@ def main(**kwargs):
             device_id=torch.cuda.current_device(),
             limit_all_gathers=True,
             sync_module_states=train_config.low_cpu_fsdp,
+            use_orig_params = True if optimizer_in_backward_available else False,
             param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
             if train_config.low_cpu_fsdp and rank != 0 else None,
         )
@@ -228,12 +238,29 @@ def main(**kwargs):
             use_kahan_summation=False,
             weight_decay=train_config.weight_decay,
         )
+    elif optimizer_in_backward_available:
+        print(f"setting up optimizer overlap")
+        optim_kwargs = {"lr": train_config.lr}
+        _apply_optimizer_in_backward(
+            optimizer_class=optim.AdamW,
+            params=model.parameters(),
+            optimizer_kwargs=optim_kwargs,
+            register_hook=False,
+        )
+        for p in model.parameters():
+            assert hasattr(p, "_in_backward_optimizers")
+        optim_kwargs = {"lr": train_config.lr, "weight_decay":0.0}
+        optimizer = optim.AdamW(
+            model.parameters(),
+            **optim_kwargs
+        )
     else:
         optimizer = optim.AdamW(
             model.parameters(),
             lr=train_config.lr,
             weight_decay=train_config.weight_decay,
         )
+   
     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
 
     # Start the training process