Browse Source

Fix checkpoint saving (#650)

Matthias Reso 7 tháng trước cách đây
mục cha
commit
778e31e35c

+ 26 - 49
recipes/quickstart/finetuning/quickstart_peft_finetuning.ipynb

@@ -65,7 +65,7 @@
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "c7963d43806d432aaa3d00e2055e355c",
+       "model_id": "68838a4f42f84545912e95b339a31034",
        "version_major": 2,
        "version_minor": 0
       },
@@ -75,13 +75,6 @@
      },
      "metadata": {},
      "output_type": "display_data"
-    },
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
-     ]
     }
    ],
    "source": [
@@ -101,6 +94,7 @@
     "train_config.context_length = 1024 if torch.cuda.get_device_properties(0).total_memory < 16e9 else 2048 # T4 16GB or A10 24GB\n",
     "train_config.batching_strategy = \"packing\"\n",
     "train_config.output_dir = \"meta-llama-samsum\"\n",
+    "train_config.use_peft = True\n",
     "\n",
     "from transformers import BitsAndBytesConfig\n",
     "config = BitsAndBytesConfig(\n",
@@ -205,7 +199,7 @@
     "model_input = tokenizer(eval_prompt, return_tensors=\"pt\").to(\"cuda\")\n",
     "\n",
     "model.eval()\n",
-    "with torch.no_grad():\n",
+    "with torch.inference_mode():\n",
     "    print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))"
    ]
   },
@@ -230,34 +224,20 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/datasets/load.py:1486: FutureWarning: The repository for samsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/samsum\n",
-      "You can avoid this message in future by passing the argument `trust_remote_code=True`.\n",
-      "Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.\n",
-      "  warnings.warn(\n",
-      "Preprocessing dataset: 100%|██████████| 14732/14732 [00:02<00:00, 6124.69it/s]\n"
+      "/home/ubuntu/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead\n",
+      "  from torch.distributed._shard.checkpoint import (\n",
+      "Preprocessing dataset: 100%|██████████| 14732/14732 [00:02<00:00, 5872.02it/s]\n"
      ]
     }
    ],
    "source": [
     "from llama_recipes.configs.datasets import samsum_dataset\n",
-    "from llama_recipes.data.concatenator import ConcatDataset\n",
-    "from llama_recipes.utils.config_utils import get_dataloader_kwargs\n",
-    "from llama_recipes.utils.dataset_utils import get_preprocessed_dataset\n",
-    "\n",
-    "train_dataset = get_preprocessed_dataset(tokenizer, samsum_dataset, 'train')\n",
-    "\n",
-    "train_dl_kwargs = get_dataloader_kwargs(train_config, train_dataset, tokenizer, \"train\")\n",
+    "from llama_recipes.utils.dataset_utils import get_dataloader\n",
     "\n",
-    "if train_config.batching_strategy == \"packing\":\n",
-    "        train_dataset = ConcatDataset(train_dataset, chunk_size=train_config.context_length)\n",
+    "samsum_dataset.trust_remote_code = True\n",
     "\n",
-    "# Create DataLoaders for the training and validation dataset\n",
-    "train_dataloader = torch.utils.data.DataLoader(\n",
-    "    train_dataset,\n",
-    "    num_workers=train_config.num_workers_dataloader,\n",
-    "    pin_memory=True,\n",
-    "    **train_dl_kwargs,\n",
-    ")"
+    "train_dataloader = get_dataloader(tokenizer, samsum_dataset, train_config)\n",
+    "eval_dataloader = get_dataloader(tokenizer, samsum_dataset, train_config, \"val\")"
    ]
   },
   {
@@ -310,17 +290,23 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.\n",
+      "/home/ubuntu/llama-recipes/src/llama_recipes/utils/train_utils.py:92: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n",
+      "  scaler = torch.cuda.amp.GradScaler()\n",
+      "/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.\n",
       "  warnings.warn(\n",
       "Training Epoch: 1:   0%|\u001b[34m          \u001b[0m| 0/319 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
       "To disable this warning, you can either:\n",
       "\t- Avoid using `tokenizers` before the fork if possible\n",
       "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
-      "/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
-      "  warnings.warn(\n",
+      "/home/ubuntu/llama-recipes/src/llama_recipes/utils/train_utils.py:151: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
+      "  with autocast():\n",
+      "/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:600: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
+      "  return fn(*args, **kwargs)\n",
       "/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
       "  warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
-      "Training Epoch: 1/1, step 1278/1279 completed (loss: 0.27870458364486694): : 320it [2:07:09, 23.84s/it]                      3.94s/it]  \n"
+      "/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n",
+      "  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]\n",
+      "Training Epoch: 1/1, step 1278/1279 completed (loss: 0.28094857931137085): : 320it [2:08:50, 24.16s/it]                      4.21s/it]  \n"
      ]
     },
     {
@@ -332,7 +318,7 @@
       "Peak active CUDA memory was 15 GB\n",
       "CUDA Malloc retries : 0\n",
       "CPU Total Peak Memory consumed during the train (max): 2 GB\n",
-      "Epoch 1: train_perplexity=1.3403, train_epoch_loss=0.2929, epoch time 7630.169942979002s\n"
+      "Epoch 1: train_perplexity=1.3404, train_epoch_loss=0.2930, epoch time 7730.981359725998s\n"
      ]
     }
    ],
