|
@@ -220,70 +220,74 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
|
|
|
# Update the learning rate as needed
|
|
|
lr_scheduler.step()
|
|
|
+ should_save_model = train_config.save_model
|
|
|
if train_config.run_validation:
|
|
|
eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
|
|
|
if train_config.save_metrics:
|
|
|
val_step_loss.extend(temp_val_loss)
|
|
|
val_step_perplexity.extend(temp_step_perplexity)
|
|
|
-
|
|
|
- checkpoint_start_time = time.perf_counter()
|
|
|
- if train_config.save_model and eval_epoch_loss < best_val_loss:
|
|
|
+ should_save_model = train_config.save_model and eval_epoch_loss < best_val_loss
|
|
|
+
|
|
|
+ checkpoint_start_time = time.perf_counter()
|
|
|
+ if should_save_model:
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ dist.barrier()
|
|
|
+ if train_config.use_peft:
|
|
|
if train_config.enable_fsdp:
|
|
|
- dist.barrier()
|
|
|
- if train_config.use_peft:
|
|
|
- if train_config.enable_fsdp:
|
|
|
- if rank==0:
|
|
|
- print(f"we are about to save the PEFT modules")
|
|
|
- else:
|
|
|
+ if rank==0:
|
|
|
print(f"we are about to save the PEFT modules")
|
|
|
- save_peft_checkpoint(model, train_config.output_dir)
|
|
|
- if train_config.enable_fsdp:
|
|
|
- if rank==0:
|
|
|
- print(f"PEFT modules are saved in {train_config.output_dir} directory")
|
|
|
- else:
|
|
|
+ else:
|
|
|
+ print(f"we are about to save the PEFT modules")
|
|
|
+ save_peft_checkpoint(model, train_config.output_dir)
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ if rank==0:
|
|
|
print(f"PEFT modules are saved in {train_config.output_dir} directory")
|
|
|
-
|
|
|
else:
|
|
|
- if not train_config.enable_fsdp:
|
|
|
- save_model_checkpoint(model, train_config.output_dir)
|
|
|
-
|
|
|
- elif fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
|
|
|
- print(" Saving the FSDP model checkpoint using FULL_STATE_DICT")
|
|
|
+ print(f"PEFT modules are saved in {train_config.output_dir} directory")
|
|
|
+
|
|
|
+ else:
|
|
|
+ if not train_config.enable_fsdp:
|
|
|
+ save_model_checkpoint(model, train_config.output_dir)
|
|
|
+
|
|
|
+ elif fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
|
|
|
+ print(" Saving the FSDP model checkpoint using FULL_STATE_DICT")
|
|
|
+ print("=====================================================")
|
|
|
+ save_fsdp_model_checkpoint_full(
|
|
|
+ model, optimizer, rank, train_config, epoch=epoch
|
|
|
+ )
|
|
|
+
|
|
|
+ if train_config.save_optimizer:
|
|
|
+ print(" Saving the FSDP optimizer using FULL_STATE_DICT")
|
|
|
print("=====================================================")
|
|
|
- save_fsdp_model_checkpoint_full(
|
|
|
+ save_optimizer_checkpoint(
|
|
|
model, optimizer, rank, train_config, epoch=epoch
|
|
|
)
|
|
|
-
|
|
|
- if train_config.save_optimizer:
|
|
|
- print(" Saving the FSDP optimizer using FULL_STATE_DICT")
|
|
|
- print("=====================================================")
|
|
|
- save_optimizer_checkpoint(
|
|
|
- model, optimizer, rank, train_config, epoch=epoch
|
|
|
- )
|
|
|
-
|
|
|
- elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
|
|
|
-
|
|
|
- if train_config.save_optimizer:
|
|
|
- print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
|
|
|
- print("=====================================================")
|
|
|
- save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
|
|
|
- else:
|
|
|
- print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
|
|
|
- print("=====================================================")
|
|
|
- save_model_and_optimizer_sharded(model, rank, train_config)
|
|
|
+
|
|
|
+ elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
|
|
|
|
|
|
-
|
|
|
- if train_config.enable_fsdp:
|
|
|
- dist.barrier()
|
|
|
- checkpoint_end_time = time.perf_counter() - checkpoint_start_time
|
|
|
- checkpoint_times.append(checkpoint_end_time)
|
|
|
+ if train_config.save_optimizer:
|
|
|
+ print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
|
|
|
+ print("=====================================================")
|
|
|
+ save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
|
|
|
+ else:
|
|
|
+ print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
|
|
|
+ print("=====================================================")
|
|
|
+ save_model_and_optimizer_sharded(model, rank, train_config)
|
|
|
+
|
|
|
+
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ dist.barrier()
|
|
|
+ checkpoint_end_time = time.perf_counter() - checkpoint_start_time
|
|
|
+ checkpoint_times.append(checkpoint_end_time)
|
|
|
+
|
|
|
+ if train_config.run_validation:
|
|
|
if eval_epoch_loss < best_val_loss:
|
|
|
best_val_loss = eval_epoch_loss
|
|
|
if train_config.enable_fsdp:
|
|
|
if rank==0:
|
|
|
print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
|
|
|
else:
|
|
|
- print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
|
|
|
+ print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
|
|
|
val_loss.append(float(best_val_loss))
|
|
|
val_prep.append(float(eval_ppl))
|
|
|
if train_config.enable_fsdp:
|