|
@@ -3,7 +3,7 @@
|
|
|
|
|
|
import os
|
|
|
import sys
|
|
|
-from typing import List
|
|
|
+from typing import List, Optional
|
|
|
import yaml
|
|
|
|
|
|
import fire
|
|
@@ -36,6 +36,9 @@ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
|
|
from pathlib import Path
|
|
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
|
|
from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
|
|
|
+import torch.autograd.profiler as profiler
|
|
|
+from torch.cuda._memory_viz import profile_plot
|
|
|
+from pickle import dump
|
|
|
|
|
|
def set_tokenizer_params(tokenizer: LlamaTokenizer):
|
|
|
tokenizer.pad_token_id = 0
|
|
@@ -64,6 +67,13 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
Returns: results dictionary containing average training and validation perplexity and loss
|
|
|
"""
|
|
|
# Create a gradient scaler for fp16
|
|
|
+ torch.cuda.memory._record_memory_history(True,
|
|
|
+ # keep 100,000 alloc/free events from before the snapshot
|
|
|
+ trace_alloc_max_entries=100000,
|
|
|
+
|
|
|
+ # record stack information for the trace events
|
|
|
+ trace_alloc_record_context=True)
|
|
|
+
|
|
|
if train_config.use_fp16 and train_config.enable_fsdp:
|
|
|
scaler = ShardedGradScaler()
|
|
|
elif train_config.use_fp16 and not train_config.enable_fsdp:
|
|
@@ -82,7 +92,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
with MemoryTrace() as memtrace: # track the memory usage
|
|
|
model.train()
|
|
|
total_loss = 0.0
|
|
|
+ # if fsdp_config.profile_mem:
|
|
|
+ # with torch.profiler.profile(
|
|
|
+ # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
|
|
|
+ # activities=[profiler.ProfilerActivity.CPU, profiler.ProfilerActivity.CUDA],
|
|
|
+ # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/llama2-7b'),
|
|
|
+ # record_shapes=True,
|
|
|
+ # profile_memory=True,
|
|
|
+ # with_stack=True,
|
|
|
+ # ) as prof:
|
|
|
for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
|
|
|
+ if step >10:
|
|
|
+ break
|
|
|
for key in batch.keys():
|
|
|
if train_config.enable_fsdp:
|
|
|
batch[key] = batch[key].to(local_rank)
|
|
@@ -104,8 +125,14 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
optimizer.step()
|
|
|
optimizer.zero_grad()
|
|
|
+ if step == 4:
|
|
|
+ if rank==0:
|
|
|
+ snapshot = torch.cuda.memory._snapshot()
|
|
|
+ with open('snapshot.pickle', 'wb') as f:
|
|
|
+ dump(snapshot, f)
|
|
|
|
|
|
print(f"\n step {step} is completed and loss is {loss.detach().float()}")
|
|
|
+
|
|
|
end_epoch = time.perf_counter()
|
|
|
epoch_time = end_epoch- start_epoch
|
|
|
print(f"epoch time is {epoch_time}")
|
|
@@ -235,6 +262,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
eval_loss = 0.0 # Initialize evaluation loss
|
|
|
with MemoryTrace() as memtrace:
|
|
|
for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
|
|
|
+ if step>6:
|
|
|
+ break
|
|
|
for key in batch.keys():
|
|
|
if train_config.enable_fsdp:
|
|
|
batch[key] = batch[key].to(local_rank)
|
|
@@ -409,4 +438,27 @@ def save_train_params(train_config, fsdp_config, rank):
|
|
|
with open(file_name, 'w') as f:
|
|
|
f.write(config_yaml)
|
|
|
if rank==0:
|
|
|
- print(f"training params are saved in {file_name}")
|
|
|
+ print(f"training params are saved in {file_name}")
|
|
|
+
|
|
|
+
|
|
|
+def export_memory_timeline(path: str, device: Optional[str] = None) -> None:
|
|
|
+ try:
|
|
|
+ from torch.profiler._memory_profiler import MemoryProfile, MemoryProfileTimeline
|
|
|
+ except ImportError:
|
|
|
+ # Handle the ImportError here, such as providing an alternative implementation or an error message.
|
|
|
+ print("The required module 'MemoryProfileTimeline' is not available.")
|
|
|
+
|
|
|
+
|
|
|
+def _memory_profile():
|
|
|
+ try:
|
|
|
+ from torch.profiler._memory_profiler import MemoryProfile, MemoryProfileTimeline
|
|
|
+ except ImportError:
|
|
|
+ # Handle the ImportError here, such as providing an alternative implementation or an error message.
|
|
|
+ print("The required module 'MemoryProfileTimeline' is not available.")
|
|
|
+ required = ("record_shapes", "profile_memory", "with_stack")
|
|
|
+ missing = [f"{i}=True" for i in required if not getattr(self, i)]
|
|
|
+ if missing:
|
|
|
+ raise ValueError(f"{', '.join(missing)} required for memory profiling.")
|
|
|
+
|
|
|
+ assert self.profiler is not None and self.profiler.kineto_results is not None
|
|
|
+ return MemoryProfile(self.profiler.kineto_results)
|