Browse Source

Adjust imports to package structure + cleaned up imports

Matthias Reso 1 year ago
parent
commit
cf678b9bf0

+ 2 - 2
src/llama_recipes/configs/fsdp.py

@@ -1,8 +1,8 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
-from dataclasses import dataclass, field
-from typing import ClassVar
+from dataclasses import dataclass
+
 from torch.distributed.fsdp import ShardingStrategy
 from torch.distributed.fsdp import ShardingStrategy
 from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
 from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
 
 

+ 1 - 1
src/llama_recipes/configs/peft.py

@@ -1,7 +1,7 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
-from dataclasses import dataclass, field
+from dataclasses import dataclass
 from typing import ClassVar, List
 from typing import ClassVar, List
 
 
 @dataclass
 @dataclass

+ 1 - 1
src/llama_recipes/configs/training.py

@@ -1,7 +1,7 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
 from dataclasses import dataclass
 from dataclasses import dataclass
-from typing import ClassVar
 
 
 
 
 @dataclass
 @dataclass

+ 2 - 4
src/llama_recipes/datasets/alpaca_dataset.py

@@ -5,12 +5,10 @@
 
 
 import copy
 import copy
 import json
 import json
-import os
-import torch
 
 
-from sentencepiece import SentencePieceProcessor
+import torch
 from torch.utils.data import Dataset
 from torch.utils.data import Dataset
