config_utils.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import inspect
  4. from dataclasses import asdict
  5. import torch.distributed as dist
  6. from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
  7. from torch.utils.data import DistributedSampler
  8. from peft import (
  9. LoraConfig,
  10. AdaptionPromptConfig,
  11. PrefixTuningConfig,
  12. )
  13. from transformers import default_data_collator
  14. from transformers.data import DataCollatorForSeq2Seq
  15. from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
  16. from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
  17. from llama_recipes.datasets import DATASET_PREPROC
  18. def update_config(config, **kwargs):
  19. if isinstance(config, (tuple, list)):
  20. for c in config:
  21. update_config(c, **kwargs)
  22. else:
  23. for k, v in kwargs.items():
  24. if hasattr(config, k):
  25. setattr(config, k, v)
  26. elif "." in k:
  27. # allow --some_config.some_param=True
  28. config_name, param_name = k.split(".")
  29. if type(config).__name__ == config_name:
  30. if hasattr(config, param_name):
  31. setattr(config, param_name, v)
  32. else:
  33. # In case of specialized config we can warn user
  34. print(f"Warning: {config_name} does not accept parameter: {k}")
  35. elif isinstance(config, train_config):
  36. print(f"Warning: unknown parameter {k}")
  37. def generate_peft_config(train_config, kwargs):
  38. configs = (lora_config, llama_adapter_config, prefix_config)
  39. peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
  40. names = tuple(c.__name__.rstrip("_config") for c in configs)
  41. if train_config.peft_method not in names:
  42. raise RuntimeError(f"Peft config not found: {train_config.peft_method}")
  43. if train_config.peft_method == "prefix":
  44. raise RuntimeError("PrefixTuning is currently not supported (see https://github.com/meta-llama/llama-recipes/issues/359#issuecomment-2089350811)")
  45. if train_config.enable_fsdp and train_config.peft_method == "llama_adapter":
  46. raise RuntimeError("Llama_adapter is currently not supported in combination with FSDP (see https://github.com/meta-llama/llama-recipes/issues/359#issuecomment-2089274425)")
  47. config = configs[names.index(train_config.peft_method)]()
  48. update_config(config, **kwargs)
  49. params = asdict(config)
  50. peft_config = peft_configs[names.index(train_config.peft_method)](**params)
  51. return peft_config
  52. def generate_dataset_config(train_config, kwargs):
  53. names = tuple(DATASET_PREPROC.keys())
  54. assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
  55. dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()
  56. update_config(dataset_config, **kwargs)
  57. return dataset_config
  58. def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
  59. kwargs = {}
  60. batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
  61. if train_config.batching_strategy == "padding":
  62. if train_config.enable_fsdp:
  63. kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
  64. dataset,
  65. batch_size=batch_size,
  66. rank=dist.get_rank(),
  67. num_replicas=dist.get_world_size(),
  68. shuffle=mode=="train",
  69. )
  70. else:
  71. kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
  72. kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
  73. elif train_config.batching_strategy == "packing":
  74. if train_config.enable_fsdp:
  75. kwargs["sampler"] = DistributedSampler(
  76. dataset,
  77. rank=dist.get_rank(),
  78. num_replicas=dist.get_world_size(),
  79. shuffle=mode=="train",
  80. drop_last=True,
  81. )
  82. kwargs["batch_size"] = batch_size
  83. kwargs["drop_last"] = True
  84. kwargs["collate_fn"] = default_data_collator
  85. else:
  86. raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
  87. return kwargs
  88. def check_fsdp_config(fsdp_config):
  89. VALID_TYPES = (StateDictType.SHARDED_STATE_DICT, StateDictType.FULL_STATE_DICT)
  90. if isinstance(fsdp_config.checkpoint_type, str):
  91. str_to_obj = {
  92. "StateDictType.SHARDED_STATE_DICT": StateDictType.SHARDED_STATE_DICT,
  93. "StateDictType.FULL_STATE_DICT": StateDictType.FULL_STATE_DICT,
  94. }
  95. if fsdp_config.checkpoint_type in str_to_obj:
  96. fsdp_config.checkpoint_type = str_to_obj[fsdp_config.checkpoint_type]
  97. if not fsdp_config.checkpoint_type in VALID_TYPES:
  98. raise ValueError(f"Invalid checkpoint_type {fsdp_config.checkpoint_type}")