Преглед на файлове

Fix fixture in test_train_utils

Matthias Reso преди 6 месеца
родител
ревизия
d9ca099613
променени са 1 файла, в които са добавени 2 реда и са изтрити 0 реда
  1. 2 0
      src/tests/test_train_utils.py

+ 2 - 0
src/tests/test_train_utils.py

@@ -36,6 +36,7 @@ def test_gradient_accumulation(
 
     model = mocker.MagicMock(name="model")
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
+    model().loss.detach.return_value = torch.tensor(1)
     mock_tensor = mocker.MagicMock(name="tensor")
     batch = {"input": mock_tensor}
     train_dataloader = [batch, batch, batch, batch, batch]
@@ -94,6 +95,7 @@ def test_gradient_accumulation(
 def test_save_to_json(temp_output_dir, mocker):
     model = mocker.MagicMock(name="model")
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
+    model().loss.detach.return_value = torch.tensor(1)
     mock_tensor = mocker.MagicMock(name="tensor")
     batch = {"input": mock_tensor}
     train_dataloader = [batch, batch, batch, batch, batch]