finetuning.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from collections import Counter
  4. import os
  5. import dataclasses
  6. import fire
  7. import random
  8. import torch
  9. import torch.optim as optim
  10. import numpy as np
  11. from peft import get_peft_model, PeftModel
  12. from torch.distributed.fsdp import (
  13. FullyShardedDataParallel as FSDP,
  14. ShardingStrategy
  15. )
  16. from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
  17. from torch.optim.lr_scheduler import StepLR
  18. from transformers import (
  19. AutoConfig,
  20. AutoTokenizer,
  21. BitsAndBytesConfig,
  22. AutoProcessor,
  23. LlamaForCausalLM,
  24. MllamaForConditionalGeneration,
  25. )
  26. from transformers.models.llama.modeling_llama import LlamaDecoderLayer
  27. from transformers.models.mllama.modeling_mllama import MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
  28. from llama_recipes.configs import fsdp_config as FSDP_CONFIG
  29. from llama_recipes.configs import train_config as TRAIN_CONFIG
  30. from llama_recipes.configs import quantization_config as QUANTIZATION_CONFIG
  31. from llama_recipes.data.concatenator import ConcatDataset
  32. from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
  33. from llama_recipes.utils import fsdp_auto_wrap_policy
  34. from llama_recipes.utils.config_utils import (
  35. update_config,
  36. generate_peft_config,
  37. generate_dataset_config,
  38. get_dataloader_kwargs,
  39. check_fsdp_config,
  40. )
  41. from llama_recipes.utils.dataset_utils import get_preprocessed_dataset,get_custom_data_collator
  42. from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
  43. from llama_recipes.utils.train_utils import (
  44. train,
  45. freeze_transformer_layers,
  46. setup,
  47. setup_environ_flags,
  48. clear_gpu_cache,
  49. print_model_size,
  50. get_policies,
  51. )
  52. from accelerate.utils import is_xpu_available
  53. from warnings import warn
  54. def setup_wandb(train_config, fsdp_config, **kwargs):
  55. try:
  56. import wandb
  57. except ImportError:
  58. raise ImportError(
  59. "You are trying to use wandb which is not currently installed. "
  60. "Please install it using pip install wandb"
  61. )
  62. from llama_recipes.configs import wandb_config as WANDB_CONFIG
  63. wandb_config = WANDB_CONFIG()
  64. update_config(wandb_config, **kwargs)
  65. init_dict = dataclasses.asdict(wandb_config)
  66. run = wandb.init(**init_dict)
  67. run.config.update(train_config)
  68. run.config.update(fsdp_config, allow_val_change=True)
  69. return run
  70. def main(**kwargs):
  71. # Update the configuration for the training and sharding process
  72. train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
  73. update_config((train_config, fsdp_config), **kwargs)
  74. # Set the seeds for reproducibility
  75. if is_xpu_available():
  76. torch.xpu.manual_seed(train_config.seed)
  77. torch.manual_seed(train_config.seed)
  78. random.seed(train_config.seed)
  79. np.random.seed(train_config.seed)
  80. if train_config.enable_fsdp:
  81. setup()
  82. # torchrun specific
  83. local_rank = int(os.environ["LOCAL_RANK"])
  84. rank = int(os.environ["RANK"])
  85. world_size = int(os.environ["WORLD_SIZE"])
  86. if torch.distributed.is_initialized():
  87. if is_xpu_available():
  88. torch.xpu.set_device(local_rank)
  89. elif torch.cuda.is_available():
  90. torch.cuda.set_device(local_rank)
  91. clear_gpu_cache(local_rank)
  92. setup_environ_flags(rank)
  93. wandb_run = None
  94. if train_config.use_wandb:
  95. if not train_config.enable_fsdp or rank==0:
  96. wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
  97. #setting quantization configs
  98. bnb_config = None
  99. if train_config.quantization:
  100. if type(train_config.quantization) == type(True):
  101. warn("Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.", FutureWarning)
  102. train_config.quantization = "8bit"
  103. if train_config.quantization == "8bit" and train_config.enable_fsdp:
  104. raise ValueError("8bit quantization is not supported with FSDP, please use 4bit quantization")
  105. quant_config = QUANTIZATION_CONFIG()
  106. update_config(quant_config, **kwargs)
  107. bnb_config = quant_config.create_bnb_config(train_config.quantization)
  108. # Load the pre-trained model and setup its configuration
  109. use_cache = False if train_config.enable_fsdp else None
  110. config = AutoConfig.from_pretrained(train_config.model_name)
  111. if config.model_type == "mllama":
  112. is_vision = True
  113. model = MllamaForConditionalGeneration.from_pretrained(
  114. train_config.model_name,
  115. quantization_config=bnb_config,
  116. attn_implementation="sdpa" if train_config.use_fast_kernels else None,
  117. device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
  118. torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
  119. )
  120. processor = AutoProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
  121. processor.tokenizer.padding_side='right'
  122. model.supports_gradient_checkpointing = True
  123. model.language_model.supports_gradient_checkpointing = True
  124. elif config.model_type == "llama":
  125. is_vision = False
  126. model = LlamaForCausalLM.from_pretrained(
  127. train_config.model_name,
  128. quantization_config=bnb_config,
  129. use_cache=use_cache,
  130. attn_implementation="sdpa" if train_config.use_fast_kernels else None,
  131. device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
  132. torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
  133. )
  134. else:
  135. raise ValueError(f"Model type {config.model_type} is not supported. Please use llama or mllama model.")
  136. # Load the tokenizer and add special tokens
  137. tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
  138. if not tokenizer.pad_token_id:
  139. tokenizer.pad_token_id = tokenizer.eos_token_id
  140. # If there is a mismatch between tokenizer vocab size and embedding matrix,
  141. # throw a warning and then expand the embedding matrix
  142. if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
  143. print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
  144. model.resize_token_embeddings(len(tokenizer))
  145. print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
  146. # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
  147. if train_config.enable_fsdp and fsdp_config.pure_bf16 and not train_config.quantization:
  148. model.to(torch.bfloat16)
  149. if train_config.use_peft:
  150. # Load the pre-trained peft model checkpoint and setup its configuration
  151. if train_config.from_peft_checkpoint:
  152. model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True)
  153. peft_config = model.peft_config
  154. # Generate the peft config and start fine-tuning from original model
  155. else:
  156. peft_config = generate_peft_config(train_config, kwargs)
  157. model = get_peft_model(model, peft_config)
  158. if wandb_run:
  159. wandb_run.config.update(peft_config)
  160. model.print_trainable_parameters()
  161. hsdp_device_mesh_plan = None
  162. if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
  163. hsdp_device_mesh_plan = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
  164. print("HSDP device mesh is ready")
  165. #setting up FSDP if enable_fsdp is enabled
  166. if train_config.enable_fsdp:
  167. check_fsdp_config(fsdp_config)
  168. if not train_config.use_peft and train_config.freeze_layers:
  169. freeze_transformer_layers(model, train_config.num_freeze_layers)
  170. mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
  171. # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
  172. if is_vision:
  173. my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer])
  174. else:
  175. # Create the FSDP wrapper for LlamaDecoderLayer in text models
  176. my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
  177. device_id = 0
  178. if is_xpu_available():
  179. device_id = torch.xpu.current_device()
  180. elif torch.cuda.is_available():
  181. device_id = torch.cuda.current_device()
  182. model = FSDP(
  183. model,
  184. auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
  185. cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
  186. mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
  187. sharding_strategy=fsdp_config.sharding_strategy,
  188. device_mesh=hsdp_device_mesh_plan,
  189. device_id=device_id,
  190. limit_all_gathers=True,
  191. sync_module_states=train_config.low_cpu_fsdp,
  192. param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False))
  193. if train_config.low_cpu_fsdp and rank != 0 else None,
  194. )
  195. if fsdp_config.fsdp_activation_checkpointing:
  196. model.enable_input_require_grads()
  197. model.gradient_checkpointing_enable()
  198. apply_fsdp_checkpointing(model)
  199. elif not train_config.quantization and not train_config.enable_fsdp:
  200. if is_xpu_available():
  201. model.to("xpu:0")
  202. elif torch.cuda.is_available():
  203. model.to("cuda")
  204. dataset_config = generate_dataset_config(train_config, kwargs)
  205. if is_vision:
  206. dataset_processer = processor
  207. else:
  208. dataset_processer = tokenizer
  209. # Load and preprocess the dataset for training and validation
  210. dataset_train = get_preprocessed_dataset(
  211. dataset_processer,
  212. dataset_config,
  213. split="train",
  214. )
  215. if not train_config.enable_fsdp or rank == 0:
  216. print(f"--> Training Set Length = {len(dataset_train)}")
  217. dataset_val = get_preprocessed_dataset(
  218. dataset_processer,
  219. dataset_config,
  220. split="test",
  221. )
  222. if not train_config.enable_fsdp or rank == 0:
  223. print(f"--> Validation Set Length = {len(dataset_val)}")
  224. if train_config.batching_strategy == "packing":
  225. if is_vision:
  226. raise ValueError("Packing is not supported for vision datasets")
  227. else:
  228. dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
  229. train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
  230. print("length of dataset_train", len(dataset_train))
  231. custom_data_collator = get_custom_data_collator(dataset_processer,dataset_config)
  232. if custom_data_collator:
  233. print("custom_data_collator is used")
  234. train_dl_kwargs["collate_fn"] = custom_data_collator
  235. # Create DataLoaders for the training and validation dataset
  236. train_dataloader = torch.utils.data.DataLoader(
  237. dataset_train,
  238. num_workers=train_config.num_workers_dataloader,
  239. pin_memory=True,
  240. **train_dl_kwargs,
  241. )
  242. print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
  243. eval_dataloader = None
  244. if train_config.run_validation:
  245. if train_config.batching_strategy == "packing":
  246. if is_vision:
  247. raise ValueError("Packing is not supported for vision datasets")
  248. else:
  249. dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
  250. val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
  251. if custom_data_collator:
  252. val_dl_kwargs["collate_fn"] = custom_data_collator
  253. eval_dataloader = torch.utils.data.DataLoader(
  254. dataset_val,
  255. num_workers=train_config.num_workers_dataloader,
  256. pin_memory=True,
  257. **val_dl_kwargs,
  258. )
  259. print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
  260. if len(eval_dataloader) == 0:
  261. raise ValueError(f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})")
  262. else:
  263. print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
  264. # Initialize the optimizer and learning rate scheduler
  265. if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
  266. optimizer = AnyPrecisionAdamW(
  267. model.parameters(),
  268. lr=train_config.lr,
  269. momentum_dtype=torch.bfloat16,
  270. variance_dtype=torch.bfloat16,
  271. use_kahan_summation=False,
  272. weight_decay=train_config.weight_decay,
  273. )
  274. else:
  275. optimizer = optim.AdamW(
  276. model.parameters(),
  277. lr=train_config.lr,
  278. weight_decay=train_config.weight_decay,
  279. )
  280. scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
  281. results = train(
  282. model,
  283. train_dataloader,
  284. eval_dataloader,
  285. tokenizer,
  286. optimizer,
  287. scheduler,
  288. train_config.gradient_accumulation_steps,
  289. train_config,
  290. fsdp_config if train_config.enable_fsdp else None,
  291. local_rank if train_config.enable_fsdp else None,
  292. rank if train_config.enable_fsdp else None,
  293. wandb_run,
  294. )
  295. if not train_config.enable_fsdp or rank==0:
  296. [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
  297. if train_config.use_wandb:
  298. for k,v in results.items():
  299. wandb_run.summary[k] = v
  300. if __name__ == "__main__":
  301. fire.Fire(main)