schedules.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from contextlib import contextmanager
  16. import torch
  17. from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
  18. from megatron import get_args
  19. from megatron import get_num_microbatches
  20. from megatron import get_timers
  21. from megatron import mpu
  22. from megatron import p2p_communication
  23. from megatron.utils import unwrap_model
  24. from megatron.model import DistributedDataParallel as LocalDDP
  25. from megatron.model import Float16Module
  26. def get_forward_backward_func():
  27. args = get_args()
  28. if mpu.get_pipeline_model_parallel_world_size() > 1:
  29. if args.virtual_pipeline_model_parallel_size is not None:
  30. forward_backward_func = forward_backward_pipelining_with_interleaving
  31. else:
  32. forward_backward_func = forward_backward_pipelining_without_interleaving
  33. else:
  34. forward_backward_func = forward_backward_no_pipelining
  35. return forward_backward_func
  36. def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
  37. """Forward step for passed-in model.
  38. If first stage, input tensor is obtained from data_iterator, otherwise
  39. passed-in input_tensor is used.
  40. Returns output tensor."""
  41. timers = get_timers()
  42. timers('forward-compute').start()
  43. unwrapped_model = unwrap_model(
  44. model, (torchDDP, LocalDDP, Float16Module))
  45. unwrapped_model.set_input_tensor(input_tensor)
  46. output_tensor, loss_func = forward_step_func(data_iterator, model)
  47. if mpu.is_pipeline_last_stage():
  48. output_tensor = loss_func(output_tensor)
  49. loss, loss_reduced = output_tensor
  50. output_tensor = loss / get_num_microbatches()
  51. losses_reduced.append(loss_reduced)
  52. timers('forward-compute').stop()
  53. return output_tensor
  54. def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
  55. """Backward step through passed-in output tensor.
  56. If last stage, output_tensor_grad is None, otherwise gradient of loss
  57. with respect to stage's output tensor.
  58. Returns gradient of loss with respect to input tensor (None if first
  59. stage)."""
  60. args = get_args()
  61. timers = get_timers()
  62. timers('backward-compute').start()
  63. # Retain the grad on the input_tensor.
  64. if input_tensor is not None:
  65. input_tensor.retain_grad()
  66. # Backward pass.
  67. if output_tensor_grad is None:
  68. output_tensor = optimizer.scale_loss(output_tensor)
  69. torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
  70. # Collect the grad of the input_tensor.
  71. input_tensor_grad = None
  72. if input_tensor is not None:
  73. input_tensor_grad = input_tensor.grad
  74. timers('backward-compute').stop()
  75. return input_tensor_grad
  76. @contextmanager
  77. def dummy_handler():
  78. try:
  79. yield
  80. finally:
  81. pass
  82. def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
  83. optimizer, timers, forward_only):
  84. """Run forward and backward passes with no pipeline parallelism
  85. (no inter-stage communication).
  86. Returns dictionary with losses."""
  87. assert len(model) == 1
  88. model = model[0]
  89. context_handler = dummy_handler
  90. if isinstance(model, torchDDP):
  91. context_handler = model.no_sync
  92. losses_reduced = []
  93. input_tensor, output_tensor_grad = None, None
  94. with context_handler():
  95. for i in range(get_num_microbatches() - 1):
  96. output_tensor = forward_step(forward_step_func, data_iterator, model,
  97. input_tensor, losses_reduced)
  98. if not forward_only:
  99. backward_step(optimizer, input_tensor, output_tensor,
  100. output_tensor_grad)
  101. # Run computation for last microbatch out of context handler (want to
  102. # synchronize gradients).
  103. output_tensor = forward_step(forward_step_func, data_iterator, model,
  104. input_tensor, losses_reduced)
  105. if not forward_only:
  106. backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
  107. return losses_reduced
  108. def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model,
  109. optimizer, timers, forward_only):
  110. """Run interleaved 1F1B schedule (model split into model chunks), with
  111. communication between pipeline stages as needed.
  112. Returns dictionary with losses if the last stage, empty dict otherwise."""
  113. input_tensors = [[] for _ in range(len(model))]
  114. output_tensors = [[] for _ in range(len(model))]
  115. losses_reduced = []
  116. if not forward_only:
  117. output_tensor_grads = [[] for _ in range(len(model))]
  118. pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
  119. pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
  120. # Compute number of warmup and remaining microbatches.
  121. num_model_chunks = len(model)
  122. num_microbatches = get_num_microbatches() * num_model_chunks
  123. all_warmup_microbatches = False
  124. if forward_only:
  125. num_warmup_microbatches = num_microbatches
  126. else:
  127. # Run all forward passes and then all backward passes if number of
  128. # microbatches is just the number of pipeline stages.
  129. # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
  130. # all workers, followed by more microbatches after depending on
  131. # stage ID (more forward passes for earlier stages, later stages can
  132. # immediately start with 1F1B).
  133. if get_num_microbatches() == pipeline_parallel_size:
  134. num_warmup_microbatches = num_microbatches
  135. all_warmup_microbatches = True
  136. else:
  137. num_warmup_microbatches = \
  138. (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
  139. num_warmup_microbatches += (
  140. num_model_chunks - 1) * pipeline_parallel_size
  141. num_warmup_microbatches = min(num_warmup_microbatches,
  142. num_microbatches)
  143. num_microbatches_remaining = \
  144. num_microbatches - num_warmup_microbatches
  145. def get_model_chunk_id(microbatch_id, forward):
  146. """Helper method to get the model chunk ID given the iteration number."""
  147. microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
  148. model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
  149. if not forward:
  150. model_chunk_id = (num_model_chunks - model_chunk_id - 1)
  151. return model_chunk_id
  152. def forward_step_helper(microbatch_id):
  153. """Helper method to run forward step with model split into chunks
  154. (run set_virtual_pipeline_model_parallel_rank() before calling
  155. forward_step())."""
  156. model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
  157. mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
  158. if mpu.is_pipeline_first_stage():
  159. if len(input_tensors[model_chunk_id]) == \
  160. len(output_tensors[model_chunk_id]):
  161. input_tensors[model_chunk_id].append(None)
  162. input_tensor = input_tensors[model_chunk_id][-1]
  163. output_tensor = forward_step(forward_step_func,
  164. data_iterator[model_chunk_id],
  165. model[model_chunk_id],
  166. input_tensor, losses_reduced)
  167. output_tensors[model_chunk_id].append(output_tensor)
  168. return output_tensor
  169. def backward_step_helper(microbatch_id):
  170. """Helper method to run backward step with model split into chunks
  171. (run set_virtual_pipeline_model_parallel_rank() before calling
  172. backward_step())."""
  173. model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
  174. mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
  175. if mpu.is_pipeline_last_stage():
  176. if len(output_tensor_grads[model_chunk_id]) == 0:
  177. output_tensor_grads[model_chunk_id].append(None)
  178. input_tensor = input_tensors[model_chunk_id].pop(0)
  179. output_tensor = output_tensors[model_chunk_id].pop(0)
  180. output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
  181. input_tensor_grad = \
  182. backward_step(optimizer,
  183. input_tensor,
  184. output_tensor,
  185. output_tensor_grad)
  186. return input_tensor_grad
  187. # Run warmup forward passes.
  188. mpu.set_virtual_pipeline_model_parallel_rank(0)
  189. input_tensors[0].append(
  190. p2p_communication.recv_forward(timers))
  191. for k in range(num_warmup_microbatches):
  192. output_tensor = forward_step_helper(k)
  193. # Determine if tensor should be received from previous stage.
  194. next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
  195. recv_prev = True
  196. if mpu.is_pipeline_first_stage(ignore_virtual=True):
  197. if next_forward_model_chunk_id == 0:
  198. recv_prev = False
  199. if k == (num_microbatches - 1):
  200. recv_prev = False
  201. # Don't send tensor downstream if on last stage.
  202. if mpu.is_pipeline_last_stage():
  203. output_tensor = None
  204. # Send and receive tensors as appropriate (send tensors computed
  205. # in this iteration; receive tensors for next iteration).
  206. if k == (num_warmup_microbatches - 1) and not forward_only and \
  207. not all_warmup_microbatches:
  208. input_tensor_grad = None
  209. recv_next = True
  210. if mpu.is_pipeline_last_stage(ignore_virtual=True):
  211. recv_next = False
  212. input_tensor, output_tensor_grad = \
  213. p2p_communication.send_forward_backward_recv_forward_backward(
  214. output_tensor, input_tensor_grad,
  215. recv_prev=recv_prev, recv_next=recv_next,
  216. timers=timers)
  217. output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
  218. else:
  219. input_tensor = \
  220. p2p_communication.send_forward_recv_forward(
  221. output_tensor, recv_prev, timers)
  222. input_tensors[next_forward_model_chunk_id].append(input_tensor)
  223. # Run 1F1B in steady state.
  224. for k in range(num_microbatches_remaining):
  225. # Forward pass.
  226. forward_k = k + num_warmup_microbatches
  227. output_tensor = forward_step_helper(forward_k)
  228. # Backward pass.
  229. backward_k = k
  230. input_tensor_grad = backward_step_helper(backward_k)
  231. # Send output_tensor and input_tensor_grad, receive input_tensor
  232. # and output_tensor_grad.
  233. # Determine if current stage has anything to send in either direction,
  234. # otherwise set tensor to None.
  235. forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
  236. mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
  237. if mpu.is_pipeline_last_stage():
  238. output_tensor = None
  239. backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
  240. mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
  241. if mpu.is_pipeline_first_stage():
  242. input_tensor_grad = None
  243. # Determine if peers are sending, and where in data structure to put
  244. # received tensors.
  245. recv_prev = True
  246. if mpu.is_pipeline_first_stage(ignore_virtual=True):
  247. # First stage is ahead of last stage by (pipeline_parallel_size - 1).
  248. next_forward_model_chunk_id = get_model_chunk_id(
  249. forward_k - (pipeline_parallel_size - 1), forward=True)
  250. if next_forward_model_chunk_id == (num_model_chunks - 1):
  251. recv_prev = False
  252. next_forward_model_chunk_id += 1
  253. else:
  254. next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
  255. forward=True)
  256. recv_next = True
  257. if mpu.is_pipeline_last_stage(ignore_virtual=True):
  258. # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
  259. next_backward_model_chunk_id = get_model_chunk_id(
  260. backward_k - (pipeline_parallel_size - 1), forward=False)
  261. if next_backward_model_chunk_id == 0:
  262. recv_next = False
  263. next_backward_model_chunk_id -= 1
  264. else:
  265. next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
  266. forward=False)
  267. # If last iteration, don't receive; we already received one extra
  268. # before the start of the for loop.
  269. if k == (num_microbatches_remaining - 1):
  270. recv_prev = False
  271. # Communicate tensors.
  272. input_tensor, output_tensor_grad = \
  273. p2p_communication.send_forward_backward_recv_forward_backward(
  274. output_tensor, input_tensor_grad,
  275. recv_prev=recv_prev, recv_next=recv_next,
  276. timers=timers)
  277. # Put input_tensor and output_tensor_grad in data structures in the
  278. # right location.
  279. if recv_prev:
  280. input_tensors[next_forward_model_chunk_id].append(input_tensor)
  281. if recv_next:
  282. output_tensor_grads[next_backward_model_chunk_id].append(
  283. output_tensor_grad)
  284. # Run cooldown backward passes (flush out pipeline).
  285. if not forward_only:
  286. if all_warmup_microbatches:
  287. output_tensor_grads[num_model_chunks-1].append(
  288. p2p_communication.recv_backward(timers))
  289. for k in range(num_microbatches_remaining, num_microbatches):
  290. input_tensor_grad = backward_step_helper(k)
  291. next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
  292. recv_next = True
  293. if mpu.is_pipeline_last_stage(ignore_virtual=True):
  294. if next_backward_model_chunk_id == (num_model_chunks - 1):
  295. recv_next = False
  296. if k == (num_microbatches - 1):
  297. recv_next = False
  298. output_tensor_grads[next_backward_model_chunk_id].append(
  299. p2p_communication.send_backward_recv_backward(
  300. input_tensor_grad, recv_next, timers))
  301. return losses_reduced
  302. def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
  303. model, optimizer, timers,
  304. forward_only):
  305. """Run non-interleaved 1F1B schedule, with communication between pipeline
  306. stages.
  307. Returns dictionary with losses if the last stage, empty dict otherwise."""
  308. timers = get_timers()
  309. assert len(model) == 1
  310. model = model[0]
  311. # Compute number of warmup microbatches.
  312. num_microbatches = get_num_microbatches()
  313. num_warmup_microbatches = \
  314. (mpu.get_pipeline_model_parallel_world_size() -
  315. mpu.get_pipeline_model_parallel_rank() - 1)
  316. num_warmup_microbatches = min(
  317. num_warmup_microbatches,
  318. num_microbatches)
  319. num_microbatches_remaining = \
  320. num_microbatches - num_warmup_microbatches
  321. input_tensors = []
  322. output_tensors = []
  323. losses_reduced = []
  324. # Run warmup forward passes.
  325. for i in range(num_warmup_microbatches):
  326. input_tensor = p2p_communication.recv_forward(timers)
  327. output_tensor = forward_step(forward_step_func, data_iterator, model,
  328. input_tensor, losses_reduced)
  329. p2p_communication.send_forward(output_tensor, timers)
  330. input_tensors.append(input_tensor)
  331. output_tensors.append(output_tensor)
  332. # Before running 1F1B, need to receive first forward tensor.
  333. # If all microbatches are run in warmup / cooldown phase, then no need to
  334. # receive this tensor here.
  335. if num_microbatches_remaining > 0:
  336. input_tensor = p2p_communication.recv_forward(timers)
  337. # Run 1F1B in steady state.
  338. for i in range(num_microbatches_remaining):
  339. last_iteration = (i == (num_microbatches_remaining - 1))
  340. output_tensor = forward_step(forward_step_func, data_iterator, model,
  341. input_tensor, losses_reduced)
  342. if forward_only:
  343. p2p_communication.send_forward(output_tensor, timers)
  344. else:
  345. output_tensor_grad = \
  346. p2p_communication.send_forward_recv_backward(output_tensor,
  347. timers)
  348. # Add input_tensor and output_tensor to end of list, then pop from the
  349. # start of the list for backward pass.
  350. input_tensors.append(input_tensor)
  351. output_tensors.append(output_tensor)
  352. if forward_only:
  353. if not last_iteration:
  354. input_tensor = p2p_communication.recv_forward(timers)
  355. else:
  356. input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
  357. input_tensor_grad = \
  358. backward_step(optimizer, input_tensor, output_tensor,
  359. output_tensor_grad)
  360. if last_iteration:
  361. input_tensor = None
  362. p2p_communication.send_backward(input_tensor_grad, timers)
  363. else:
  364. input_tensor = \
  365. p2p_communication.send_backward_recv_forward(
  366. input_tensor_grad, timers)
  367. # Run cooldown backward passes.
  368. if not forward_only:
  369. for i in range(num_warmup_microbatches):
  370. input_tensor = input_tensors.pop(0)
  371. output_tensor = output_tensors.pop(0)
  372. output_tensor_grad = p2p_communication.recv_backward(timers)
  373. input_tensor_grad = \
  374. backward_step(optimizer, input_tensor, output_tensor,
  375. output_tensor_grad)
  376. p2p_communication.send_backward(input_tensor_grad, timers)
  377. return losses_reduced