layers.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # Parts of the code here are adapted from PyTorch
  16. # repo: https://github.com/pytorch/pytorch
  17. import math
  18. import torch
  19. import torch.nn.functional as F
  20. import torch.nn.init as init
  21. from torch.nn.parameter import Parameter
  22. from .initialize import get_tensor_model_parallel_rank
  23. from .initialize import get_tensor_model_parallel_world_size
  24. from .mappings import copy_to_tensor_model_parallel_region
  25. from .mappings import gather_from_tensor_model_parallel_region
  26. from .mappings import reduce_from_tensor_model_parallel_region
  27. from .mappings import scatter_to_tensor_model_parallel_region
  28. from .random import get_cuda_rng_tracker
  29. from .utils import divide
  30. from .utils import split_tensor_along_last_dim
  31. from .utils import VocabUtility
  32. from megatron import get_args
  33. _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
  34. 'partition_dim': -1,
  35. 'partition_stride': 1}
  36. def param_is_not_tensor_parallel_duplicate(param):
  37. return (hasattr(param, 'tensor_model_parallel') and
  38. param.tensor_model_parallel) or (
  39. get_tensor_model_parallel_rank() == 0)
  40. def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
  41. # Make sure the attributes are not set.
  42. for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
  43. assert not hasattr(tensor, attribute)
  44. # Set the attributes.
  45. setattr(tensor, 'tensor_model_parallel', is_parallel)
  46. setattr(tensor, 'partition_dim', dim)
  47. setattr(tensor, 'partition_stride', stride)
  48. def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
  49. def maybe_set(attribute, value):
  50. if not hasattr(tensor, attribute):
  51. setattr(tensor, attribute, value)
  52. for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
  53. maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
  54. def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
  55. def maybe_copy(attribute):
  56. if hasattr(source_tensor, attribute):
  57. setattr(destination_tensor, attribute,
  58. getattr(source_tensor, attribute))
  59. for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
  60. maybe_copy(attribute)
  61. def _initialize_affine_weight_gpu(weight, init_method,
  62. partition_dim, stride=1):
  63. """Initialize affine weight for model parallel on GPU."""
  64. set_tensor_model_parallel_attributes(tensor=weight,
  65. is_parallel=True,
  66. dim=partition_dim,
  67. stride=stride)
  68. with get_cuda_rng_tracker().fork():
  69. init_method(weight)
  70. def _initialize_affine_weight_cpu(weight, output_size, input_size,
  71. per_partition_size, partition_dim,
  72. init_method, stride=1,
  73. return_master_weight=False):
  74. """Initialize affine weight for model parallel.
  75. Build the master weight on all processes and scatter
  76. the relevant chunk."""
  77. set_tensor_model_parallel_attributes(tensor=weight,
  78. is_parallel=True,
  79. dim=partition_dim,
  80. stride=stride)
  81. # Initialize master weight
  82. master_weight = torch.empty(output_size, input_size,
  83. dtype=torch.float,
  84. requires_grad=False)
  85. init_method(master_weight)
  86. args = get_args()
  87. master_weight = master_weight.to(dtype=args.params_dtype)
  88. # Split and copy
  89. per_partition_per_stride_size = divide(per_partition_size, stride)
  90. weight_list = torch.split(master_weight, per_partition_per_stride_size,
  91. dim=partition_dim)
  92. rank = get_tensor_model_parallel_rank()
  93. world_size = get_tensor_model_parallel_world_size()
  94. my_weight_list = weight_list[rank::world_size]
  95. with torch.no_grad():
  96. torch.cat(my_weight_list, dim=partition_dim, out=weight)
  97. if return_master_weight:
  98. return master_weight
  99. return None
  100. class VocabParallelEmbedding(torch.nn.Module):
  101. """Embedding parallelized in the vocabulary dimension.
  102. This is mainly adapted from torch.nn.Embedding and all the default
  103. values are kept.
  104. Arguments:
  105. num_embeddings: vocabulary size.
  106. embedding_dim: size of hidden state.
  107. init_method: method to initialize weights.
  108. """
  109. def __init__(self, num_embeddings, embedding_dim,
  110. init_method=init.xavier_normal_):
  111. super(VocabParallelEmbedding, self).__init__()
  112. # Keep the input dimensions.
  113. self.num_embeddings = num_embeddings
  114. self.embedding_dim = embedding_dim
  115. # Set the detauls for compatibility.
  116. self.padding_idx = None
  117. self.max_norm = None
  118. self.norm_type = 2.
  119. self.scale_grad_by_freq = False
  120. self.sparse = False
  121. self._weight = None
  122. self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
  123. # Divide the weight matrix along the vocaburaly dimension.
  124. self.vocab_start_index, self.vocab_end_index = \
  125. VocabUtility.vocab_range_from_global_vocab_size(
  126. self.num_embeddings, get_tensor_model_parallel_rank(),
  127. self.tensor_model_parallel_size)
  128. self.num_embeddings_per_partition = self.vocab_end_index - \
  129. self.vocab_start_index
  130. # Allocate weights and initialize.
  131. args = get_args()
  132. if args.use_cpu_initialization:
  133. self.weight = Parameter(torch.empty(
  134. self.num_embeddings_per_partition, self.embedding_dim,
  135. dtype=args.params_dtype))
  136. _initialize_affine_weight_cpu(
  137. self.weight, self.num_embeddings, self.embedding_dim,
  138. self.num_embeddings_per_partition, 0, init_method)
  139. else:
  140. self.weight = Parameter(torch.empty(
  141. self.num_embeddings_per_partition, self.embedding_dim,
  142. device=torch.cuda.current_device(), dtype=args.params_dtype))
  143. _initialize_affine_weight_gpu(self.weight, init_method,
  144. partition_dim=0, stride=1)
  145. def forward(self, input_):
  146. if self.tensor_model_parallel_size > 1:
  147. # Build the mask.
  148. input_mask = (input_ < self.vocab_start_index) | \
  149. (input_ >= self.vocab_end_index)
  150. # Mask the input.
  151. masked_input = input_.clone() - self.vocab_start_index
  152. masked_input[input_mask] = 0
  153. else:
  154. masked_input = input_
  155. # Get the embeddings.
  156. output_parallel = F.embedding(masked_input, self.weight,
  157. self.padding_idx, self.max_norm,
  158. self.norm_type, self.scale_grad_by_freq,
  159. self.sparse)
  160. # Mask the output embedding.
  161. if self.tensor_model_parallel_size > 1:
  162. output_parallel[input_mask, :] = 0.0
  163. # Reduce across all the model parallel GPUs.
  164. output = reduce_from_tensor_model_parallel_region(output_parallel)
  165. return output
  166. class ColumnParallelLinear(torch.nn.Module):
  167. """Linear layer with column parallelism.
  168. The linear layer is defined as Y = XA + b. A is parallelized along
  169. its second dimension as A = [A_1, ..., A_p].
  170. Arguments:
  171. input_size: first dimension of matrix A.
  172. output_size: second dimension of matrix A.
  173. bias: If true, add bias
  174. gather_output: If true, call all-gether on output and make Y avaiable
  175. to all GPUs, otherwise, every GPU will have its output
  176. which is Y_i = XA_i
  177. init_method: method to initialize weights. Note that bias is always set
  178. to zero.
  179. stride: For the strided linear layers.
  180. keep_master_weight_for_test: This was added for testing and should be
  181. set to False. It returns the master weights
  182. used for initialization.
  183. skip_bias_add: This was added to enable performance optimations where bias
  184. can be fused with other elementwise operations. we skip
  185. adding bias but instead return it.
  186. """
  187. def __init__(self, input_size, output_size, bias=True, gather_output=True,
  188. init_method=init.xavier_normal_, stride=1,
  189. keep_master_weight_for_test=False,
  190. skip_bias_add=False):
  191. super(ColumnParallelLinear, self).__init__()
  192. # Keep input parameters
  193. self.input_size = input_size
  194. self.output_size = output_size
  195. self.gather_output = gather_output
  196. # Divide the weight matrix along the last dimension.
  197. world_size = get_tensor_model_parallel_world_size()
  198. self.output_size_per_partition = divide(output_size, world_size)
  199. self.skip_bias_add = skip_bias_add
  200. # Parameters.
  201. # Note: torch.nn.functional.linear performs XA^T + b and as a result
  202. # we allocate the transpose.
  203. # Initialize weight.
  204. args = get_args()
  205. if args.use_cpu_initialization:
  206. self.weight = Parameter(torch.empty(self.output_size_per_partition,
  207. self.input_size,
  208. dtype=args.params_dtype))
  209. self.master_weight = _initialize_affine_weight_cpu(
  210. self.weight, self.output_size, self.input_size,
  211. self.output_size_per_partition, 0, init_method,
  212. stride=stride, return_master_weight=keep_master_weight_for_test)
  213. else:
  214. self.weight = Parameter(torch.empty(
  215. self.output_size_per_partition, self.input_size,
  216. device=torch.cuda.current_device(), dtype=args.params_dtype))
  217. _initialize_affine_weight_gpu(self.weight, init_method,
  218. partition_dim=0, stride=stride)
  219. if bias:
  220. if args.use_cpu_initialization:
  221. self.bias = Parameter(torch.empty(
  222. self.output_size_per_partition, dtype=args.params_dtype))
  223. else:
  224. self.bias = Parameter(torch.empty(
  225. self.output_size_per_partition,
  226. device=torch.cuda.current_device(),
  227. dtype=args.params_dtype))
  228. set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
  229. # Always initialize bias to zero.
  230. with torch.no_grad():
  231. self.bias.zero_()
  232. else:
  233. self.register_parameter('bias', None)
  234. def forward(self, input_):
  235. # Set up backprop all-reduce.
  236. input_parallel = copy_to_tensor_model_parallel_region(input_)
  237. # Matrix multiply.
  238. bias = self.bias if not self.skip_bias_add else None
  239. output_parallel = F.linear(input_parallel, self.weight, bias)
  240. if self.gather_output:
  241. # All-gather across the partitions.
  242. output = gather_from_tensor_model_parallel_region(output_parallel)
  243. else:
  244. output = output_parallel
  245. output_bias = self.bias if self.skip_bias_add else None
  246. return output, output_bias
  247. class RowParallelLinear(torch.nn.Module):
  248. """Linear layer with row parallelism.
  249. The linear layer is defined as Y = XA + b. A is parallelized along
  250. its first dimension and X along its second dimension as:
  251. - -
  252. | A_1 |
  253. | . |
  254. A = | . | X = [X_1, ..., X_p]
  255. | . |
  256. | A_p |
  257. - -
  258. Arguments:
  259. input_size: first dimension of matrix A.
  260. output_size: second dimension of matrix A.
  261. bias: If true, add bias. Note that bias is not parallelized.
  262. input_is_parallel: If true, we assume that the input is already
  263. split across the GPUs and we do not split
  264. again.
  265. init_method: method to initialize weights. Note that bias is always set
  266. to zero.
  267. stride: For the strided linear layers.
  268. keep_master_weight_for_test: This was added for testing and should be
  269. set to False. It returns the master weights
  270. used for initialization.
  271. skip_bias_add: This was added to enable performance optimations where bias
  272. can be fused with other elementwise operations. we skip
  273. adding bias but instead return it.
  274. """
  275. def __init__(self, input_size, output_size, bias=True,
  276. input_is_parallel=False,
  277. init_method=init.xavier_normal_, stride=1,
  278. keep_master_weight_for_test=False,
  279. skip_bias_add=False):
  280. super(RowParallelLinear, self).__init__()
  281. # Keep input parameters
  282. self.input_size = input_size
  283. self.output_size = output_size
  284. self.input_is_parallel = input_is_parallel
  285. # Divide the weight matrix along the last dimension.
  286. world_size = get_tensor_model_parallel_world_size()
  287. self.input_size_per_partition = divide(input_size, world_size)
  288. self.skip_bias_add = skip_bias_add
  289. # Parameters.
  290. # Note: torch.nn.functional.linear performs XA^T + b and as a result
  291. # we allocate the transpose.
  292. # Initialize weight.
  293. args = get_args()
  294. if args.use_cpu_initialization:
  295. self.weight = Parameter(torch.empty(self.output_size,
  296. self.input_size_per_partition,
  297. dtype=args.params_dtype))
  298. self.master_weight = _initialize_affine_weight_cpu(
  299. self.weight, self.output_size, self.input_size,
  300. self.input_size_per_partition, 1, init_method,
  301. stride=stride, return_master_weight=keep_master_weight_for_test)
  302. else:
  303. self.weight = Parameter(torch.empty(
  304. self.output_size, self.input_size_per_partition,
  305. device=torch.cuda.current_device(), dtype=args.params_dtype))
  306. _initialize_affine_weight_gpu(self.weight, init_method,
  307. partition_dim=1, stride=stride)
  308. if bias:
  309. if args.use_cpu_initialization:
  310. self.bias = Parameter(torch.empty(self.output_size,
  311. dtype=args.params_dtype))
  312. else:
  313. self.bias = Parameter(torch.empty(
  314. self.output_size, device=torch.cuda.current_device(),
  315. dtype=args.params_dtype))
  316. # Always initialize bias to zero.
  317. with torch.no_grad():
  318. self.bias.zero_()
  319. else:
  320. self.register_parameter('bias', None)
  321. def forward(self, input_):
  322. # Set up backprop all-reduce.
  323. if self.input_is_parallel:
  324. input_parallel = input_
  325. else:
  326. input_parallel = scatter_to_tensor_model_parallel_region(input_)
  327. # Matrix multiply.
  328. output_parallel = F.linear(input_parallel, self.weight)
  329. # All-reduce across all the partitions.
  330. output_ = reduce_from_tensor_model_parallel_region(output_parallel)
  331. if not self.skip_bias_add:
  332. output = output_ + self.bias if self.bias is not None else output_
  333. output_bias = None
  334. else:
  335. output = output_
  336. output_bias = self.bias
  337. return output, output_bias