test_train_utils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from unittest.mock import patch
  4. import pytest
  5. import torch
  6. import os
  7. import shutil
  8. from llama_recipes.utils.train_utils import train
  9. TEMP_OUTPUT_DIR = os.getcwd() + "/tmp"
  10. @pytest.fixture(scope="session")
  11. def temp_output_dir():
  12. # Create the directory during the session-level setup
  13. temp_output_dir = "tmp"
  14. os.mkdir(os.path.join(os.getcwd(), temp_output_dir))
  15. yield temp_output_dir
  16. # Delete the directory during the session-level teardown
  17. shutil.rmtree(temp_output_dir)
  18. @patch("llama_recipes.utils.train_utils.MemoryTrace")
  19. @patch("llama_recipes.utils.train_utils.nullcontext")
  20. @patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
  21. @patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast")
  22. def test_gradient_accumulation(
  23. autocast,
  24. scaler,
  25. nullcontext,
  26. mem_trace,
  27. mocker):
  28. model = mocker.MagicMock(name="model")
  29. model().loss.__truediv__().detach.return_value = torch.tensor(1)
  30. model().loss.detach.return_value = torch.tensor(1)
  31. mock_tensor = mocker.MagicMock(name="tensor")
  32. batch = {"input": mock_tensor}
  33. train_dataloader = [batch, batch, batch, batch, batch]
  34. eval_dataloader = None
  35. tokenizer = mocker.MagicMock()
  36. optimizer = mocker.MagicMock()
  37. lr_scheduler = mocker.MagicMock()
  38. gradient_accumulation_steps = 1
  39. train_config = mocker.MagicMock()
  40. train_config.enable_fsdp = False
  41. train_config.use_fp16 = False
  42. train_config.run_validation = False
  43. train_config.gradient_clipping = False
  44. train_config.max_train_step = 0
  45. train_config.max_eval_step = 0
  46. train_config.save_metrics = False
  47. train_config.flop_counter_start = 0
  48. train_config.use_profiler = False
  49. train_config.flop_counter = True
  50. train(
  51. model,
  52. train_dataloader,
  53. eval_dataloader,
  54. tokenizer,
  55. optimizer,
  56. lr_scheduler,
  57. gradient_accumulation_steps,
  58. train_config,
  59. )
  60. assert optimizer.zero_grad.call_count == 5
  61. optimizer.zero_grad.reset_mock()
  62. assert nullcontext.call_count == 5
  63. nullcontext.reset_mock()
  64. assert autocast.call_count == 0
  65. gradient_accumulation_steps = 2
  66. train_config.use_fp16 = True
  67. train(
  68. model,
  69. train_dataloader,
  70. eval_dataloader,
  71. tokenizer,
  72. optimizer,
  73. lr_scheduler,
  74. gradient_accumulation_steps,
  75. train_config,
  76. )
  77. assert optimizer.zero_grad.call_count == 3
  78. assert nullcontext.call_count == 0
  79. assert autocast.call_count == 5
  80. def test_save_to_json(temp_output_dir, mocker):
  81. model = mocker.MagicMock(name="model")
  82. model().loss.__truediv__().detach.return_value = torch.tensor(1)
  83. model().loss.detach.return_value = torch.tensor(1)
  84. mock_tensor = mocker.MagicMock(name="tensor")
  85. batch = {"input": mock_tensor}
  86. train_dataloader = [batch, batch, batch, batch, batch]
  87. eval_dataloader = None
  88. tokenizer = mocker.MagicMock()
  89. optimizer = mocker.MagicMock()
  90. lr_scheduler = mocker.MagicMock()
  91. gradient_accumulation_steps = 1
  92. train_config = mocker.MagicMock()
  93. train_config.enable_fsdp = False
  94. train_config.use_fp16 = False
  95. train_config.run_validation = False
  96. train_config.gradient_clipping = False
  97. train_config.save_metrics = True
  98. train_config.max_train_step = 0
  99. train_config.max_eval_step = 0
  100. train_config.output_dir = temp_output_dir
  101. train_config.flop_counter_start = 0
  102. train_config.use_profiler = False
  103. results = train(
  104. model,
  105. train_dataloader,
  106. eval_dataloader,
  107. tokenizer,
  108. optimizer,
  109. lr_scheduler,
  110. gradient_accumulation_steps,
  111. train_config,
  112. local_rank=0
  113. )
  114. assert results["metrics_filename"] not in ["", None]
  115. assert os.path.isfile(results["metrics_filename"])