浏览代码

first draft of flop counter feature

Kai Wu 1 年之前
父节点
当前提交
2a5de9b448
共有 2 个文件被更改,包括 579 次插入65 次删除
  1. 471 0
      src/llama_recipes/utils/tflop_counter.py
  2. 108 65
      src/llama_recipes/utils/train_utils.py

+ 471 - 0
src/llama_recipes/utils/tflop_counter.py

@@ -0,0 +1,471 @@
+# Temp copy of Horace Flops Counter.
+# This supports distributed to avoid printing * every GPU.
+# Remove after main file is updated.
+
+import torch
+from torch.utils._pytree import tree_map
+from typing import List, Any, Dict, Optional, Union
+from collections import defaultdict
+from torch.utils._python_dispatch import TorchDispatchMode
+from math import prod
+
+__all__ = ["FlopCounterMode"]
+
+aten = torch.ops.aten
+
+
+def get_shape(i):
+    if isinstance(i, torch.Tensor):
+        return i.shape
+    return i
+
+
+def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int:
+    """
+    Count flops for matmul.
+    """
+    # Inputs should be a list of length 2.
+    # Inputs contains the shapes of two matrices.
+    m, k = a_shape
+    k2, n = b_shape
+    assert k == k2
+    # NB(chilli): Should be 2 * k - 1 technically for FLOPs.
+    return m * n * 2 * k
+
+
+def addmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
+    """
+    Count flops for addmm
+    """
+    return mm_flop(a_shape, b_shape)
+
+
+def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int:
+    """
+    Count flops for the bmm operation.
+    """
+    # Inputs should be a list of length 2.
+    # Inputs contains the shapes of two tensor.
+    b, m, k = a_shape
+    b2, k2, n = b_shape
+    assert b == b2
+    assert k == k2
+    # NB(chilli): Should be 2 * k - 1 technically for FLOPs.
+    flop = b * m * n * 2 * k
+    return flop
+
+
+def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
+    """
+    Count flops for the baddbmm operation.
+    """
+    # Inputs should be a list of length 3.
+    # Inputs contains the shapes of three tensors.
+    return bmm_flop(a_shape, b_shape)
+
+
+def conv_flop_count(
+    x_shape: List[int],
+    w_shape: List[int],
+    out_shape: List[int],
+    transposed: bool = False,
+) -> int:
+    """
+    Count flops for convolution. Note only multiplication is
+    counted. Computation for bias are ignored.
+    Flops for a transposed convolution are calculated as
+    flops = (x_shape[2:] * prod(w_shape) * batch_size).
+    Args:
+        x_shape (list(int)): The input shape before convolution.
+        w_shape (list(int)): The filter shape.
+        out_shape (list(int)): The output shape after convolution.
+        transposed (bool): is the convolution transposed
+    Returns:
+        int: the number of flops
+    """
+    batch_size = x_shape[0]
+    conv_shape = (x_shape if transposed else out_shape)[2:]
+    c_out, c_in, *dims = w_shape
+
+    # NB(chilli): I don't think this properly accounts for padding :think:
+    # NB(chilli): Should be 2 * c_in - 1 technically for FLOPs.
+    flop = batch_size * prod(conv_shape) * c_out * prod(dims) * 2 * c_in
+    return flop
+
+
+def conv_flop(
+    x_shape,
+    w_shape,
+    _bias,
+    _stride,
+    _padding,
+    _dilation,
+    transposed,
+    *args,
+    out_shape=None,
+    **kwargs
+) -> int:
+    """
+    Count flops for convolution.
+    """
+    return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
+
+
+def transpose_shape(shape):
+    return [shape[1], shape[0]] + list(shape[2:])
+
+
+def conv_backward_flop(
+    grad_out_shape,
+    x_shape,
+    w_shape,
+    _bias,
+    _stride,
+    _padding,
+    _dilation,
+    transposed,
+    _output_padding,
+    _groups,
+    output_mask,
+    out_shape,
+) -> int:
+    flop_count = 0
+
+    if output_mask[0]:
+        grad_input_shape = get_shape(out_shape[0])
+        flop_count += conv_flop_count(
+            grad_out_shape, w_shape, grad_input_shape, not transposed
+        )
+    if output_mask[1]:
+        grad_weight_shape = get_shape(out_shape[1])
+        flop_count += conv_flop_count(
+            transpose_shape(x_shape), grad_out_shape, grad_weight_shape, transposed
+        )
+
+    return flop_count
+
+
+def sdpa_flop_count(query_shape, key_shape, value_shape):
+    """
+    Count flops for self-attention.
+    NB: We can assume that value_shape == key_shape
+    """
+    b, h, s_q, d_q = query_shape
+    _b2, _h2, s_k, _d2 = key_shape
+    _b3, _h3, _s3, d_v = value_shape
+    assert (
+        b == _b2 == _b3 and h == _h2 == _h3 and d_q == _d2 and s_k == _s3 and d_q == _d2
+    )
+    total_flops = 0
+    # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
+    total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
+    # scores: [b, h, s_q, s_k] @ v: [b, h, s_k, d_v] -> out: [b, h, s_q, d_v]
+    total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_v))
+    return total_flops
+
+
+def sdpa_flop(
+    query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs
+) -> int:
+    """
+    Count flops for self-attention.
+    """
+    # NB: We aren't accounting for causal attention here
+    return sdpa_flop_count(query_shape, key_shape, value_shape)
+
+
+def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape):
+    total_flops = 0
+    b, h, s_q, d_q = query_shape
+    _b2, _h2, s_k, _d2 = key_shape
+    _b3, _h3, _s3, d_v = value_shape
+    _b4, _h4, _s4, _d4 = grad_out_shape
+    assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_q == _d2
+    assert d_v == _d4 and s_k == _s3 and s_q == _s4
+    total_flops = 0
+    # Step 1: We recompute the scores matrix.
+    # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
+    total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
+
+    # Step 2: We propagate the gradients through the score @ v operation.
+    # gradOut: [b, h, s_q, d_v] @ v: [b, h, d_v, s_k] -> gradScores: [b, h, s_q, s_k]
+    total_flops += bmm_flop((b * h, s_q, d_v), (b * h, d_v, s_k))
+    # scores: [b, h, s_k, s_q] @ gradOut: [b, h, s_q, d_v] -> gradV: [b, h, s_k, d_v]
+    total_flops += bmm_flop((b * h, s_k, s_q), (b * h, s_q, d_v))
+
+    # Step 3: We propagate th gradients through the k @ v operation
+    # gradScores: [b, h, s_q, s_k] @ k: [b, h, s_k, d_q] -> gradQ: [b, h, s_q, d_q]
+    total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_q))
+    # q: [b, h, d_q, s_q] @ gradScores: [b, h, s_q, s_k] -> gradK: [b, h, d_q, s_k]
+    total_flops += bmm_flop((b * h, d_q, s_q), (b * h, s_q, s_k))
+    return total_flops
+
+
+def sdpa_backward_flop(
+    grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs
+) -> int:
+    """
+    Count flops for self-attention backward.
+    """
+    return sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
+
+
+flop_mapping = {
+    aten.mm: mm_flop,
+    aten.addmm: addmm_flop,
+    aten.bmm: bmm_flop,
+    aten.baddbmm: baddbmm_flop,
+    aten.convolution: conv_flop,
+    aten._convolution: conv_flop,
+    aten.convolution_backward: conv_backward_flop,
+    aten._scaled_dot_product_efficient_attention: sdpa_flop,
+    aten._scaled_dot_product_flash_attention: sdpa_flop,
+    aten._scaled_dot_product_efficient_attention_backward: sdpa_backward_flop,
+    aten._scaled_dot_product_flash_attention_backward: sdpa_backward_flop,
+}
+
+
+def normalize_tuple(x):
+    if not isinstance(x, tuple):
+        return (x,)
+    return x
+
+
+# Define the suffixes for different orders of magnitude
+suffixes = ["", "K", "M", "B", "T"]
+
+
+def get_suffix_str(number):
+    # Find the index of the appropriate suffix based on the number of digits
+    # with some additional overflow.
+    # i.e. 1.01B should be displayed as 1001M, not 1.001B
+    index = max(0, min(len(suffixes) - 1, (len(str(number)) - 3) // 3))
+    return suffixes[index]
+
+
+def convert_num_with_suffix(number, suffix):
+    index = suffixes.index(suffix)
+    # Divide the number by 1000^index and format it to two decimal places
+    value = "{:.3f}".format(number / (1000**index))
+    # Return the value and the suffix as a string
+    return value + suffixes[index]
+
+
+class FlopCounterMode(TorchDispatchMode):
+    """
+    ``FlopCounterMode`` is a context manager that counts the number of
+    flops within its context. It does this using a ``TorchDispatchMode``.
+
+    It also supports hierarchical output by passing a module (or list of modules) to FlopCounterMode on construction.
+
+    Example usage
+
+    .. code-block:: python
+
+        mod = ...
+        flop_counter = FlopCounterMode(mod)
+        with flop_counter:
+            mod.sum().backward()
+
+    """
+
+    def __init__(
+        self,
+        mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None,
+        depth: int = 2,
+        display: bool = True,
+        custom_mapping: Dict[Any, Any] = None,
+        rank=None,
+    ):
+        self.flop_counts: Dict[str, Dict[Any, int]] = defaultdict(
+            lambda: defaultdict(int)
+        )
+        self.depth = depth
+        self.parents = ["Global"]
+        self.display = display
+        self.rank = rank
+        self.ready = False
+        if custom_mapping is None:
+            custom_mapping = {}
+        if isinstance(mods, torch.nn.Module):
+            mods = [mods]
+        self.mods = mods
+        if mods is not None:
+            for mod in mods:
+                prefix = type(mod).__name__
+                for name, module in dict(mod.named_modules()).items():
+                    if name == "":
+                        name = prefix
+                    else:
+                        name = ".".join([prefix, name])
+                    module.register_forward_pre_hook(self._enter_module(name))
+                    module.register_forward_hook(self._exit_module(name))
+        self.flop_mapping = {**flop_mapping, **custom_mapping}
+
+    def _enter_module(self, name):
+        def f(module, inputs):
+            inputs = normalize_tuple(inputs)
+            out = self._create_pre_module(name)(*inputs)
+            return out
+
+        return f
+
+    def _exit_module(self, name):
+        def f(module, inputs, outputs):
+            outputs = normalize_tuple(outputs)
+            return self._create_post_module(name)(*outputs)
+
+        return f
+
+    def _create_post_module(self, name):
+        class PushState(torch.autograd.Function):
+            @staticmethod
+            def forward(ctx, *args):
+                assert self.parents[-1] == name
+                self.parents.pop()
+                args = tree_map(
+                    lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args
+                )
+                if len(args) == 1:
+                    return args[0]
+                return args
+
+            @staticmethod
+            def backward(ctx, *grad_outs):
+                self.parents.append(name)
+                return grad_outs
+
+        return PushState.apply
+
+    def _create_pre_module(self, name):
+        class PopState(torch.autograd.Function):
+            @staticmethod
+            def forward(ctx, *args):
+                self.parents.append(name)
+                args = tree_map(
+                    lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args
+                )
+                if len(args) == 1:
+                    return args[0]
+                return args
+
+            @staticmethod
+            def backward(ctx, *grad_outs):
+                assert self.parents[-1] == name
+                self.parents.pop()
+                return grad_outs
+
+        return PopState.apply
+
+    def get_total_flops(self) -> int:
+        return sum(self.flop_counts["Global"].values())
+
+    def get_flop_counts(self) -> Dict[str, Dict[Any, int]]:
+        """Returns the flop counts as a dictionary of dictionaries. The outer
+        dictionary is keyed by module name, and the inner dictionary is keyed by
+        operation name.
+
+        Returns:
+            Dict[str, Dict[Any, int]]: The flop counts as a dictionary.
+        """
+        return dict(self.flop_counts)
+
+    def get_table(self, depth=None):
+        if depth is None:
+            depth = self.depth
+        if depth is None:
+            depth = 999999
+
+        import tabulate
+
+        tabulate.PRESERVE_WHITESPACE = True
+        header = ["Module", "FLOP", "% Total"]
+        values = []
+        global_flops = self.get_total_flops()
+        global_suffix = get_suffix_str(global_flops)
+        is_global_subsumed = False
+
+        def process_mod(mod_name, depth):
+            nonlocal is_global_subsumed
+
+            total_flops = sum(self.flop_counts[mod_name].values())
+
+            is_global_subsumed |= total_flops >= global_flops
+
+            padding = " " * depth
+            values = []
+            values.append(
+                [
+                    padding + mod_name,
+                    convert_num_with_suffix(total_flops, global_suffix),
+                    "{:.2f}%".format(total_flops / global_flops * 100),
+                ]
+            )
+            for k, v in self.flop_counts[mod_name].items():
+                values.append(
+                    [
+                        padding + " - " + str(k),
+                        convert_num_with_suffix(v, global_suffix),
+                        "{:.2f}%".format(v / global_flops * 100),
+                    ]
+                )
+            return values
+
+        for mod in self.flop_counts.keys():
+            if mod == "Global":
+                continue
+            mod_depth = mod.count(".") + 1
+            if mod_depth > depth:
+                continue
+
+            cur_values = process_mod(mod, mod_depth - 1)
+            for value in cur_values:
+                values.append(value)
+
+        # We do a bit of messing around here to only output the "Global" value
+        # if there are any FLOPs in there that aren't already fully contained by
+        # a module.
+        if "Global" in self.flop_counts and not is_global_subsumed:
+            for idx, value in enumerate(values):
+                values[idx][0] = " " + values[idx][0]
+
+            values = process_mod("Global", 0) + values
+
+        if len(values) == 0:
+            values = [["Global", "0", "0%"]]
+
+        return tabulate.tabulate(
+            values, headers=header, colalign=("left", "right", "right")
+        )
+
+    def __enter__(self):
+        self.flop_counts.clear()
+        self.ready = False
+        super().__enter__()
+        return self
+
+    def __exit__(self, *args):
+        self.stop_counting()
+        if self.display:
+            if self.rank is None or self.rank == 0:
+                print(self.get_table(self.depth))
+        super().__exit__(*args)
+    def start_counting(self):
+        self.flop_counts.clear()
+        self.ready = True
+    def stop_counting(self):
+        self.ready = False
+    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
+        if not self.ready:
+            return
+        kwargs = kwargs if kwargs else {}
+        out = func(*args, **kwargs)
+        func_packet = func._overloadpacket
+        if func_packet in self.flop_mapping:
+            flop_count_func = self.flop_mapping[func_packet]
+            args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out))
+            flop_count = flop_count_func(*args, **kwargs, out_shape=out_shape)  # type: ignore[operator]
+            for par in self.parents:
+                self.flop_counts[par][func_packet] += flop_count
+
+        return out

