Explorar el Código

Fix fixture in test_train_utils

Matthias Reso hace 6 meses
padre
commit
d9ca099613
Se han modificado 1 ficheros con 2 adiciones y 0 borrados
  1. 2 0
      src/tests/test_train_utils.py

+ 2 - 0
src/tests/test_train_utils.py

@@ -36,6 +36,7 @@ def test_gradient_accumulation(
 
     model = mocker.MagicMock(name="model")
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
+    model().loss.detach.return_value = torch.tensor(1)
     mock_tensor = mocker.MagicMock(name="tensor")
     batch = {"input": mock_tensor}
     train_dataloader = [batch, batch, batch, batch, batch]
@@ -94,6 +95,7 @@ def test_gradient_accumulation(
 def test_save_to_json(temp_output_dir, mocker):
     model = mocker.MagicMock(name="model")
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
+    model().loss.detach.return_value = torch.tensor(1)
     mock_tensor = mocker.MagicMock(name="tensor")
     batch = {"input": mock_tensor}
     train_dataloader = [batch, batch, batch, batch, batch]