-from typing import List
+
 
 
 PROMPT_DICT = {
 PROMPT_DICT = {
     "prompt_input": (
     "prompt_input": (

+ 2 - 18
src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py

@@ -4,29 +4,13 @@
 # For dataset details visit: https://huggingface.co/datasets/jfleg
 # For dataset details visit: https://huggingface.co/datasets/jfleg
 # For download and preparation see: recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb
 # For download and preparation see: recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb
 
 
-import argparse
-import csv
-import glob
-import os
-import json
-import time
-import logging
-import random
-import re
-from itertools import chain
-from string import punctuation
-
-
-import pandas as pd
-import numpy as np
-import torch
-from torch.utils.data import Dataset
 
 
 from datasets import load_dataset
 from datasets import load_dataset
 from pathlib import Path
 from pathlib import Path
 
 
-from ft_datasets.utils import ConcatDataset
+from torch.utils.data import Dataset
 
 
+from ..utils import ConcatDataset
 
 
 
 
 class grammar(Dataset):
 class grammar(Dataset):

+ 1 - 0
src/llama_recipes/datasets/samsum_dataset.py

@@ -4,6 +4,7 @@
 # For dataset details visit: https://huggingface.co/datasets/samsum
 # For dataset details visit: https://huggingface.co/datasets/samsum
 
 
 import datasets
 import datasets
+
 from .utils import Concatenator
 from .utils import Concatenator
 
 
 def get_preprocessed_samsum(dataset_config, tokenizer, split):
 def get_preprocessed_samsum(dataset_config, tokenizer, split):

+ 1 - 0
src/llama_recipes/datasets/utils.py

@@ -3,6 +3,7 @@
 
 
 from tqdm import tqdm
 from tqdm import tqdm
 from itertools import chain
 from itertools import chain
+
 from torch.utils.data import Dataset
 from torch.utils.data import Dataset
 
 
 class Concatenator(object):
 class Concatenator(object):

+ 8 - 9
src/llama_recipes/finetuning.py

@@ -2,13 +2,13 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
 import os
 import os
+from pkg_resources import packaging
 
 
 import fire
 import fire
 import torch
 import torch
 import torch.distributed as dist
 import torch.distributed as dist
 import torch.optim as optim
 import torch.optim as optim
 from peft import get_peft_model, prepare_model_for_int8_training
 from peft import get_peft_model, prepare_model_for_int8_training
-from pkg_resources import packaging
 from torch.distributed.fsdp import (
 from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
     FullyShardedDataParallel as FSDP,
 )
 )
@@ -22,19 +22,18 @@ from transformers import (
 )
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
 
-import policies
-from configs import fsdp_config, train_config
-from policies import AnyPrecisionAdamW
+from .configs import fsdp_config, train_config
+from .policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 
 
-from utils import fsdp_auto_wrap_policy
-from utils.config_utils import (
+from .utils import fsdp_auto_wrap_policy
+from .utils.config_utils import (
     update_config,
     update_config,
     generate_peft_config,
     generate_peft_config,
     generate_dataset_config,
     generate_dataset_config,
 )
 )
-from utils.dataset_utils import get_preprocessed_dataset
+from .utils.dataset_utils import get_preprocessed_dataset
 
 
-from utils.train_utils import (
+from .utils.train_utils import (
     train,
     train,
     freeze_transformer_layers,
     freeze_transformer_layers,
     setup,
     setup,
@@ -153,7 +152,7 @@ def main(**kwargs):
             if train_config.low_cpu_fsdp and rank != 0 else None,
             if train_config.low_cpu_fsdp and rank != 0 else None,
         )
         )
         if fsdp_config.fsdp_activation_checkpointing:
         if fsdp_config.fsdp_activation_checkpointing:
-            policies.apply_fsdp_checkpointing(model)
+            apply_fsdp_checkpointing(model)
     elif not train_config.quantization and not train_config.enable_fsdp:
     elif not train_config.quantization and not train_config.enable_fsdp:
         model.to("cuda")
         model.to("cuda")
 
 

+ 6 - 6
src/llama_recipes/inference/chat_completion.py

@@ -2,18 +2,18 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
 # from accelerate import init_empty_weights, load_checkpoint_and_dispatch
 # from accelerate import init_empty_weights, load_checkpoint_and_dispatch
+
 import fire
 import fire
-import torch
 import os
 import os
 import sys
 import sys
-import warnings
 from typing import List
 from typing import List
 
 
-from peft import PeftModel, PeftConfig
-from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM
-from safety_utils import get_safety_checker
+import torch
 from model_utils import load_model, load_peft_model
 from model_utils import load_model, load_peft_model
-from chat_utils import read_dialogs_from_file, format_tokens
+from transformers import LlamaTokenizer
+from safety_utils import get_safety_checker
+
+from .chat_utils import read_dialogs_from_file, format_tokens
 
 
 def main(
 def main(
     model_name,
     model_name,

+ 2 - 1
src/llama_recipes/inference/chat_utils.py

@@ -1,8 +1,9 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
-from typing import List, Literal, Optional, Tuple, TypedDict, Union
 import json
 import json
+from typing import List, Literal, TypedDict
+
 
 
 Role = Literal["user", "assistant"]
 Role = Literal["user", "assistant"]
 
 

+ 4 - 2
src/llama_recipes/inference/checkpoint_converter_fsdp_hf.py

@@ -4,12 +4,14 @@
 # from accelerate import init_empty_weights, load_checkpoint_and_dispatch
 # from accelerate import init_empty_weights, load_checkpoint_and_dispatch
 
 
 import fire
 import fire
-import torch
 import os
 import os
 import sys
 import sys
 import yaml
 import yaml
+
 from transformers import LlamaTokenizer
 from transformers import LlamaTokenizer
-from model_utils import  load_llama_from_config
+
+from .model_utils import  load_llama_from_config
+
 # Get the current file's directory
 # Get the current file's directory
 current_directory = os.path.dirname(os.path.abspath(__file__))
 current_directory = os.path.dirname(os.path.abspath(__file__))
 
 

+ 5 - 4
src/llama_recipes/inference/inference.py

@@ -4,15 +4,16 @@
 # from accelerate import init_empty_weights, load_checkpoint_and_dispatch
 # from accelerate import init_empty_weights, load_checkpoint_and_dispatch
 
 
 import fire
 import fire
-import torch
 import os
 import os
 import sys
 import sys
 import time
 import time
-from typing import List
 
 
+import torch
 from transformers import LlamaTokenizer
 from transformers import LlamaTokenizer
-from safety_utils import get_safety_checker
-from model_utils import load_model, load_peft_model, load_llama_from_config
+
+from .safety_utils import get_safety_checker
+from .model_utils import load_model, load_peft_model
+
 
 
 def main(
 def main(
     model_name,
     model_name,

+ 0 - 2
src/llama_recipes/inference/safety_utils.py

@@ -5,8 +5,6 @@ import os
 import torch
 import torch
 import warnings
 import warnings
 
 
-from peft import PeftConfig
-from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM
 
 
 # Class for performing safety checks using AuditNLG library
 # Class for performing safety checks using AuditNLG library
 class AuditNLGSensitiveTopics(object):
 class AuditNLGSensitiveTopics(object):

+ 2 - 9
src/llama_recipes/inference/vLLM_inference.py

@@ -1,20 +1,13 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
-from accelerate import init_empty_weights, load_checkpoint_and_dispatch
 import fire
 import fire
+
 import torch
 import torch
-import os
-import sys
-from peft import PeftModel, PeftConfig
-from transformers import (
-    LlamaConfig,
-    LlamaTokenizer,
-    LlamaForCausalLM
-)
 from vllm import LLM
 from vllm import LLM
 from vllm import LLM, SamplingParams
 from vllm import LLM, SamplingParams
 
 
+
 torch.cuda.manual_seed(42)
 torch.cuda.manual_seed(42)
 torch.manual_seed(42)
 torch.manual_seed(42)
 
 

+ 2 - 6
src/llama_recipes/policies/activation_checkpointing_functions.py

@@ -1,18 +1,14 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
-import torch
-import os
-import torch.distributed as dist
+from functools import partial
+
 from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
 from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
     checkpoint_wrapper,
     checkpoint_wrapper,
     CheckpointImpl,
     CheckpointImpl,
     apply_activation_checkpointing,
     apply_activation_checkpointing,
 )
 )
-
-from transformers.models.t5.modeling_t5 import T5Block
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
-from functools import partial
 
 
 non_reentrant_wrapper = partial(
 non_reentrant_wrapper = partial(
     checkpoint_wrapper,
     checkpoint_wrapper,

+ 0 - 4
src/llama_recipes/policies/mixed_precision.py

@@ -4,11 +4,7 @@
 import torch
 import torch
 
 
 from torch.distributed.fsdp import (
 from torch.distributed.fsdp import (
-    # FullyShardedDataParallel as FSDP,
-    # CPUOffload,
     MixedPrecision,
     MixedPrecision,
-    # BackwardPrefetch,
-    # ShardingStrategy,
 )
 )
 
 
 # requires grad scaler in main loop
 # requires grad scaler in main loop

+ 1 - 15
src/llama_recipes/policies/wrapping.py

@@ -1,28 +1,14 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
-import torch.distributed as dist
-import torch.nn as nn
-import torch
+import functools
 
 
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
-
-from torch.distributed.fsdp.fully_sharded_data_parallel import (
-    FullyShardedDataParallel as FSDP,
-    CPUOffload,
-    BackwardPrefetch,
-    MixedPrecision,
-)
 from torch.distributed.fsdp.wrap import (
 from torch.distributed.fsdp.wrap import (
     transformer_auto_wrap_policy,
     transformer_auto_wrap_policy,
     size_based_auto_wrap_policy,
     size_based_auto_wrap_policy,
-    enable_wrap,
-    wrap,
 )
 )
 
 
-import functools
-from typing import Type
-
 
 
 def get_size_policy(min_params=1e8):
 def get_size_policy(min_params=1e8):
     num_wrap_policy = functools.partial(
     num_wrap_policy = functools.partial(

+ 2 - 2
src/llama_recipes/utils/config_utils.py

@@ -3,14 +3,14 @@
 
 
 import inspect
 import inspect
 from dataclasses import fields
 from dataclasses import fields
+
 from peft import (
 from peft import (
     LoraConfig,
     LoraConfig,
     AdaptionPromptConfig,
     AdaptionPromptConfig,
     PrefixTuningConfig,
     PrefixTuningConfig,
 )
 )
 
 
-import configs.datasets as datasets
-from configs import lora_config, llama_adapter_config, prefix_config, train_config
+from ..configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
 from .dataset_utils import DATASET_PREPROC
 from .dataset_utils import DATASET_PREPROC
 
 
 
 

+ 3 - 4
src/llama_recipes/utils/dataset_utils.py

@@ -1,16 +1,15 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
-import torch
-
 from functools import partial
 from functools import partial
 
 
-from ft_datasets import (
+import torch
+
+from ..datasets import (
     get_grammar_dataset,
     get_grammar_dataset,
     get_alpaca_dataset,
     get_alpaca_dataset,
     get_samsum_dataset,
     get_samsum_dataset,
 )
 )
-from typing import Optional
 
 
 
 
 DATASET_PREPROC = {
 DATASET_PREPROC = {

+ 0 - 3
src/llama_recipes/utils/fsdp_utils.py

@@ -3,10 +3,7 @@
 
 
 def fsdp_auto_wrap_policy(model, transformer_layer_name):
 def fsdp_auto_wrap_policy(model, transformer_layer_name):
     import functools
     import functools
-    import os
 
 
-    from accelerate import FullyShardedDataParallelPlugin
-    from transformers.models.t5.modeling_t5 import T5Block
     from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
     from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
 
 
     from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
     from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder

+ 2 - 4
src/llama_recipes/utils/memory_utils.py

@@ -1,12 +1,10 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
 import gc
 import gc
-import os
-import sys
+import psutil
 import threading
 import threading
 
 
-import numpy as np
-import psutil
 import torch
 import torch
 
 
 def byte2gb(x):
 def byte2gb(x):

+ 18 - 33
src/llama_recipes/utils/train_utils.py

@@ -2,40 +2,25 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
 import os
 import os
-import sys
-from typing import List
-import yaml
 import time
 import time
+import yaml
+from pathlib import Path
+from pkg_resources import packaging
+
 
 
-import fire
 import torch
 import torch
-import transformers
-from datasets import load_dataset
-from tqdm import tqdm
-"""
-Unused imports:
-import torch.nn as nn
-import bitsandbytes as bnb
-"""
-from torch.nn import functional as F
-from peft import (
-    LoraConfig,
-    get_peft_model,
-    get_peft_model_state_dict,
-    prepare_model_for_int8_training,
-    set_peft_model_state_dict,
-)
-from transformers import LlamaForCausalLM, LlamaTokenizer
-from torch.distributed.fsdp import StateDictType
-import torch.distributed as dist
-from pkg_resources import packaging
-from .memory_utils import MemoryTrace
-import model_checkpointing
 import torch.cuda.nccl as nccl
 import torch.cuda.nccl as nccl
+import torch.distributed as dist
+from torch.distributed.fsdp import StateDictType
 from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
 from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
-from pathlib import Path
-sys.path.append(str(Path(__file__).resolve().parent.parent))
-from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
+from tqdm import tqdm
+from transformers import LlamaTokenizer
+
+
+from .memory_utils import MemoryTrace
+from ..model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
+from ..policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
+
 
 
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
     tokenizer.pad_token_id = 0
@@ -162,21 +147,21 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                 else:
                 else:
                     if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
                     if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
                         
                         
-                        model_checkpointing.save_model_checkpoint(
+                        save_model_checkpoint(
                             model, optimizer, rank, train_config, epoch=epoch
                             model, optimizer, rank, train_config, epoch=epoch
                         )
                         )
                     elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
                     elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
                         print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
                         print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
                         print("=====================================================")
                         print("=====================================================")
                         
                         
-                        model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
+                        save_model_and_optimizer_sharded(model, rank, train_config)
                         if train_config.save_optimizer:
                         if train_config.save_optimizer:
-                            model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
+                            save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
                             print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
                             print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
                             print("=====================================================")
                             print("=====================================================")
 
 
                     if not train_config.use_peft and  train_config.save_optimizer:
                     if not train_config.use_peft and  train_config.save_optimizer:
-                        model_checkpointing.save_optimizer_checkpoint(
+                        save_optimizer_checkpoint(
                             model, optimizer, rank, train_config, epoch=epoch
                             model, optimizer, rank, train_config, epoch=epoch
                         )
                         )
                         print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
                         print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")