finetuning.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. from pkg_resources import packaging
  5. import fire
  6. import random
  7. import torch
  8. import torch.optim as optim
  9. from peft import get_peft_model, prepare_model_for_int8_training
  10. from torch.distributed.fsdp import (
  11. FullyShardedDataParallel as FSDP,
  12. )
  13. from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
  14. from torch.optim.lr_scheduler import StepLR
  15. from transformers import (
  16. LlamaForCausalLM,
  17. LlamaTokenizer,
  18. LlamaConfig,
  19. )
  20. from transformers.models.llama.modeling_llama import LlamaDecoderLayer
  21. from llama_recipes.configs import fsdp_config as FSDP_CONFIG
  22. from llama_recipes.configs import train_config as TRAIN_CONFIG
  23. from llama_recipes.data.concatenator import ConcatDataset
  24. from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
  25. from llama_recipes.utils import fsdp_auto_wrap_policy
  26. from llama_recipes.utils.config_utils import (
  27. update_config,
  28. generate_peft_config,
  29. generate_dataset_config,
  30. get_dataloader_kwargs,
  31. )
  32. from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
  33. from llama_recipes.utils.train_utils import (
  34. train,
  35. freeze_transformer_layers,
  36. setup,
  37. setup_environ_flags,
  38. clear_gpu_cache,
  39. print_model_size,
  40. get_policies
  41. )
  42. from accelerate.utils import is_xpu_available
  43. def main(**kwargs):
  44. # Update the configuration for the training and sharding process
  45. train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
  46. update_config((train_config, fsdp_config), **kwargs)
  47. # Set the seeds for reproducibility
  48. if is_xpu_available():
  49. torch.xpu.manual_seed(train_config.seed)
  50. else:
  51. torch.cuda.manual_seed(train_config.seed)
  52. torch.manual_seed(train_config.seed)
  53. random.seed(train_config.seed)
  54. if train_config.enable_fsdp:
  55. setup()
  56. # torchrun specific
  57. local_rank = int(os.environ["LOCAL_RANK"])
  58. rank = int(os.environ["RANK"])
  59. world_size = int(os.environ["WORLD_SIZE"])
  60. if torch.distributed.is_initialized():
  61. if is_xpu_available():
  62. torch.xpu.set_device(local_rank)
  63. else:
  64. torch.cuda.set_device(local_rank)
  65. clear_gpu_cache(local_rank)
  66. setup_environ_flags(rank)
  67. # Load the pre-trained model and setup its configuration
  68. use_cache = False if train_config.enable_fsdp else None
  69. if train_config.enable_fsdp and train_config.low_cpu_fsdp:
  70. """
  71. for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
  72. this avoids cpu oom when loading large models like llama 70B, in which case
  73. model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
  74. overhead and currently requires latest nightly.
  75. """
  76. v = packaging.version.parse(torch.__version__)
  77. verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
  78. if not verify_latest_nightly:
  79. raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
  80. "please install latest nightly.")
  81. if rank == 0:
  82. model = LlamaForCausalLM.from_pretrained(
  83. train_config.model_name,
  84. load_in_8bit=True if train_config.quantization else None,
  85. device_map="auto" if train_config.quantization else None,
  86. use_cache=use_cache,
  87. )
  88. else:
  89. llama_config = LlamaConfig.from_pretrained(train_config.model_name)
  90. llama_config.use_cache = use_cache
  91. with torch.device("meta"):
  92. model = LlamaForCausalLM(llama_config)
  93. else:
  94. model = LlamaForCausalLM.from_pretrained(
  95. train_config.model_name,
  96. load_in_8bit=True if train_config.quantization else None,
  97. device_map="auto" if train_config.quantization else None,
  98. use_cache=use_cache,
  99. )
  100. if train_config.enable_fsdp and train_config.use_fast_kernels:
  101. """
  102. For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
  103. using of Flash Attention or Xformer memory-efficient kernels
  104. based on the hardware being used. This would speed up fine-tuning.
  105. """
  106. try:
  107. from optimum.bettertransformer import BetterTransformer
  108. model = BetterTransformer.transform(model)
  109. except ImportError:
  110. print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
  111. # Load the tokenizer and add special tokens
  112. tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
  113. tokenizer.pad_token_id = tokenizer.eos_token_id
  114. print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
  115. # Prepare the model for int8 training if quantization is enabled
  116. if train_config.quantization:
  117. model = prepare_model_for_int8_training(model)
  118. # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
  119. if train_config.enable_fsdp and fsdp_config.pure_bf16:
  120. model.to(torch.bfloat16)
  121. if train_config.use_peft:
  122. peft_config = generate_peft_config(train_config, kwargs)
  123. model = get_peft_model(model, peft_config)
  124. model.print_trainable_parameters()
  125. #setting up FSDP if enable_fsdp is enabled
  126. if train_config.enable_fsdp:
  127. if not train_config.use_peft and train_config.freeze_layers:
  128. freeze_transformer_layers(train_config.num_freeze_layers)
  129. mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
  130. my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
  131. model = FSDP(
  132. model,
  133. auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
  134. cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
  135. mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
  136. sharding_strategy=fsdp_config.sharding_strategy,
  137. device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(),
  138. limit_all_gathers=True,
  139. sync_module_states=train_config.low_cpu_fsdp,
  140. param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
  141. if train_config.low_cpu_fsdp and rank != 0 else None,
  142. )
  143. if fsdp_config.fsdp_activation_checkpointing:
  144. apply_fsdp_checkpointing(model)
  145. elif not train_config.quantization and not train_config.enable_fsdp:
  146. if is_xpu_available():
  147. model.to("xpu:0")
  148. else:
  149. model.to("cuda")
  150. dataset_config = generate_dataset_config(train_config, kwargs)
  151. # Load and preprocess the dataset for training and validation
  152. dataset_train = get_preprocessed_dataset(
  153. tokenizer,
  154. dataset_config,
  155. split="train",
  156. )
  157. if not train_config.enable_fsdp or rank == 0:
  158. print(f"--> Training Set Length = {len(dataset_train)}")
  159. dataset_val = get_preprocessed_dataset(
  160. tokenizer,
  161. dataset_config,
  162. split="test",
  163. )
  164. if not train_config.enable_fsdp or rank == 0:
  165. print(f"--> Validation Set Length = {len(dataset_val)}")
  166. if train_config.batching_strategy == "packing":
  167. dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
  168. train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
  169. # Create DataLoaders for the training and validation dataset
  170. train_dataloader = torch.utils.data.DataLoader(
  171. dataset_train,
  172. num_workers=train_config.num_workers_dataloader,
  173. pin_memory=True,
  174. **train_dl_kwargs,
  175. )
  176. eval_dataloader = None
  177. if train_config.run_validation:
  178. if train_config.batching_strategy == "packing":
  179. dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
  180. val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
  181. eval_dataloader = torch.utils.data.DataLoader(
  182. dataset_val,
  183. num_workers=train_config.num_workers_dataloader,
  184. pin_memory=True,
  185. **val_dl_kwargs,
  186. )
  187. # Initialize the optimizer and learning rate scheduler
  188. if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
  189. optimizer = AnyPrecisionAdamW(
  190. model.parameters(),
  191. lr=train_config.lr,
  192. momentum_dtype=torch.bfloat16,
  193. variance_dtype=torch.bfloat16,
  194. use_kahan_summation=False,
  195. weight_decay=train_config.weight_decay,
  196. )
  197. else:
  198. optimizer = optim.AdamW(
  199. model.parameters(),
  200. lr=train_config.lr,
  201. weight_decay=train_config.weight_decay,
  202. )
  203. scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
  204. # Start the training process
  205. results = train(
  206. model,
  207. train_dataloader,
  208. eval_dataloader,
  209. tokenizer,
  210. optimizer,
  211. scheduler,
  212. train_config.gradient_accumulation_steps,
  213. train_config,
  214. fsdp_config if train_config.enable_fsdp else None,
  215. local_rank if train_config.enable_fsdp else None,
  216. rank if train_config.enable_fsdp else None,
  217. )
  218. if not train_config.enable_fsdp or rank==0:
  219. [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
  220. if __name__ == "__main__":
  221. fire.Fire(main)