+ 108 - 65
src/llama_recipes/utils/train_utils.py

@@ -24,11 +24,45 @@ from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_
 from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
 from llama_recipes.utils.memory_utils import MemoryTrace
 from accelerate.utils import is_xpu_available, is_ccl_available
+from llama_recipes.utils.tflop_counter import FlopCounterMode
 
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
     tokenizer.padding_side = "left"
 
+@contextlib.contextmanager
+def throughput_measure_context(cfg, local_rank=None):
+    use_profiler: bool = cfg.use_profiler
+    use_flop_counter: bool = cfg.flop_counter
+    if use_flop_counter and use_profiler:
+        raise ValueError("Cannot use both profiler and flop counter")
+    if use_profiler:
+        print(f"profiling is activated and results will be saved in {cfg.profile_dir}")
+        with torch.profiler.profile(
+            activities=[
+                torch.profiler.ProfilerActivity.CPU,
+                torch.profiler.ProfilerActivity.CUDA,
+            ],
+            schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1),
+            on_trace_ready=torch.profiler.tensorboard_trace_handler(
+                cfg.profile_dir
+            ),
+            profile_memory=True,
+            with_stack=False,
+            record_shapes=True,
+        ) as torch_profiler:
+            yield torch_profiler
+    elif use_flop_counter:
+        torch_profiler = contextlib.nullcontext()
+        with FlopCounterMode(rank=local_rank) as flop_counter:
+            yeild flop_counter
+    else:
+        torch_profiler = contextlib.nullcontext()
+        yield None
+
+def get_total_flops(model):
+    return (sum([v for _, v in model.flop_counts["Global"].items()]))
+
 # Converting Bytes to Megabytes
 def byte2mb(x):
     return int(x / 2**20)
