|
@@ -4,6 +4,7 @@
|
|
|
import os
|
|
|
import sys
|
|
|
from typing import List
|
|
|
+import yaml
|
|
|
|
|
|
import fire
|
|
|
import torch
|
|
@@ -67,7 +68,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
scaler = ShardedGradScaler()
|
|
|
elif train_config.use_fp16 and not train_config.enable_fsdp:
|
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
|
-
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ world_size = int(os.environ["WORLD_SIZE"])
|
|
|
train_prep = []
|
|
|
train_loss = []
|
|
|
val_prep = []
|
|
@@ -80,7 +82,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
with MemoryTrace() as memtrace: # track the memory usage
|
|
|
model.train()
|
|
|
total_loss = 0.0
|
|
|
- data_set_len = 0
|
|
|
for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
|
|
|
for key in batch.keys():
|
|
|
if train_config.enable_fsdp:
|
|
@@ -90,8 +91,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
loss = model(**batch).loss
|
|
|
loss = loss / gradient_accumulation_steps
|
|
|
total_loss += loss.detach().float()
|
|
|
- first_key = next(iter(batch))
|
|
|
- data_set_len += len(batch[first_key])
|
|
|
if train_config.use_fp16:
|
|
|
# if fp16 is enabled, use gradient scaler to handle gradient update
|
|
|
scaler.scale(loss).backward()
|
|
@@ -122,12 +121,19 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
|
|
|
train_prep.append(train_perplexity)
|
|
|
train_loss.append(train_epoch_loss)
|
|
|
-
|
|
|
- print(f"Max CUDA memory allocated was {memtrace.peak} GB")
|
|
|
- print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
|
|
|
- print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
|
|
|
- print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
|
|
|
- print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ if rank==0:
|
|
|
+ print(f"Max CUDA memory allocated was {memtrace.peak} GB")
|
|
|
+ print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
|
|
|
+ print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
|
|
|
+ print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
|
|
|
+ print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
|
|
|
+ else:
|
|
|
+ print(f"Max CUDA memory allocated was {memtrace.peak} GB")
|
|
|
+ print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
|
|
|
+ print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
|
|
|
+ print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
|
|
|
+ print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
|
|
|
|
|
|
# Update the learning rate as needed
|
|
|
lr_scheduler.step()
|
|
@@ -135,35 +141,53 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
if train_config.run_validation:
|
|
|
eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)
|
|
|
if train_config.save_model and eval_epoch_loss < best_val_loss:
|
|
|
-
|
|
|
- if train_config.use_peft:
|
|
|
-
|
|
|
- print(f"we are in the saving the PEFT modules")
|
|
|
- model.save_pretrained(train_config.output_dir)
|
|
|
- print(f"PEFT modules are saved in {train_config.output_dir} directory")
|
|
|
-
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ dist.barrier()
|
|
|
+ if train_config.use_peft:
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ if rank==0:
|
|
|
+ print(f"we are about to save the PEFT modules")
|
|
|
+ else:
|
|
|
+ print(f"we are about to save the PEFT modules")
|
|
|
+ model.save_pretrained(train_config.output_dir)
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ if rank==0:
|
|
|
+ print(f"PEFT modules are saved in {train_config.output_dir} directory")
|
|
|
+ else:
|
|
|
+ print(f"PEFT modules are saved in {train_config.output_dir} directory")
|
|
|
+
|
|
|
else:
|
|
|
if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
|
|
|
|
|
|
model_checkpointing.save_model_checkpoint(
|
|
|
- model, optimizer, rank, train_config, epoch=1
|
|
|
+ model, optimizer, rank, train_config, epoch=epoch
|
|
|
)
|
|
|
elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
|
|
|
- print(" we are about to save the models *******")
|
|
|
+ print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
|
|
|
+ print("=====================================================")
|
|
|
|
|
|
model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
|
|
|
if train_config.save_optimizer:
|
|
|
model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
|
|
|
+ print(" Saving the FSDP model checkpoints qnd optimizer using SHARDED_STATE_DICT")
|
|
|
+ print("=====================================================")
|
|
|
|
|
|
if not train_config.use_peft and train_config.save_optimizer:
|
|
|
model_checkpointing.save_optimizer_checkpoint(
|
|
|
- model, optimizer, rank, train_config, epoch=1
|
|
|
- )
|
|
|
-
|
|
|
+ model, optimizer, rank, train_config, epoch=epoch
|
|
|
+ )
|
|
|
+ print(" Saving the FSDP model checkpoints qnd optimizer using FULL_STATE_DICT")
|
|
|
+ print("=====================================================")
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ dist.barrier()
|
|
|
|
|
|
- if local_rank == 0 and eval_epoch_loss < best_val_loss:
|
|
|
+ if eval_epoch_loss < best_val_loss:
|
|
|
best_val_loss = eval_epoch_loss
|
|
|
- print(f"best eval loss on epoch {epoch} is {best_val_loss}")
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ if rank==0:
|
|
|
+ print(f"best eval loss on epoch {epoch} is {best_val_loss}")
|
|
|
+ else:
|
|
|
+ print(f"best eval loss on epoch {epoch} is {best_val_loss}")
|
|
|
val_loss.append(best_val_loss)
|
|
|
val_prep.append(eval_ppl)
|
|
|
|
|
@@ -171,7 +195,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
|
|
|
lr_scheduler.step()
|
|
|
avg_epoch_time = sum(epoch_times)/len(epoch_times)
|
|
|
- print("avg epoch time is {avg_epoch_time}")
|
|
|
+ print(f"avg epoch time is {avg_epoch_time}")
|
|
|
print("==========================================")
|
|
|
avg_train_prep = sum(train_prep)/len(train_prep)
|
|
|
avg_train_loss = sum(train_loss)/len(train_loss)
|
|
@@ -185,7 +209,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
results['avg_eval_prep'] = avg_eval_prep
|
|
|
results['avg_eval_loss'] = avg_eval_loss
|
|
|
|
|
|
-
|
|
|
+ #saving the training params including fsdp setting for reference.
|
|
|
+ if train_config.enable_fsdp and not train_config.use_peft:
|
|
|
+ save_train_params(train_config, fsdp_config, rank)
|
|
|
+
|
|
|
return results
|
|
|
|
|
|
def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
@@ -200,10 +227,11 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
|
|
|
Returns: eval_ppl, eval_epoch_loss
|
|
|
"""
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ world_size = int(os.environ["WORLD_SIZE"])
|
|
|
model.eval()
|
|
|
eval_preds = []
|
|
|
eval_loss = 0.0 # Initialize evaluation loss
|
|
|
- eval_dataset_len = 0
|
|
|
with MemoryTrace() as memtrace:
|
|
|
for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
|
|
|
for key in batch.keys():
|
|
@@ -217,9 +245,6 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
outputs = model(**batch)
|
|
|
loss = outputs.loss
|
|
|
eval_loss += loss.detach().float()
|
|
|
- first_key = next(iter(batch))
|
|
|
- eval_dataset_len+= len(batch[first_key])
|
|
|
-
|
|
|
# Decode predictions and add to evaluation predictions list
|
|
|
preds = torch.argmax(outputs.logits, -1)
|
|
|
eval_preds.extend(
|
|
@@ -233,11 +258,17 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
|
|
|
# Compute average loss and perplexity
|
|
|
eval_epoch_loss = eval_loss / len(eval_dataloader)
|
|
|
- eval_epoch_loss = eval_epoch_loss/world_size
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ eval_epoch_loss = eval_epoch_loss/world_size
|
|
|
eval_ppl = torch.exp(eval_epoch_loss)
|
|
|
|
|
|
# Print evaluation metrics
|
|
|
- print(f" {eval_ppl=} {eval_epoch_loss=}")
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ if local_rank==0:
|
|
|
+ print(f" {eval_ppl=} {eval_epoch_loss=}")
|
|
|
+ else:
|
|
|
+ print(f" {eval_ppl=} {eval_epoch_loss=}")
|
|
|
+
|
|
|
return eval_ppl, eval_epoch_loss
|
|
|
|
|
|
def freeze_transformer_layers(model, num_layer):
|
|
@@ -262,7 +293,10 @@ def setup_environ_flags(rank):
|
|
|
"""Set environment flags for debugging purposes"""
|
|
|
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
|
|
|
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
|
|
|
- os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
|
|
|
+ # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
|
|
|
+ # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
|
|
|
+ # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
|
|
|
+ # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
|
|
|
if rank == 0:
|
|
|
print(f"--> Running with torch dist debug set to detail")
|
|
|
|
|
@@ -336,3 +370,42 @@ def get_policies(cfg, rank):
|
|
|
print(f"bFloat16 support not present. Using FP32, and not mixed precision")
|
|
|
wrapping_policy = get_llama_wrapper()
|
|
|
return mixed_precision_policy, wrapping_policy
|
|
|
+
|
|
|
+def save_train_params(train_config, fsdp_config, rank):
|
|
|
+ """
|
|
|
+ This function saves the train_config and FSDP config into a train_params.yaml.
|
|
|
+ This will be used by converter script in the inference folder to fetch the HF model name or path.
|
|
|
+ It also would be hepful as a log for future references.
|
|
|
+ """
|
|
|
+ # Convert the train_config and fsdp_config objects to dictionaries,
|
|
|
+ # converting all values to strings to ensure they can be serialized into a YAML file
|
|
|
+ train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
|
|
|
+ fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
|
|
|
+ # Merge the two dictionaries into one
|
|
|
+ train_params_dict = {**train_config_dict, **fsdp_config_dict}
|
|
|
+ # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
|
|
|
+ folder_name = (
|
|
|
+ train_config.dist_checkpoint_root_folder
|
|
|
+ + "/"
|
|
|
+ + train_config.dist_checkpoint_folder
|
|
|
+ + "-"
|
|
|
+ + train_config.model_name
|
|
|
+ )
|
|
|
+
|
|
|
+ save_dir = Path.cwd() / folder_name
|
|
|
+ # If the directory does not exist, create it
|
|
|
+ if not os.path.exists(save_dir):
|
|
|
+ os.makedirs(save_dir)
|
|
|
+ # Convert the dictionary to a YAML string
|
|
|
+ config_yaml = yaml.dump(train_params_dict, indent=4)
|
|
|
+ file_name = os.path.join(save_dir,'train_params.yaml')
|
|
|
+
|
|
|
+ # Check if there's a directory with the same name as the file
|
|
|
+ if os.path.isdir(file_name):
|
|
|
+ print(f"Error: {file_name} is a directory, not a file.")
|
|
|
+ else:
|
|
|
+ # Write the YAML string to the file
|
|
|
+ with open(file_name, 'w') as f:
|
|
|
+ f.write(config_yaml)
|
|
|
+ if rank==0:
|
|
|
+ print(f"training params are saved in {file_name}")
|