123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from contextlib import contextmanager
- import torch
- from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
- from megatron import get_args
- from megatron import get_num_microbatches
- from megatron import get_timers
- from megatron import mpu
- from megatron import p2p_communication
- from megatron.utils import unwrap_model
- from megatron.model import DistributedDataParallel as LocalDDP
- from megatron.model import Float16Module
- def get_forward_backward_func():
- args = get_args()
- if mpu.get_pipeline_model_parallel_world_size() > 1:
- if args.virtual_pipeline_model_parallel_size is not None:
- forward_backward_func = forward_backward_pipelining_with_interleaving
- else:
- forward_backward_func = forward_backward_pipelining_without_interleaving
- else:
- forward_backward_func = forward_backward_no_pipelining
- return forward_backward_func
- def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
- """Forward step for passed-in model.
- If first stage, input tensor is obtained from data_iterator, otherwise
- passed-in input_tensor is used.
- Returns output tensor."""
- timers = get_timers()
- timers('forward-compute').start()
- unwrapped_model = unwrap_model(
- model, (torchDDP, LocalDDP, Float16Module))
- unwrapped_model.set_input_tensor(input_tensor)
- output_tensor, loss_func = forward_step_func(data_iterator, model)
- if mpu.is_pipeline_last_stage():
- output_tensor = loss_func(output_tensor)
- loss, loss_reduced = output_tensor
- output_tensor = loss / get_num_microbatches()
- losses_reduced.append(loss_reduced)
- timers('forward-compute').stop()
- return output_tensor
- def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
- """Backward step through passed-in output tensor.
- If last stage, output_tensor_grad is None, otherwise gradient of loss
- with respect to stage's output tensor.
- Returns gradient of loss with respect to input tensor (None if first
- stage)."""
- args = get_args()
- timers = get_timers()
- timers('backward-compute').start()
- # Retain the grad on the input_tensor.
- if input_tensor is not None:
- input_tensor.retain_grad()
- # Backward pass.
- if output_tensor_grad is None:
- output_tensor = optimizer.scale_loss(output_tensor)
- torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
- # Collect the grad of the input_tensor.
- input_tensor_grad = None
- if input_tensor is not None:
- input_tensor_grad = input_tensor.grad
- timers('backward-compute').stop()
- return input_tensor_grad
- @contextmanager
- def dummy_handler():
- try:
- yield
- finally:
- pass
- def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
- optimizer, timers, forward_only):
- """Run forward and backward passes with no pipeline parallelism
- (no inter-stage communication).
- Returns dictionary with losses."""
- assert len(model) == 1
- model = model[0]
- context_handler = dummy_handler
- if isinstance(model, torchDDP):
- context_handler = model.no_sync
- losses_reduced = []
- input_tensor, output_tensor_grad = None, None
- with context_handler():
- for i in range(get_num_microbatches() - 1):
- output_tensor = forward_step(forward_step_func, data_iterator, model,
- input_tensor, losses_reduced)
- if not forward_only:
- backward_step(optimizer, input_tensor, output_tensor,
- output_tensor_grad)
- # Run computation for last microbatch out of context handler (want to
- # synchronize gradients).
- output_tensor = forward_step(forward_step_func, data_iterator, model,
- input_tensor, losses_reduced)
- if not forward_only:
- backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
- return losses_reduced
- def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model,
- optimizer, timers, forward_only):
- """Run interleaved 1F1B schedule (model split into model chunks), with
- communication between pipeline stages as needed.
- Returns dictionary with losses if the last stage, empty dict otherwise."""
- input_tensors = [[] for _ in range(len(model))]
- output_tensors = [[] for _ in range(len(model))]
- losses_reduced = []
- if not forward_only:
- output_tensor_grads = [[] for _ in range(len(model))]
- pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
- pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
- # Compute number of warmup and remaining microbatches.
- num_model_chunks = len(model)
- num_microbatches = get_num_microbatches() * num_model_chunks
- all_warmup_microbatches = False
- if forward_only:
- num_warmup_microbatches = num_microbatches
- else:
- # Run all forward passes and then all backward passes if number of
- # microbatches is just the number of pipeline stages.
- # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
- # all workers, followed by more microbatches after depending on
- # stage ID (more forward passes for earlier stages, later stages can
- # immediately start with 1F1B).
- if get_num_microbatches() == pipeline_parallel_size:
- num_warmup_microbatches = num_microbatches
- all_warmup_microbatches = True
- else:
- num_warmup_microbatches = \
- (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
- num_warmup_microbatches += (
- num_model_chunks - 1) * pipeline_parallel_size
- num_warmup_microbatches = min(num_warmup_microbatches,
- num_microbatches)
- num_microbatches_remaining = \
- num_microbatches - num_warmup_microbatches
- def get_model_chunk_id(microbatch_id, forward):
- """Helper method to get the model chunk ID given the iteration number."""
- microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
- model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
- if not forward:
- model_chunk_id = (num_model_chunks - model_chunk_id - 1)
- return model_chunk_id
- def forward_step_helper(microbatch_id):
- """Helper method to run forward step with model split into chunks
- (run set_virtual_pipeline_model_parallel_rank() before calling
- forward_step())."""
- model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
- mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
- if mpu.is_pipeline_first_stage():
- if len(input_tensors[model_chunk_id]) == \
- len(output_tensors[model_chunk_id]):
- input_tensors[model_chunk_id].append(None)
- input_tensor = input_tensors[model_chunk_id][-1]
- output_tensor = forward_step(forward_step_func,
- data_iterator[model_chunk_id],
- model[model_chunk_id],
- input_tensor, losses_reduced)
- output_tensors[model_chunk_id].append(output_tensor)
- return output_tensor
- def backward_step_helper(microbatch_id):
- """Helper method to run backward step with model split into chunks
- (run set_virtual_pipeline_model_parallel_rank() before calling
- backward_step())."""
- model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
- mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
- if mpu.is_pipeline_last_stage():
- if len(output_tensor_grads[model_chunk_id]) == 0:
- output_tensor_grads[model_chunk_id].append(None)
- input_tensor = input_tensors[model_chunk_id].pop(0)
- output_tensor = output_tensors[model_chunk_id].pop(0)
- output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
- input_tensor_grad = \
- backward_step(optimizer,
- input_tensor,
- output_tensor,
- output_tensor_grad)
- return input_tensor_grad
- # Run warmup forward passes.
- mpu.set_virtual_pipeline_model_parallel_rank(0)
- input_tensors[0].append(
- p2p_communication.recv_forward(timers))
- for k in range(num_warmup_microbatches):
- output_tensor = forward_step_helper(k)
- # Determine if tensor should be received from previous stage.
- next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
- recv_prev = True
- if mpu.is_pipeline_first_stage(ignore_virtual=True):
- if next_forward_model_chunk_id == 0:
- recv_prev = False
- if k == (num_microbatches - 1):
- recv_prev = False
- # Don't send tensor downstream if on last stage.
- if mpu.is_pipeline_last_stage():
- output_tensor = None
- # Send and receive tensors as appropriate (send tensors computed
- # in this iteration; receive tensors for next iteration).
- if k == (num_warmup_microbatches - 1) and not forward_only and \
- not all_warmup_microbatches:
- input_tensor_grad = None
- recv_next = True
- if mpu.is_pipeline_last_stage(ignore_virtual=True):
- recv_next = False
- input_tensor, output_tensor_grad = \
- p2p_communication.send_forward_backward_recv_forward_backward(
- output_tensor, input_tensor_grad,
- recv_prev=recv_prev, recv_next=recv_next,
- timers=timers)
- output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
- else:
- input_tensor = \
- p2p_communication.send_forward_recv_forward(
- output_tensor, recv_prev, timers)
- input_tensors[next_forward_model_chunk_id].append(input_tensor)
- # Run 1F1B in steady state.
- for k in range(num_microbatches_remaining):
- # Forward pass.
- forward_k = k + num_warmup_microbatches
- output_tensor = forward_step_helper(forward_k)
- # Backward pass.
- backward_k = k
- input_tensor_grad = backward_step_helper(backward_k)
- # Send output_tensor and input_tensor_grad, receive input_tensor
- # and output_tensor_grad.
- # Determine if current stage has anything to send in either direction,
- # otherwise set tensor to None.
- forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
- mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
- if mpu.is_pipeline_last_stage():
- output_tensor = None
- backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
- mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
- if mpu.is_pipeline_first_stage():
- input_tensor_grad = None
- # Determine if peers are sending, and where in data structure to put
- # received tensors.
- recv_prev = True
- if mpu.is_pipeline_first_stage(ignore_virtual=True):
- # First stage is ahead of last stage by (pipeline_parallel_size - 1).
- next_forward_model_chunk_id = get_model_chunk_id(
- forward_k - (pipeline_parallel_size - 1), forward=True)
- if next_forward_model_chunk_id == (num_model_chunks - 1):
- recv_prev = False
- next_forward_model_chunk_id += 1
- else:
- next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
- forward=True)
- recv_next = True
- if mpu.is_pipeline_last_stage(ignore_virtual=True):
- # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
- next_backward_model_chunk_id = get_model_chunk_id(
- backward_k - (pipeline_parallel_size - 1), forward=False)
- if next_backward_model_chunk_id == 0:
- recv_next = False
- next_backward_model_chunk_id -= 1
- else:
- next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
- forward=False)
- # If last iteration, don't receive; we already received one extra
- # before the start of the for loop.
- if k == (num_microbatches_remaining - 1):
- recv_prev = False
- # Communicate tensors.
- input_tensor, output_tensor_grad = \
- p2p_communication.send_forward_backward_recv_forward_backward(
- output_tensor, input_tensor_grad,
- recv_prev=recv_prev, recv_next=recv_next,
- timers=timers)
- # Put input_tensor and output_tensor_grad in data structures in the
- # right location.
- if recv_prev:
- input_tensors[next_forward_model_chunk_id].append(input_tensor)
- if recv_next:
- output_tensor_grads[next_backward_model_chunk_id].append(
- output_tensor_grad)
- # Run cooldown backward passes (flush out pipeline).
- if not forward_only:
- if all_warmup_microbatches:
- output_tensor_grads[num_model_chunks-1].append(
- p2p_communication.recv_backward(timers))
- for k in range(num_microbatches_remaining, num_microbatches):
- input_tensor_grad = backward_step_helper(k)
- next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
- recv_next = True
- if mpu.is_pipeline_last_stage(ignore_virtual=True):
- if next_backward_model_chunk_id == (num_model_chunks - 1):
- recv_next = False
- if k == (num_microbatches - 1):
- recv_next = False
- output_tensor_grads[next_backward_model_chunk_id].append(
- p2p_communication.send_backward_recv_backward(
- input_tensor_grad, recv_next, timers))
- return losses_reduced
- def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
- model, optimizer, timers,
- forward_only):
- """Run non-interleaved 1F1B schedule, with communication between pipeline
- stages.
- Returns dictionary with losses if the last stage, empty dict otherwise."""
- timers = get_timers()
- assert len(model) == 1
- model = model[0]
- # Compute number of warmup microbatches.
- num_microbatches = get_num_microbatches()
- num_warmup_microbatches = \
- (mpu.get_pipeline_model_parallel_world_size() -
- mpu.get_pipeline_model_parallel_rank() - 1)
- num_warmup_microbatches = min(
- num_warmup_microbatches,
- num_microbatches)
- num_microbatches_remaining = \
- num_microbatches - num_warmup_microbatches
- input_tensors = []
- output_tensors = []
- losses_reduced = []
- # Run warmup forward passes.
- for i in range(num_warmup_microbatches):
- input_tensor = p2p_communication.recv_forward(timers)
- output_tensor = forward_step(forward_step_func, data_iterator, model,
- input_tensor, losses_reduced)
- p2p_communication.send_forward(output_tensor, timers)
- input_tensors.append(input_tensor)
- output_tensors.append(output_tensor)
- # Before running 1F1B, need to receive first forward tensor.
- # If all microbatches are run in warmup / cooldown phase, then no need to
- # receive this tensor here.
- if num_microbatches_remaining > 0:
- input_tensor = p2p_communication.recv_forward(timers)
- # Run 1F1B in steady state.
- for i in range(num_microbatches_remaining):
- last_iteration = (i == (num_microbatches_remaining - 1))
- output_tensor = forward_step(forward_step_func, data_iterator, model,
- input_tensor, losses_reduced)
- if forward_only:
- p2p_communication.send_forward(output_tensor, timers)
- else:
- output_tensor_grad = \
- p2p_communication.send_forward_recv_backward(output_tensor,
- timers)
- # Add input_tensor and output_tensor to end of list, then pop from the
- # start of the list for backward pass.
- input_tensors.append(input_tensor)
- output_tensors.append(output_tensor)
- if forward_only:
- if not last_iteration:
- input_tensor = p2p_communication.recv_forward(timers)
- else:
- input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
- input_tensor_grad = \
- backward_step(optimizer, input_tensor, output_tensor,
- output_tensor_grad)
- if last_iteration:
- input_tensor = None
- p2p_communication.send_backward(input_tensor_grad, timers)
- else:
- input_tensor = \
- p2p_communication.send_backward_recv_forward(
- input_tensor_grad, timers)
- # Run cooldown backward passes.
- if not forward_only:
- for i in range(num_warmup_microbatches):
- input_tensor = input_tensors.pop(0)
- output_tensor = output_tensors.pop(0)
- output_tensor_grad = p2p_communication.recv_backward(timers)
- input_tensor_grad = \
- backward_step(optimizer, input_tensor, output_tensor,
- output_tensor_grad)
- p2p_communication.send_backward(input_tensor_grad, timers)
- return losses_reduced
|