@@ -92,73 +126,81 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             total_loss = 0.0
             total_length = len(train_dataloader)//gradient_accumulation_steps
             pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
-            for step, batch in enumerate(train_dataloader):
-                total_train_steps += 1
-                # stop when the maximum number of training steps is reached
-                if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
-                    max_steps_reached = True
-                    if not train_config.enable_fsdp or local_rank==0:
-                        print("max training steps reached, stopping training, total_train_steps: ", total_train_steps-1)
-                    break
-                for key in batch.keys():
-                    if train_config.enable_fsdp:
-                        if is_xpu_available():
-                            batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
+            with throughput_measure_context(train_config,local_rank) as measure_context:
+                for step, batch in enumerate(train_dataloader):
+                    total_train_steps += 1
+                    # stop when the maximum number of training steps is reached
+                    if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
+                        max_steps_reached = True
+                        if not train_config.enable_fsdp or local_rank==0:
+                            print("max training steps reached, stopping training, total_train_steps: ", total_train_steps-1)
+                        break
+                    if traing_config.flop_counter and total_train_steps == train_config.flop_counter_startpoint:
+                        print("start flop counting at the step: ", total_train_steps)
+                        measure_context.start()
+                    if traing_config.flop_counter and total_train_steps == train_config.flop_counter_startpoint + 1:
+                        print("stop flop counting at the step: ", total_train_steps)
+                        TFlops = get_total_flops(flop_counter) / 1e12
+                        measure_context.stop()
+                    for key in batch.keys():
+                        if train_config.enable_fsdp:
+                            if is_xpu_available():
+                                batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
+                            else:
+                                batch[key] = batch[key].to(local_rank)
                         else:
-                            batch[key] = batch[key].to(local_rank)
-                    else:
 
-                        if is_xpu_available():
-                            batch[key] = batch[key].to('xpu:0')
-                        else:
-                            batch[key] = batch[key].to('cuda:0')
-                with autocast():
-                    loss = model(**batch).loss
-                loss = loss / gradient_accumulation_steps
-                if train_config.save_metrics:
-                    train_step_loss.append(loss.detach().float().item())
-                    train_step_perplexity.append(float(torch.exp(loss.detach().float())))
-                total_loss += loss.detach().float()
-                if train_config.use_fp16:
-                    # if fp16 is enabled, use gradient scaler to handle gradient update
-                    scaler.scale(loss).backward()
-                    if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
-                        if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
-                            scaler.unscale_(optimizer)
-                            if train_config.enable_fsdp:
-                                model.clip_grad_norm_(train_config.gradient_clipping_threshold)
+                            if is_xpu_available():
+                                batch[key] = batch[key].to('xpu:0')
                             else:
-                                torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
-                        scaler.step(optimizer)
-                        scaler.update()
-                        optimizer.zero_grad()
-                        pbar.update(1)
-                else:
-                    # regular backpropagation when fp16 is not used
-                    loss.backward()
-                    if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
-                        if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
-                            if train_config.enable_fsdp:
-                                model.clip_grad_norm_(train_config.gradient_clipping_threshold)
-                            else:
-                                torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
-                        optimizer.step()
-                        optimizer.zero_grad()
-                        pbar.update(1)
-
-                if wandb_run:
-                    if not train_config.enable_fsdp or rank==0:
-                        wandb_run.log({
-                            'train/epoch': epoch + 1,
-                            'train/step': epoch * len(train_dataloader) + step,
-                            'train/loss': loss.detach().float(),
-                        })
-
-                pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
-
-                if train_config.save_metrics:
-                    save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
-            pbar.close()
+                                batch[key] = batch[key].to('cuda:0')
+                    with autocast():
+                        loss = model(**batch).loss
+                    loss = loss / gradient_accumulation_steps
+                    if train_config.save_metrics:
+                        train_step_loss.append(loss.detach().float().item())
+                        train_step_perplexity.append(float(torch.exp(loss.detach().float())))
+                    total_loss += loss.detach().float()
+                    if train_config.use_fp16:
+                        # if fp16 is enabled, use gradient scaler to handle gradient update
+                        scaler.scale(loss).backward()
+                        if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                            if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
+                                scaler.unscale_(optimizer)
+                                if train_config.enable_fsdp:
+                                    model.clip_grad_norm_(train_config.gradient_clipping_threshold)
+                                else:
+                                    torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
+                            scaler.step(optimizer)
+                            scaler.update()
+                            optimizer.zero_grad()
+                            pbar.update(1)
+                    else:
+                        # regular backpropagation when fp16 is not used
+                        loss.backward()
+                        if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                            if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
+                                if train_config.enable_fsdp:
+                                    model.clip_grad_norm_(train_config.gradient_clipping_threshold)
+                                else:
+                                    torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
+                            optimizer.step()
+                            optimizer.zero_grad()
+                            pbar.update(1)
+
+                    if wandb_run:
+                        if not train_config.enable_fsdp or rank==0:
+                            wandb_run.log({
+                                'train/epoch': epoch + 1,
+                                'train/step': epoch * len(train_dataloader) + step,
+                                'train/loss': loss.detach().float(),
+                            })
+
+                    pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
+
+                    if train_config.save_metrics:
+                        save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
+                pbar.close()
 
         epoch_end_time = time.perf_counter()-epoch_start_time
         epoch_times.append(epoch_end_time)
@@ -266,7 +308,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     results["avg_checkpoint_time"] = avg_checkpoint_time
     if train_config.save_metrics:
         results["metrics_filename"] = metrics_filename
-
+    if train_config.flop_counter:
+        results["model_flops"]= TFlops
     #saving the training params including fsdp setting for reference.
     if train_config.enable_fsdp and not train_config.use_peft and rank==0:
         save_train_params(train_config, fsdp_config, rank)