@@ -354,7 +340,7 @@
     "results = train(\n",
     "    model,\n",
     "    train_dataloader,\n",
-    "    None,\n",
+    "    eval_dataloader,\n",
     "    tokenizer,\n",
     "    optimizer,\n",
     "    scheduler,\n",
@@ -380,16 +366,7 @@
    "cell_type": "code",
    "execution_count": 7,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
-      "  warnings.warn(\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "model.save_pretrained(train_config.output_dir)"
    ]
@@ -440,13 +417,13 @@
       "A: He said he’d name it after his dead hamster – Lemmy  - he's  a great Motorhead fan :-)))\n",
       "---\n",
       "Summary:\n",
-      "A wants to get a puppy for her son. She will take him to the animal shelter tomorrow. B is not sure if he can go with her, but he's willing to.\n"
+      "A wants to get a puppy for his son. A took him to the animal shelter last Monday and he showed A one he really liked. A wants to get him one of those little dogs. A and B agree that raising a dog is a tough issue.\n"
      ]
     }
    ],
    "source": [
     "model.eval()\n",
-    "with torch.no_grad():\n",
+    "with torch.inference_mode():\n",
     "    print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))\n"
    ]
   }
@@ -467,7 +444,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.10.14"
+   "version": "3.11.9"
   },
   "vscode": {
    "interpreter": {

+ 1 - 1
requirements.txt

@@ -22,7 +22,7 @@ tabulate
 evaluate
 rouge_score
 pyyaml==6.0.1
-faiss-gpu
+faiss-gpu; python_version < '3.11'
 unstructured[pdf]
 langchain_openai
 langchain

+ 1 - 1
src/llama_recipes/configs/fsdp.py

@@ -14,7 +14,7 @@ class fsdp_config:
     hsdp : bool =False # Require HYBRID_SHARD to be set. This flag can extend the HYBRID_SHARD by allowing sharding a model on customized number of GPUs (Sharding_group) and Replicas over Sharding_group.
     sharding_group_size : int=0 # requires hsdp to be set. This specifies the sharding group size, number of GPUs that you model can fit into to form a replica of a model.
     replica_group_size: int=0 #requires hsdp to be set. This specifies the replica group size, which is world_size/sharding_group_size.
-    checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT  # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
+    checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT  # alternatively FULL_STATE_DICT can be used. SHARDED_STATE_DICT saves one file with sharded weights per rank while FULL_STATE_DICT will collect all weights on rank 0 and save them in a single file.
     fsdp_activation_checkpointing: bool=True
     fsdp_cpu_offload: bool=False
     pure_bf16: bool = False

+ 12 - 1
src/llama_recipes/datasets/__init__.py

@@ -1,7 +1,18 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+from functools import partial
+
 from llama_recipes.datasets.grammar_dataset.grammar_dataset import get_dataset as get_grammar_dataset
 from llama_recipes.datasets.alpaca_dataset import InstructionDataset as get_alpaca_dataset
+from llama_recipes.datasets.custom_dataset import get_custom_dataset
 from llama_recipes.datasets.samsum_dataset import get_preprocessed_samsum as get_samsum_dataset
-from llama_recipes.datasets.toxicchat_dataset import get_llamaguard_toxicchat_dataset as get_llamaguard_toxicchat_dataset
+from llama_recipes.datasets.toxicchat_dataset import get_llamaguard_toxicchat_dataset as get_llamaguard_toxicchat_dataset
+
+DATASET_PREPROC = {
+    "alpaca_dataset": partial(get_alpaca_dataset),
+    "grammar_dataset": get_grammar_dataset,
+    "samsum_dataset": get_samsum_dataset,
+    "custom_dataset": get_custom_dataset,
+    "llamaguard_toxicchat_dataset": get_llamaguard_toxicchat_dataset,
+}

+ 3 - 0
src/llama_recipes/finetuning.py

@@ -37,6 +37,7 @@ from llama_recipes.utils.config_utils import (
     generate_peft_config,
     generate_dataset_config,
     get_dataloader_kwargs,
+    check_fsdp_config,
 )
 from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
 
@@ -162,6 +163,8 @@ def main(**kwargs):
 
     #setting up FSDP if enable_fsdp is enabled
     if train_config.enable_fsdp:
+        check_fsdp_config(fsdp_config)
+        
         if not train_config.use_peft and train_config.freeze_layers:
             freeze_transformer_layers(model, train_config.num_freeze_layers)
 

+ 2 - 1
src/llama_recipes/model_checkpointing/__init__.py

@@ -3,8 +3,9 @@
 
 from llama_recipes.model_checkpointing.checkpoint_handler import (
     load_model_checkpoint,
-    save_model_checkpoint,
+    save_fsdp_model_checkpoint_full,
     save_peft_checkpoint,
+    save_model_checkpoint,
     load_optimizer_checkpoint,
     save_optimizer_checkpoint,
     save_model_and_optimizer_sharded,

+ 19 - 5
src/llama_recipes/model_checkpointing/checkpoint_handler.py

@@ -123,7 +123,7 @@ def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
         print(
             f"Checkpoint Time = {t1-t0:.4f}\n"
         )
-def save_model_checkpoint(
+def save_fsdp_model_checkpoint_full(
     model,
     optimizer,
     rank,
@@ -152,7 +152,7 @@ def save_model_checkpoint(
         )
         save_dir = Path.cwd() / folder_name
         save_dir.mkdir(parents=True, exist_ok=True)
-        save_name = cfg.model_name + "-" + str(epoch) + ".pt"
+        save_name = cfg.model_name.replace("/","--") + "-" + str(epoch) + ".pt"
         save_full_path = str(save_dir) + "/" + save_name
 
         # save model
@@ -271,6 +271,20 @@ def save_peft_checkpoint(model, model_path):
     """save_pretrained peft model"""
 
     options = StateDictOptions(full_state_dict=True, cpu_offload=True)
-
-    state_dict = get_model_state_dict(model, options=options)
-    model.save_pretrained(model_path, state_dict=state_dict)
+    
+    if isinstance(model, FSDP):
+        state_dict = get_model_state_dict(model, options=options)
+        model.save_pretrained(model_path, state_dict=state_dict)
+    else:
+        model.save_pretrained(model_path)
+    
+    
+def save_model_checkpoint(model, output_dir):
+    """save model when not peft and on single device"""
+    
+    output_file = Path(output_dir) / "model.pt"
+    
+    state_dict = model.state_dict()
+    
+    torch.save(state_dict, output_file)
+    

+ 16 - 0
src/llama_recipes/utils/config_utils.py

@@ -5,6 +5,7 @@ import inspect
 from dataclasses import asdict
 
 import torch.distributed as dist
+from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
 from torch.utils.data import DistributedSampler
 from peft import (
     LoraConfig,
@@ -106,3 +107,18 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
             raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
 
         return kwargs
+
+
+def check_fsdp_config(fsdp_config):
+    VALID_TYPES = (StateDictType.SHARDED_STATE_DICT, StateDictType.FULL_STATE_DICT)
+    if isinstance(fsdp_config.checkpoint_type, str):
+        str_to_obj = {
+            "StateDictType.SHARDED_STATE_DICT": StateDictType.SHARDED_STATE_DICT,
+            "StateDictType.FULL_STATE_DICT": StateDictType.FULL_STATE_DICT,
+        }
+        if fsdp_config.checkpoint_type in str_to_obj:
+            fsdp_config.checkpoint_type = str_to_obj[fsdp_config.checkpoint_type]
+        
+    if not fsdp_config.checkpoint_type in VALID_TYPES:
+        raise ValueError(f"Invalid checkpoint_type {fsdp_config.checkpoint_type}")
+    

+ 21 - 55
src/llama_recipes/utils/dataset_utils.py

@@ -1,63 +1,11 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
-import importlib
-from functools import partial
-from pathlib import Path
-
 import torch
 
-from llama_recipes.datasets import (
-    get_grammar_dataset,
-    get_alpaca_dataset,
-    get_samsum_dataset,
-    get_llamaguard_toxicchat_dataset,
-)
-
-
-def load_module_from_py_file(py_file: str) -> object:
-    """
-    This method loads a module from a py file which is not in the Python path
-    """
-    module_name = Path(py_file).name
-    loader = importlib.machinery.SourceFileLoader(module_name, py_file)
-    spec = importlib.util.spec_from_loader(module_name, loader)
-    module = importlib.util.module_from_spec(spec)
-
-    loader.exec_module(module)
-
-    return module
-
-
-def get_custom_dataset(dataset_config, tokenizer, split: str):
-    if ":" in dataset_config.file:
-        module_path, func_name = dataset_config.file.split(":")
-    else:
-        module_path, func_name = dataset_config.file, "get_custom_dataset"
-
-    if not module_path.endswith(".py"):
-        raise ValueError(f"Dataset file {module_path} is not a .py file.")
-
-    module_path = Path(module_path)
-    if not module_path.is_file():
-        raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
-
-    module = load_module_from_py_file(module_path.as_posix())
-    try:
-        return getattr(module, func_name)(dataset_config, tokenizer, split)
-    except AttributeError as e:
-        print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
-        raise e
-
-
-DATASET_PREPROC = {
-    "alpaca_dataset": partial(get_alpaca_dataset),
-    "grammar_dataset": get_grammar_dataset,
-    "samsum_dataset": get_samsum_dataset,
-    "custom_dataset": get_custom_dataset,
-    "llamaguard_toxicchat_dataset": get_llamaguard_toxicchat_dataset,
-
-}
+from llama_recipes.data.concatenator import ConcatDataset
+from llama_recipes.datasets import DATASET_PREPROC, get_custom_dataset
+from llama_recipes.utils.config_utils import get_dataloader_kwargs
 
 
 def get_preprocessed_dataset(
@@ -78,3 +26,21 @@ def get_preprocessed_dataset(
         tokenizer,
         get_split(),
     )
+
+
+def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"):
+    dataset = get_preprocessed_dataset(tokenizer, dataset_config, split)
+    dl_kwargs = get_dataloader_kwargs(train_config, dataset, tokenizer, split)
+    
+    if split == "train" and train_config.batching_strategy == "packing":
+        dataset = ConcatDataset(dataset, chunk_size=train_config.context_length)
+
+    # Create data loader
+    dataloader = torch.utils.data.DataLoader(
+        dataset,
+        num_workers=train_config.num_workers_dataloader,
+        pin_memory=True,
+        **dl_kwargs,
+    )
+    return dataloader
+    

+ 22 - 14
src/llama_recipes/utils/train_utils.py

@@ -20,7 +20,7 @@ from transformers import LlamaTokenizer
 import json
 
 
-from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint, save_peft_checkpoint
+from llama_recipes.model_checkpointing import save_fsdp_model_checkpoint_full, save_model_and_optimizer_sharded, save_optimizer_checkpoint, save_peft_checkpoint, save_model_checkpoint
 from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
 from llama_recipes.utils.memory_utils import MemoryTrace
 from accelerate.utils import is_xpu_available, is_ccl_available
@@ -243,27 +243,35 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         print(f"PEFT modules are saved in {train_config.output_dir} directory")
 
                 else:
-                    if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
-
-                        save_model_checkpoint(
+                    if not train_config.enable_fsdp:
+                        save_model_checkpoint(model, train_config.output_dir)
+                        
+                    elif fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
+                        print(" Saving the FSDP model checkpoint using FULL_STATE_DICT")
+                        print("=====================================================")
+                        save_fsdp_model_checkpoint_full(
                             model, optimizer, rank, train_config, epoch=epoch
                         )
-                    elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
-                        print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
-                        print("=====================================================")
+                        
+                        if train_config.save_optimizer:
+                            print(" Saving the FSDP optimizer using FULL_STATE_DICT")
+                            print("=====================================================")
+                            save_optimizer_checkpoint(
+                                model, optimizer, rank, train_config, epoch=epoch
+                            )
+                        
+                    elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
 
-                        save_model_and_optimizer_sharded(model, rank, train_config)
                         if train_config.save_optimizer:
+                            print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
+                            print("=====================================================")
                             save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
+                        else:
                             print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
                             print("=====================================================")
+                            save_model_and_optimizer_sharded(model, rank, train_config)
 
-                    if not train_config.use_peft and  train_config.save_optimizer:
-                        save_optimizer_checkpoint(
-                            model, optimizer, rank, train_config, epoch=epoch
-                        )
-                        print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
-                        print("=====================================================")
+                        
                 if train_config.enable_fsdp:
                     dist.barrier()
             checkpoint_end_time = time.perf_counter() - checkpoint_start_time