|
@@ -10,6 +10,7 @@ import torch
|
|
|
import transformers
|
|
|
from datasets import load_dataset
|
|
|
from tqdm import tqdm
|
|
|
+import time
|
|
|
"""
|
|
|
Unused imports:
|
|
|
import torch.nn as nn
|
|
@@ -73,7 +74,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
val_loss =[]
|
|
|
results = {}
|
|
|
best_val_loss = float("inf")
|
|
|
+ epoch_times=[]
|
|
|
for epoch in range(train_config.num_epochs):
|
|
|
+ start_epoch = time.perf_counter()
|
|
|
with MemoryTrace() as memtrace: # track the memory usage
|
|
|
model.train()
|
|
|
total_loss = 0.0
|
|
@@ -104,10 +107,17 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
print(f"\n step {step} is completed and loss is {loss.detach().float()}")
|
|
|
+ end_epoch = time.perf_counter()
|
|
|
+ epoch_time = end_epoch- start_epoch
|
|
|
+ print(f"epoch time is {epoch_time}")
|
|
|
+ print("==================================================")
|
|
|
+ epoch_times.append(epoch_time)
|
|
|
# Reducing total_loss across all devices if there's more than one CUDA device
|
|
|
if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
|
|
|
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
|
|
|
- train_epoch_loss = total_loss / data_set_len
|
|
|
+ world_size = int(os.environ["WORLD_SIZE"])
|
|
|
+ train_epoch_loss = total_loss / len(train_dataloader)
|
|
|
+ train_epoch_loss = train_epoch_loss/world_size
|
|
|
train_perplexity = torch.exp(train_epoch_loss)
|
|
|
|
|
|
train_prep.append(train_perplexity)
|
|
@@ -160,7 +170,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
|
|
|
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
|
|
|
lr_scheduler.step()
|
|
|
-
|
|
|
+ avg_epoch_time = sum(epoch_times)/len(epoch_times)
|
|
|
+ print("avg epoch time is {avg_epoch_time}")
|
|
|
+ print("==========================================")
|
|
|
avg_train_prep = sum(train_prep)/len(train_prep)
|
|
|
avg_train_loss = sum(train_loss)/len(train_loss)
|
|
|
if train_config.run_validation:
|
|
@@ -217,9 +229,11 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
# If there's more than one CUDA device, reduce evaluation loss across all devices
|
|
|
if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
|
|
|
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
|
|
|
+ world_size = int(os.environ["WORLD_SIZE"])
|
|
|
|
|
|
# Compute average loss and perplexity
|
|
|
- eval_epoch_loss = eval_loss / eval_dataset_len
|
|
|
+ eval_epoch_loss = eval_loss / len(eval_dataloader)
|
|
|
+ eval_epoch_loss = eval_epoch_loss/world_size
|
|
|
eval_ppl = torch.exp(eval_epoch_loss)
|
|
|
|
|
|
# Print evaluation metrics
|