test_train_utils.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from unittest.mock import patch
  4. import pytest
  5. import torch
  6. import os
  7. import shutil
  8. from llama_recipes.utils.train_utils import train
  9. TEMP_OUTPUT_DIR = os.getcwd() + "/tmp"
  10. @pytest.fixture(scope="session")
  11. def temp_output_dir():
  12. # Create the directory during the session-level setup
  13. temp_output_dir = "tmp"
  14. os.mkdir(os.path.join(os.getcwd(), temp_output_dir))
  15. yield temp_output_dir
  16. # Delete the directory during the session-level teardown
  17. shutil.rmtree(temp_output_dir)
  18. @patch("llama_recipes.utils.train_utils.MemoryTrace")
  19. @patch("llama_recipes.utils.train_utils.nullcontext")
  20. @patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
  21. @patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast")
  22. def test_gradient_accumulation(
  23. autocast,
  24. scaler,
  25. nullcontext,
  26. mem_trace,
  27. mocker):
  28. model = mocker.MagicMock(name="model")
  29. model().loss.__truediv__().detach.return_value = torch.tensor(1)
  30. mock_tensor = mocker.MagicMock(name="tensor")
  31. batch = {"input": mock_tensor}
  32. train_dataloader = [batch, batch, batch, batch, batch]
  33. eval_dataloader = None
  34. tokenizer = mocker.MagicMock()
  35. optimizer = mocker.MagicMock()
  36. lr_scheduler = mocker.MagicMock()
  37. gradient_accumulation_steps = 1
  38. train_config = mocker.MagicMock()
  39. train_config.enable_fsdp = False
  40. train_config.use_fp16 = False
  41. train_config.run_validation = False
  42. train_config.gradient_clipping = False
  43. train_config.max_train_step = 0
  44. train_config.max_eval_step = 0
  45. train_config.save_metrics = False
  46. train_config.flop_counter_start = 0
  47. train_config.use_profiler = False
  48. train_config.flop_counter = True
  49. train(
  50. model,
  51. train_dataloader,
  52. eval_dataloader,
  53. tokenizer,
  54. optimizer,
  55. lr_scheduler,
  56. gradient_accumulation_steps,
  57. train_config,
  58. )
  59. assert optimizer.zero_grad.call_count == 5
  60. optimizer.zero_grad.reset_mock()
  61. assert nullcontext.call_count == 5
  62. nullcontext.reset_mock()
  63. assert autocast.call_count == 0
  64. gradient_accumulation_steps = 2
  65. train_config.use_fp16 = True
  66. train(
  67. model,
  68. train_dataloader,
  69. eval_dataloader,
  70. tokenizer,
  71. optimizer,
  72. lr_scheduler,
  73. gradient_accumulation_steps,
  74. train_config,
  75. )
  76. assert optimizer.zero_grad.call_count == 3
  77. assert nullcontext.call_count == 0
  78. assert autocast.call_count == 5
  79. def test_save_to_json(temp_output_dir, mocker):
  80. model = mocker.MagicMock(name="model")
  81. model().loss.__truediv__().detach.return_value = torch.tensor(1)
  82. mock_tensor = mocker.MagicMock(name="tensor")
  83. batch = {"input": mock_tensor}
  84. train_dataloader = [batch, batch, batch, batch, batch]
  85. eval_dataloader = None
  86. tokenizer = mocker.MagicMock()
  87. optimizer = mocker.MagicMock()
  88. lr_scheduler = mocker.MagicMock()
  89. gradient_accumulation_steps = 1
  90. train_config = mocker.MagicMock()
  91. train_config.enable_fsdp = False
  92. train_config.use_fp16 = False
  93. train_config.run_validation = False
  94. train_config.gradient_clipping = False
  95. train_config.save_metrics = True
  96. train_config.max_train_step = 0
  97. train_config.max_eval_step = 0
  98. train_config.output_dir = temp_output_dir
  99. train_config.flop_counter_start = 0
  100. train_config.use_profiler = False
  101. results = train(
  102. model,
  103. train_dataloader,
  104. eval_dataloader,
  105. tokenizer,
  106. optimizer,
  107. lr_scheduler,
  108. gradient_accumulation_steps,
  109. train_config,
  110. local_rank=0
  111. )
  112. assert results["metrics_filename"] not in ["", None]
  113. assert os.path.isfile(results["metrics_filename"])