瀏覽代碼

Fix src/tests/test_train_utils.py

Matthias Reso 7 月之前
父節點
當前提交
fac71cd136
共有 1 個文件被更改,包括 10 次插入1 次删除
  1. 10 1
      src/tests/test_train_utils.py

+ 10 - 1
src/tests/test_train_utils.py

@@ -27,7 +27,12 @@ def temp_output_dir():
 @patch("llama_recipes.utils.train_utils.nullcontext")
 @patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
 @patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast")
-def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker):
+def test_gradient_accumulation(
+    autocast,
+    scaler,
+    nullcontext,
+    mem_trace,
+    mocker):
 
     model = mocker.MagicMock(name="model")
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
@@ -47,6 +52,9 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker)
     train_config.max_train_step = 0
     train_config.max_eval_step = 0
     train_config.save_metrics = False
+    train_config.flop_counter_start = 0
+    train_config.use_profiler = False
+    train_config.flop_counter = True
 
     train(
         model,
@@ -103,6 +111,7 @@ def test_save_to_json(temp_output_dir, mocker):
     train_config.max_train_step = 0
     train_config.max_eval_step = 0
     train_config.output_dir = temp_output_dir
+    train_config.flop_counter_start = 0
     train_config.use_profiler = False
 
     results = train(