config_utils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import inspect
  4. from dataclasses import asdict
  5. import torch.distributed as dist
  6. from torch.utils.data import DistributedSampler
  7. from peft import (
  8. LoraConfig,
  9. AdaptionPromptConfig,
  10. PrefixTuningConfig,
  11. )
  12. from transformers import default_data_collator
  13. from transformers.data import DataCollatorForSeq2Seq
  14. from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
  15. from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
  16. from llama_recipes.utils.dataset_utils import DATASET_PREPROC
  17. def update_config(config, **kwargs):
  18. if isinstance(config, (tuple, list)):
  19. for c in config:
  20. update_config(c, **kwargs)
  21. else:
  22. for k, v in kwargs.items():
  23. if hasattr(config, k):
  24. setattr(config, k, v)
  25. elif "." in k:
  26. # allow --some_config.some_param=True
  27. config_name, param_name = k.split(".")
  28. if type(config).__name__ == config_name:
  29. if hasattr(config, param_name):
  30. setattr(config, param_name, v)
  31. else:
  32. # In case of specialized config we can warm user
  33. print(f"Warning: {config_name} does not accept parameter: {k}")
  34. elif isinstance(config, train_config):
  35. print(f"Warning: unknown parameter {k}")
  36. def generate_peft_config(train_config, kwargs):
  37. configs = (lora_config, llama_adapter_config, prefix_config)
  38. peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
  39. names = tuple(c.__name__.rstrip("_config") for c in configs)
  40. assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
  41. config = configs[names.index(train_config.peft_method)]()
  42. update_config(config, **kwargs)
  43. params = asdict(config)
  44. peft_config = peft_configs[names.index(train_config.peft_method)](**params)
  45. return peft_config
  46. def generate_dataset_config(train_config, kwargs):
  47. names = tuple(DATASET_PREPROC.keys())
  48. assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
  49. dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()
  50. update_config(dataset_config, **kwargs)
  51. return dataset_config
  52. def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
  53. kwargs = {}
  54. batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
  55. if train_config.batching_strategy == "padding":
  56. if train_config.enable_fsdp:
  57. kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
  58. dataset,
  59. batch_size=batch_size,
  60. rank=dist.get_rank(),
  61. num_replicas=dist.get_world_size(),
  62. shuffle=mode=="train",
  63. )
  64. else:
  65. kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
  66. kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
  67. elif train_config.batching_strategy == "packing":
  68. if train_config.enable_fsdp:
  69. kwargs["sampler"] = DistributedSampler(
  70. dataset,
  71. rank=dist.get_rank(),
  72. num_replicas=dist.get_world_size(),
  73. shuffle=mode=="train",
  74. drop_last=True,
  75. )
  76. kwargs["batch_size"] = batch_size
  77. kwargs["drop_last"] = True
  78. kwargs["collate_fn"] = default_data_collator
  79. else:
  80. raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
  81. return kwargs