finetuning.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from collections import Counter
  4. import os
  5. import dataclasses
  6. import fire
  7. import random
  8. import torch
  9. import torch.optim as optim
  10. from peft import get_peft_model, PeftModel
  11. from torch.distributed.fsdp import (
  12. FullyShardedDataParallel as FSDP,
  13. ShardingStrategy
  14. )
  15. from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
  16. from torch.optim.lr_scheduler import StepLR
  17. from transformers import (
  18. AutoTokenizer,
  19. BitsAndBytesConfig,
  20. LlamaForCausalLM,
  21. LlamaConfig,
  22. )
  23. from transformers.models.llama.modeling_llama import LlamaDecoderLayer
  24. from llama_recipes.configs import fsdp_config as FSDP_CONFIG
  25. from llama_recipes.configs import train_config as TRAIN_CONFIG
  26. from llama_recipes.configs import quantization_config as QUANTIZATION_CONFIG
  27. from llama_recipes.data.concatenator import ConcatDataset
  28. from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
  29. from llama_recipes.utils import fsdp_auto_wrap_policy
  30. from llama_recipes.utils.config_utils import (
  31. update_config,
  32. generate_peft_config,
  33. generate_dataset_config,
  34. get_dataloader_kwargs,
  35. )
  36. from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
  37. from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
  38. from llama_recipes.utils.train_utils import (
  39. train,
  40. freeze_transformer_layers,
  41. setup,
  42. setup_environ_flags,
  43. clear_gpu_cache,
  44. print_model_size,
  45. get_policies,
  46. )
  47. from accelerate.utils import is_xpu_available
  48. from warnings import warn
  49. def setup_wandb(train_config, fsdp_config, **kwargs):
  50. try:
  51. import wandb
  52. except ImportError:
  53. raise ImportError(
  54. "You are trying to use wandb which is not currently installed. "
  55. "Please install it using pip install wandb"
  56. )
  57. from llama_recipes.configs import wandb_config as WANDB_CONFIG
  58. wandb_config = WANDB_CONFIG()
  59. update_config(wandb_config, **kwargs)
  60. init_dict = dataclasses.asdict(wandb_config)
  61. run = wandb.init(**init_dict)
  62. run.config.update(train_config)
  63. run.config.update(fsdp_config, allow_val_change=True)
  64. return run
  65. def main(**kwargs):
  66. # Update the configuration for the training and sharding process
  67. train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
  68. update_config((train_config, fsdp_config), **kwargs)
  69. # Set the seeds for reproducibility
  70. if is_xpu_available():
  71. torch.xpu.manual_seed(train_config.seed)
  72. torch.manual_seed(train_config.seed)
  73. random.seed(train_config.seed)
  74. if train_config.enable_fsdp:
  75. setup()
  76. # torchrun specific
  77. local_rank = int(os.environ["LOCAL_RANK"])
  78. rank = int(os.environ["RANK"])
  79. world_size = int(os.environ["WORLD_SIZE"])
  80. if torch.distributed.is_initialized():
  81. if is_xpu_available():
  82. torch.xpu.set_device(local_rank)
  83. elif torch.cuda.is_available():
  84. torch.cuda.set_device(local_rank)
  85. clear_gpu_cache(local_rank)
  86. setup_environ_flags(rank)
  87. wandb_run = None
  88. if train_config.use_wandb:
  89. if not train_config.enable_fsdp or rank==0:
  90. wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
  91. #setting quantization configs
  92. bnb_config = None
  93. if train_config.quantization:
  94. if type(train_config.quantization) == type(True):
  95. warn("Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.", FutureWarning)
  96. train_config.quantization = "8bit"
  97. if train_config.quantization == "8bit" and train_config.enable_fsdp:
  98. raise ValueError("8bit quantization is not supported with FSDP, please use 4bit quantization")
  99. quant_config = QUANTIZATION_CONFIG()
  100. update_config(quant_config, **kwargs)
  101. bnb_config = quant_config.create_bnb_config(train_config.quantization)
  102. # Load the pre-trained model and setup its configuration
  103. use_cache = False if train_config.enable_fsdp else None
  104. model = LlamaForCausalLM.from_pretrained(
  105. train_config.model_name,
  106. quantization_config=bnb_config,
  107. use_cache=use_cache,
  108. attn_implementation="sdpa" if train_config.use_fast_kernels else None,
  109. device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
  110. torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
  111. )
  112. # Load the tokenizer and add special tokens
  113. tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
  114. tokenizer.pad_token_id = tokenizer.eos_token_id
  115. # If there is a mismatch between tokenizer vocab size and embedding matrix,
  116. # throw a warning and then expand the embedding matrix
  117. if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
  118. print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
  119. model.resize_token_embeddings(len(tokenizer))
  120. print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
  121. # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
  122. if train_config.enable_fsdp and fsdp_config.pure_bf16 and not train_config.quantization:
  123. model.to(torch.bfloat16)
  124. if train_config.use_peft:
  125. # Load the pre-trained peft model checkpoint and setup its configuration
  126. if train_config.from_peft_checkpoint:
  127. model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True)
  128. peft_config = model.peft_config()
  129. # Generate the peft config and start fine-tuning from original model
  130. else:
  131. peft_config = generate_peft_config(train_config, kwargs)
  132. model = get_peft_model(model, peft_config)
  133. if wandb_run:
  134. wandb_run.config.update(peft_config)
  135. model.print_trainable_parameters()
  136. hsdp_device_mesh_plan = None
  137. if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
  138. hsdp_device_mesh_plan = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
  139. print("HSDP device mesh is ready")
  140. #setting up FSDP if enable_fsdp is enabled
  141. if train_config.enable_fsdp:
  142. if not train_config.use_peft and train_config.freeze_layers:
  143. freeze_transformer_layers(model, train_config.num_freeze_layers)
  144. mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
  145. my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
  146. device_id = 0
  147. if is_xpu_available():
  148. device_id = torch.xpu.current_device()
  149. elif torch.cuda.is_available():
  150. device_id = torch.cuda.current_device()
  151. model = FSDP(
  152. model,
  153. auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
  154. cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
  155. mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
  156. sharding_strategy=fsdp_config.sharding_strategy,
  157. device_mesh=hsdp_device_mesh_plan,
  158. device_id=device_id,
  159. limit_all_gathers=True,
  160. sync_module_states=train_config.low_cpu_fsdp,
  161. param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False))
  162. if train_config.low_cpu_fsdp and rank != 0 else None,
  163. )
  164. if fsdp_config.fsdp_activation_checkpointing:
  165. model.enable_input_require_grads()
  166. model.gradient_checkpointing_enable()
  167. apply_fsdp_checkpointing(model)
  168. elif not train_config.quantization and not train_config.enable_fsdp:
  169. if is_xpu_available():
  170. model.to("xpu:0")
  171. elif torch.cuda.is_available():
  172. model.to("cuda")
  173. dataset_config = generate_dataset_config(train_config, kwargs)
  174. # Load and preprocess the dataset for training and validation
  175. dataset_train = get_preprocessed_dataset(
  176. tokenizer,
  177. dataset_config,
  178. split="train",
  179. )
  180. if not train_config.enable_fsdp or rank == 0:
  181. print(f"--> Training Set Length = {len(dataset_train)}")
  182. dataset_val = get_preprocessed_dataset(
  183. tokenizer,
  184. dataset_config,
  185. split="test",
  186. )
  187. if not train_config.enable_fsdp or rank == 0:
  188. print(f"--> Validation Set Length = {len(dataset_val)}")
  189. if train_config.batching_strategy == "packing":
  190. dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
  191. train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
  192. # Create DataLoaders for the training and validation dataset
  193. train_dataloader = torch.utils.data.DataLoader(
  194. dataset_train,
  195. num_workers=train_config.num_workers_dataloader,
  196. pin_memory=True,
  197. **train_dl_kwargs,
  198. )
  199. eval_dataloader = None
  200. if train_config.run_validation:
  201. if train_config.batching_strategy == "packing":
  202. dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
  203. val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
  204. eval_dataloader = torch.utils.data.DataLoader(
  205. dataset_val,
  206. num_workers=train_config.num_workers_dataloader,
  207. pin_memory=True,
  208. **val_dl_kwargs,
  209. )
  210. if len(eval_dataloader) == 0:
  211. raise ValueError("The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set.")
  212. else:
  213. print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
  214. # Initialize the optimizer and learning rate scheduler
  215. if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
  216. optimizer = AnyPrecisionAdamW(
  217. model.parameters(),
  218. lr=train_config.lr,
  219. momentum_dtype=torch.bfloat16,
  220. variance_dtype=torch.bfloat16,
  221. use_kahan_summation=False,
  222. weight_decay=train_config.weight_decay,
  223. )
  224. else:
  225. optimizer = optim.AdamW(
  226. model.parameters(),
  227. lr=train_config.lr,
  228. weight_decay=train_config.weight_decay,
  229. )
  230. scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
  231. # Start the training process
  232. results = train(
  233. model,
  234. train_dataloader,
  235. eval_dataloader,
  236. tokenizer,
  237. optimizer,
  238. scheduler,
  239. train_config.gradient_accumulation_steps,
  240. train_config,
  241. fsdp_config if train_config.enable_fsdp else None,
  242. local_rank if train_config.enable_fsdp else None,
  243. rank if train_config.enable_fsdp else None,
  244. wandb_run,
  245. )
  246. if not train_config.enable_fsdp or rank==0:
  247. [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
  248. if train_config.use_wandb:
  249. for k,v in results.items():
  250. wandb_run.summary[k] = v
  251. if __name__ == "__main__":
  252. fire.Fire(main)