llama_finetuning.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import sys
  5. from typing import List, Union
  6. import fire
  7. import torch
  8. import transformers
  9. from datasets import load_dataset
  10. import os.path as osp
  11. from tqdm import tqdm
  12. # Unused imports removed
  13. from utils import fsdp_auto_wrap_policy
  14. from transformers import (
  15. LlamaForCausalLM,
  16. LlamaTokenizer,
  17. AutoModelForCausalLM,
  18. AutoModelForSeq2SeqLM,
  19. AutoTokenizer,
  20. default_data_collator,
  21. BitsAndBytesConfig
  22. )
  23. import torch.distributed as dist
  24. # Unused imports removed
  25. from utils.train_utils import (
  26. set_tokenizer_params,
  27. train,
  28. evaluation,
  29. freeze_transformer_layers,
  30. check_frozen_layers_peft_model,
  31. setup,
  32. setup_environ_flags,
  33. cleanup,
  34. clear_gpu_cache,
  35. get_parameter_dtypes,
  36. print_model_size,
  37. get_policies
  38. )
  39. from utils.dataset_utils import get_preprocessed_dataset
  40. from utils.config_utils import (
  41. update_config,
  42. generate_peft_config,
  43. generate_dataset_config,
  44. )
  45. from peft import get_peft_model, TaskType, prepare_model_for_int8_training
  46. import configs
  47. from torch.distributed.fsdp import (
  48. FullyShardedDataParallel as FSDP,
  49. MixedPrecision,
  50. )
  51. from torch.utils.data import DistributedSampler
  52. import policies
  53. from policies import AnyPrecisionAdamW
  54. from configs import fsdp_config, train_config
  55. import torch.optim as optim
  56. from torch.optim.lr_scheduler import StepLR
  57. from pkg_resources import packaging
  58. import torch
  59. import torch.cuda.nccl as nccl
  60. import torch.distributed as dist
  61. from transformers.models.llama.modeling_llama import LlamaDecoderLayer
  62. def main(**kwargs):
  63. # Update the configuration for the training and sharding process
  64. update_config((train_config, fsdp_config), **kwargs)
  65. # Set the seeds for reproducibility
  66. torch.cuda.manual_seed(train_config.seed)
  67. torch.manual_seed(train_config.seed)
  68. if train_config.enable_fsdp:
  69. setup()
  70. # torchrun specific
  71. local_rank = int(os.environ["LOCAL_RANK"])
  72. rank = int(os.environ["RANK"])
  73. world_size = int(os.environ["WORLD_SIZE"])
  74. if torch.distributed.is_initialized():
  75. torch.cuda.set_device(rank)
  76. setup_environ_flags(rank)
  77. # Calculate gradient accumulation steps
  78. gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size
  79. # Load the pre-trained model and setup its configuration
  80. model = LlamaForCausalLM.from_pretrained(
  81. train_config.model_name,
  82. load_in_8bit=True if train_config.quantization else None,
  83. device_map="auto" if train_config.quantization else None,
  84. )
  85. print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
  86. # Prepare the model for int8 training if quantization is enabled
  87. if train_config.quantization:
  88. model = prepare_model_for_int8_training(model)
  89. # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
  90. if train_config.enable_fsdp and fsdp_config.pure_bf16:
  91. model.to(torch.bfloat16)
  92. # Load the tokenizer and add special tokens
  93. tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
  94. tokenizer.add_special_tokens(
  95. {
  96. "pad_token": "<PAD>",
  97. }
  98. )
  99. if train_config.use_peft:
  100. peft_config = generate_peft_config(train_config, kwargs)
  101. model = get_peft_model(model, peft_config)
  102. model.print_trainable_parameters()
  103. #setting up FSDP if enable_fsdp is enabled
  104. if train_config.enable_fsdp:
  105. if not train_config.use_peft and train_config.freeze_layers:
  106. freeze_transformer_layers(train_config.num_freeze_layers)
  107. mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
  108. my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
  109. if fsdp_config.optimizer_overlap:
  110. try:
  111. from torch.distributed.optim import _apply_optimizer_in_backward
  112. except ImportError:
  113. # Handle the ImportError here, such as providing an alternative implementation or an error message.
  114. print("The required module 'torch.distributed.optim' is not available.")
  115. model = FSDP(
  116. model,
  117. auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
  118. mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
  119. sharding_strategy=fsdp_config.sharding_strategy,
  120. device_id=torch.cuda.current_device(),
  121. limit_all_gathers=True,
  122. use_orig_params=True,
  123. )
  124. else:
  125. model = FSDP(
  126. model,
  127. auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
  128. mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
  129. sharding_strategy=fsdp_config.sharding_strategy,
  130. device_id=torch.cuda.current_device(),
  131. limit_all_gathers=True,
  132. )
  133. if fsdp_config.fsdp_activation_checkpointing:
  134. policies.apply_fsdp_checkpointing(model)
  135. elif not train_config.quantization and not train_config.enable_fsdp:
  136. model.to("cuda")
  137. dataset_config = generate_dataset_config(train_config, kwargs)
  138. # Load and preprocess the dataset for training and validation
  139. dataset_train = get_preprocessed_dataset(
  140. tokenizer,
  141. dataset_config,
  142. split="train",
  143. )
  144. if not train_config.enable_fsdp or rank == 0:
  145. print(f"--> Training Set Length = {len(dataset_train)}")
  146. dataset_val = get_preprocessed_dataset(
  147. tokenizer,
  148. dataset_config,
  149. split="test",
  150. )
  151. if not train_config.enable_fsdp or rank == 0:
  152. print(f"--> Validation Set Length = {len(dataset_val)}")
  153. train_sampler = None
  154. val_sampler = None
  155. if train_config.enable_fsdp:
  156. train_sampler = DistributedSampler(
  157. dataset_train,
  158. rank=dist.get_rank(),
  159. num_replicas=dist.get_world_size(),
  160. shuffle=True,
  161. )
  162. if train_config.run_validation:
  163. val_sampler = DistributedSampler(
  164. dataset_val,
  165. rank=dist.get_rank(),
  166. num_replicas=dist.get_world_size(),
  167. )
  168. # Create DataLoaders for the training and validation dataset
  169. train_dataloader = torch.utils.data.DataLoader(
  170. dataset_train,
  171. batch_size=train_config.batch_size_training,
  172. num_workers=train_config.num_workers_dataloader,
  173. pin_memory=True,
  174. sampler=train_sampler if train_sampler else None,
  175. drop_last=True,
  176. collate_fn=default_data_collator,
  177. )
  178. if train_config.run_validation:
  179. eval_dataloader = torch.utils.data.DataLoader(
  180. dataset_val,
  181. batch_size=train_config.val_batch_size,
  182. num_workers=train_config.num_workers_dataloader,
  183. pin_memory=True,
  184. sampler=val_sampler if val_sampler else None,
  185. drop_last=True,
  186. collate_fn=default_data_collator,
  187. )
  188. #Initialize the optimizer and learning rate scheduler
  189. if not fsdp_config.optimizer_overlap:
  190. if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
  191. optimizer = AnyPrecisionAdamW(
  192. model.parameters(),
  193. lr=train_config.lr,
  194. momentum_dtype=torch.bfloat16,
  195. variance_dtype=torch.bfloat16,
  196. use_kahan_summation=False,
  197. )
  198. else:
  199. optimizer = optim.AdamW(
  200. model.parameters(),
  201. lr=train_config.lr,
  202. weight_decay=0.0,
  203. )
  204. if fsdp_config.optimizer_overlap:
  205. print(f"setting up optimizer overlap")
  206. print("===============================")
  207. optim_kwargs = {"lr": train_config.lr}
  208. _apply_optimizer_in_backward(
  209. optimizer_class=optim.AdamW,
  210. params=model.parameters(),
  211. optimizer_kwargs=optim_kwargs,
  212. register_hook=False,
  213. )
  214. for p in model.parameters():
  215. assert hasattr(p, "_in_backward_optimizers")
  216. optim_kwargs = {"lr": train_config.lr, "weight_decay":0.0}
  217. optimizer = optim.AdamW(
  218. model.parameters(),
  219. **optim_kwargs
  220. )
  221. scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
  222. # Start the training process
  223. results = train(
  224. model,
  225. train_dataloader,
  226. eval_dataloader,
  227. tokenizer,
  228. optimizer,
  229. scheduler,
  230. gradient_accumulation_steps,
  231. train_config,
  232. fsdp_config if train_config.enable_fsdp else None,
  233. local_rank if train_config.enable_fsdp else None,
  234. rank if train_config.enable_fsdp else None,
  235. )
  236. if not train_config.enable_fsdp or rank==0:
  237. [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
  238. if __name__ == "__main__":
  239. fire.Fire(main)