random.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # Parts of the code here are adapted from PyTorch
  16. # repo: https://github.com/pytorch/pytorch
  17. import contextlib
  18. import torch
  19. from torch import _C
  20. from torch.cuda import _lazy_call, device as device_ctx_manager
  21. from torch.utils.checkpoint import detach_variable
  22. from megatron import get_args
  23. from megatron.memory import allocate_mem_buff
  24. from .initialize import get_data_parallel_rank
  25. from .initialize import get_tensor_model_parallel_group
  26. from .initialize import get_tensor_model_parallel_rank
  27. from .initialize import get_tensor_model_parallel_world_size
  28. # Default name for the model parallel rng tracker.
  29. _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
  30. # Whether apply model parallelsim to checkpointed hidden states.
  31. _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None
  32. def init_checkpointed_activations_memory_buffer():
  33. """Initializ the memory buffer for the checkpointed activations."""
  34. args = get_args()
  35. per_layer = args.micro_batch_size * args.max_position_embeddings * \
  36. args.hidden_size // args.tensor_model_parallel_size
  37. assert args.num_layers % args.checkpoint_num_layers == 0, \
  38. 'number of layers is not divisible by checkpoint-num-layers'
  39. num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
  40. numel = per_layer * num_checkpointer_layers
  41. dtype = torch.half
  42. if not args.fp16:
  43. dtype = torch.float
  44. global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
  45. assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \
  46. 'checkpointed activations memory buffer is already allocated.'
  47. _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff(
  48. 'checkpointed activations', numel, dtype, track_usage=False)
  49. def reset_checkpointed_activations_memory_buffer():
  50. """Reset the memory used for checkpointing."""
  51. if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
  52. _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.reset()
  53. def _set_cuda_rng_state(new_state, device=-1):
  54. """Sets the random number generator state of the current GPU.
  55. Argumentss:
  56. new_state (torch.ByteTensor): The desired state
  57. This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
  58. with a single change: the input state is not cloned. Cloning caused
  59. major performance issues for +4 GPU cases.
  60. """
  61. if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
  62. # older PyTorch
  63. def cb():
  64. with device_ctx_manager(device):
  65. _C._cuda_setRNGState(new_state)
  66. else:
  67. # newer PyTorch
  68. if device == -1:
  69. device = torch.device('cuda')
  70. elif isinstance(device, str):
  71. device = torch.device(device)
  72. elif isinstance(device, int):
  73. device = torch.device('cuda', device)
  74. def cb():
  75. idx = device.index
  76. if idx is None:
  77. idx = torch.cuda.current_device()
  78. default_generator = torch.cuda.default_generators[idx]
  79. default_generator.set_state(new_state)
  80. _lazy_call(cb)
  81. def split_tensor_into_1d_equal_chunks(tensor):
  82. """Break a tensor into equal 1D chunks."""
  83. data = tensor.view(-1)
  84. partition_size = torch.numel(data) // get_tensor_model_parallel_world_size()
  85. start_index = partition_size * get_tensor_model_parallel_rank()
  86. end_index = start_index + partition_size
  87. return data[start_index:end_index]
  88. def gather_split_1d_tensor(tensor):
  89. """Opposite of above function, gather values from model parallel ranks."""
  90. world_size = get_tensor_model_parallel_world_size()
  91. numel = torch.numel(tensor)
  92. numel_gathered = world_size * numel
  93. gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
  94. device=torch.cuda.current_device(),
  95. requires_grad=False)
  96. chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)]
  97. torch.distributed.all_gather(chunks, tensor,
  98. group=get_tensor_model_parallel_group())
  99. return gathered
  100. class CudaRNGStatesTracker:
  101. """Tracker for the cuda RNG states.
  102. Using the `add` method, a cuda rng state is initialized based on
  103. the input `seed` and is assigned to `name`. Later, by forking the
  104. rng state, we can perform operations and return to our starting
  105. cuda state.
  106. """
  107. def __init__(self):
  108. # Map from a string name to the cuda rng state.
  109. self.states_ = {}
  110. # Seeds are just for book keeping and ensure no seed is set twice.
  111. self.seeds_ = set()
  112. def reset(self):
  113. """Set to the initial state (no tracker)."""
  114. self.states_ = {}
  115. self.seeds_ = set()
  116. def get_states(self):
  117. """Get rng states. Copy the dictionary so we have direct
  118. pointers to the states, not just a pointer to the dictionary."""
  119. states = {}
  120. for name in self.states_:
  121. states[name] = self.states_[name]
  122. return states
  123. def set_states(self, states):
  124. """Set the rng states. For efficiency purposes, we do not check
  125. the size of seed for compatibility."""
  126. self.states_ = states
  127. def add(self, name, seed):
  128. """Track the rng state."""
  129. # Check seed is not already used.
  130. if seed in self.seeds_:
  131. raise Exception('seed {} already exists'.format(seed))
  132. self.seeds_.add(seed)
  133. # Check that state is not already defined.
  134. if name in self.states_:
  135. raise Exception('cuda rng state {} already exists'.format(name))
  136. # Get the current rng state.
  137. orig_rng_state = torch.cuda.get_rng_state()
  138. # Set the new state and store it.
  139. torch.cuda.manual_seed(seed)
  140. self.states_[name] = torch.cuda.get_rng_state()
  141. # Reset rng state to what it was.
  142. _set_cuda_rng_state(orig_rng_state)
  143. @contextlib.contextmanager
  144. def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
  145. """Fork the cuda rng state, perform operations, and exit with
  146. the original state."""
  147. # Check if we have added the state
  148. if name not in self.states_:
  149. raise Exception('cuda rng state {} is not added'.format(name))
  150. # Store current rng state.
  151. orig_cuda_rng_state = torch.cuda.get_rng_state()
  152. # Set rng state to the desired one
  153. _set_cuda_rng_state(self.states_[name])
  154. # Do the stuff we wanted to do.
  155. try:
  156. yield
  157. finally:
  158. # Update the current rng state for later use.
  159. self.states_[name] = torch.cuda.get_rng_state()
  160. # And set the state to the original state we started with.
  161. _set_cuda_rng_state(orig_cuda_rng_state)
  162. # RNG tracker object.
  163. _CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
  164. def get_cuda_rng_tracker():
  165. """Get cuda rng tracker."""
  166. return _CUDA_RNG_STATE_TRACKER
  167. def model_parallel_cuda_manual_seed(seed):
  168. """Initialize model parallel cuda seed.
  169. This function should be called after the model parallel is
  170. initialized. Also, no torch.cuda.manual_seed should be called
  171. after this function. Basically, this is replacement for that
  172. function.
  173. Two set of RNG states are tracked:
  174. default state: This is for data parallelism and is the same among a
  175. set of model parallel GPUs but different across
  176. different model paralle groups. This is used for
  177. example for dropout in the non-tensor-model-parallel regions.
  178. tensor-model-parallel state: This state is different among a set of model
  179. parallel GPUs, but the same across data parallel
  180. groups. This is used for example for dropout in
  181. model parallel regions.
  182. """
  183. # 2718 is just for fun and any POSITIVE value will work.
  184. offset = seed + 2718
  185. tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()
  186. # Data parallel gets the original seed.
  187. data_parallel_seed = seed
  188. if torch.distributed.get_rank() == 0:
  189. print('> initializing model parallel cuda seeds on global rank {}, '
  190. 'model parallel rank {}, and data parallel rank {} with '
  191. 'model parallel seed: {} and data parallel seed: {}'.format(
  192. torch.distributed.get_rank(), get_tensor_model_parallel_rank(),
  193. get_data_parallel_rank(), tensor_model_parallel_seed,
  194. data_parallel_seed), flush=True)
  195. _CUDA_RNG_STATE_TRACKER.reset()
  196. # Set the default state.
  197. torch.cuda.manual_seed(data_parallel_seed)
  198. # and model parallel state.
  199. _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
  200. tensor_model_parallel_seed)
  201. class CheckpointFunction(torch.autograd.Function):
  202. """This function is adapted from torch.utils.checkpoint with
  203. two main changes:
  204. 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
  205. 2) the states in the model parallel tracker are also properly
  206. tracked/set/reset.
  207. """
  208. @staticmethod
  209. def forward(ctx, run_function, *args):
  210. ctx.run_function = run_function
  211. # Copy the rng states.
  212. ctx.fwd_cpu_rng_state = torch.get_rng_state()
  213. ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
  214. ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
  215. with torch.no_grad():
  216. outputs = run_function(*args)
  217. # Divide hidden states across model parallel group and only keep
  218. # the chunk corresponding to the current rank.
  219. if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
  220. ctx.input_0_shape = args[0].data.shape
  221. args[0].data = split_tensor_into_1d_equal_chunks(args[0].data)
  222. args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add(
  223. args[0].data)
  224. # Store everything.
  225. ctx.save_for_backward(*args)
  226. return outputs
  227. @staticmethod
  228. def backward(ctx, *args):
  229. if not torch.autograd._is_checkpoint_valid():
  230. raise RuntimeError("Checkpointing is not compatible with .grad(), "
  231. "please use .backward() if possible")
  232. inputs = ctx.saved_tensors
  233. if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
  234. inputs[0].data = gather_split_1d_tensor(inputs[0].data)
  235. inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
  236. # Store the current states.
  237. bwd_cpu_rng_state = torch.get_rng_state()
  238. bwd_cuda_rng_state = torch.cuda.get_rng_state()
  239. bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
  240. # Set the states to what it used to be before the forward pass.
  241. torch.set_rng_state(ctx.fwd_cpu_rng_state)
  242. _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
  243. get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
  244. # Compute the forward pass.
  245. detached_inputs = detach_variable(inputs)
  246. with torch.enable_grad():
  247. outputs = ctx.run_function(*detached_inputs)
  248. # Set the states back to what it was at the start of this function.
  249. torch.set_rng_state(bwd_cpu_rng_state)
  250. _set_cuda_rng_state(bwd_cuda_rng_state)
  251. get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
  252. if isinstance(outputs, torch.Tensor):
  253. outputs = (outputs,)
  254. torch.autograd.backward(outputs, args)
  255. grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
  256. for inp in detached_inputs)
  257. return (None,) + grads
  258. def checkpoint(function, *args):
  259. """Checkpoint a model or part of the model.
  260. This has been directly copied from torch.utils.checkpoint."""
  261. return CheckpointFunction.apply(function, *args)