Browse Source

Fix fixture in test_train_utils

Matthias Reso 6 months ago
parent
commit
d9ca099613
1 changed files with 2 additions and 0 deletions
  1. 2 0
      src/tests/test_train_utils.py

+ 2 - 0
src/tests/test_train_utils.py

@@ -36,6 +36,7 @@ def test_gradient_accumulation(
 
     model = mocker.MagicMock(name="model")
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
+    model().loss.detach.return_value = torch.tensor(1)
     mock_tensor = mocker.MagicMock(name="tensor")
     batch = {"input": mock_tensor}
     train_dataloader = [batch, batch, batch, batch, batch]
@@ -94,6 +95,7 @@ def test_gradient_accumulation(
 def test_save_to_json(temp_output_dir, mocker):
     model = mocker.MagicMock(name="model")
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
+    model().loss.detach.return_value = torch.tensor(1)
     mock_tensor = mocker.MagicMock(name="tensor")
     batch = {"input": mock_tensor}
     train_dataloader = [batch, batch, batch, batch, batch]