test_layers.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from mpu import layers
  16. from commons import set_random_seed
  17. from commons import print_separator
  18. from commons import initialize_distributed
  19. import mpu
  20. from torch.nn.parameter import Parameter
  21. import torch.nn.init as init
  22. import torch
  23. import random
  24. import sys
  25. sys.path.append("../..")
  26. def test_parallel_embedding(tensor_model_parallel_size):
  27. if torch.distributed.get_rank() == 0:
  28. print('> testing parallel embedding with model parallel size {} ...'.
  29. format(tensor_model_parallel_size))
  30. mpu.initialize_model_parallel(tensor_model_parallel_size)
  31. tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
  32. batch_size = 17
  33. seq_length = 23
  34. vocab_size = 48
  35. hidden_size = 16
  36. seed = 1236
  37. set_random_seed(123)
  38. input_data = torch.LongTensor(
  39. size=(batch_size, seq_length)).random_(0, vocab_size).cuda()
  40. loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda()
  41. set_random_seed(seed)
  42. embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda()
  43. output = embedding_original(input_data)
  44. loss_original = torch.mul(output, loss_weight).sum()
  45. loss_original.backward()
  46. set_random_seed(seed)
  47. embedding_parallel = layers.ParallelEmbedding(
  48. vocab_size, hidden_size, init_method=init.normal_).cuda()
  49. output = embedding_parallel(input_data)
  50. loss_parallel = torch.mul(output, loss_weight).sum()
  51. loss_parallel.backward()
  52. set_random_seed(seed)
  53. embedding_vocab_parallel = layers.VocabParallelEmbedding(
  54. vocab_size, hidden_size, init_method=init.normal_).cuda()
  55. output = embedding_vocab_parallel(input_data)
  56. loss_vocab_parallel = torch.mul(output, loss_weight).sum()
  57. loss_vocab_parallel.backward()
  58. torch.distributed.barrier()
  59. error = loss_parallel.sub(loss_original).abs()
  60. print(' error in loss (parallel) on global rank {}: {}'.format(
  61. torch.distributed.get_rank(), error))
  62. assert error < 1.0e-12, 'error: {}'.format(error)
  63. torch.distributed.barrier()
  64. error = loss_vocab_parallel.sub(loss_original).abs()
  65. print(' error in loss (vocab parallel) on global rank {}: {}'.format(
  66. torch.distributed.get_rank(), error))
  67. assert error < 1.0e-12, 'error: {}'.format(error)
  68. weight_grad_orig = torch.split(embedding_original.weight.grad,
  69. hidden_size // tensor_model_parallel_size,
  70. 1)[mpu.get_tensor_model_parallel_rank()]
  71. error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
  72. print(' error in grad (parallel) on global rank {}: {}'.format(
  73. torch.distributed.get_rank(), error))
  74. assert error < 1.0e-12, 'error: {}'.format(error)
  75. weight_grad_orig = torch.split(embedding_original.weight.grad,
  76. vocab_size // tensor_model_parallel_size,
  77. 0)[mpu.get_tensor_model_parallel_rank()]
  78. error = embedding_vocab_parallel.weight.grad.sub(
  79. weight_grad_orig).abs().max()
  80. print(' error in grad (vocab parallel) on global rank {}: {}'.format(
  81. torch.distributed.get_rank(), error))
  82. assert error < 1.0e-12, 'error: {}'.format(error)
  83. # Reset groups
  84. mpu.destroy_model_parallel()
  85. torch.distributed.barrier()
  86. if torch.distributed.get_rank() == 0:
  87. print('>> passed the test :-)')
  88. def test_initialize_affine_weight(tensor_model_parallel_size):
  89. mpu.initialize_model_parallel(tensor_model_parallel_size)
  90. if torch.distributed.get_rank() == 0:
  91. print('> testing initialize_affine_weight with model parallel '
  92. 'size: {}'.format(tensor_model_parallel_size))
  93. tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
  94. seed = 12345
  95. input_size_coeff = 13
  96. input_size = input_size_coeff * tensor_model_parallel_size
  97. output_size_coeff = 17
  98. output_size = output_size_coeff * tensor_model_parallel_size
  99. # ---------------
  100. # Column parallel
  101. # ---------------
  102. weight = torch.empty(output_size_coeff, input_size)
  103. set_random_seed(seed)
  104. layers._initialize_affine_weight(weight, output_size, input_size,
  105. output_size_coeff, 0,
  106. torch.nn.init.normal_)
  107. # Target.
  108. set_random_seed(seed)
  109. master_weight = torch.empty(output_size, input_size)
  110. torch.nn.init.normal_(master_weight)
  111. rank = mpu.get_tensor_model_parallel_rank()
  112. my_weight = torch.split(master_weight, output_size_coeff,
  113. dim=0)[rank].contiguous().clone()
  114. # Compare.
  115. error = weight.sub(my_weight).abs().max()
  116. torch.distributed.barrier()
  117. print(' column parallel max error (should be zero) on global rank '
  118. '{}: {}'.format(torch.distributed.get_rank(), error))
  119. assert error < 1.0e-6
  120. # ------------
  121. # Row parallel
  122. # ------------
  123. weight = torch.empty(output_size, input_size_coeff)
  124. set_random_seed(seed)
  125. mpu.layers._initialize_affine_weight(weight, output_size, input_size,
  126. input_size_coeff, 1,
  127. torch.nn.init.normal_)
  128. # Target.
  129. set_random_seed(seed)
  130. master_weight = torch.empty(output_size, input_size)
  131. torch.nn.init.normal_(master_weight)
  132. rank = mpu.get_tensor_model_parallel_rank()
  133. my_weight = torch.split(master_weight, input_size_coeff,
  134. dim=1)[rank].contiguous().clone()
  135. # Compare.
  136. error = weight.sub(my_weight).abs().max()
  137. torch.distributed.barrier()
  138. print(' row parallel max error (should be zero) on global rank '
  139. '{}: {}'.format(torch.distributed.get_rank(), error))
  140. assert error < 1.0e-6
  141. # Reset groups
  142. mpu.destroy_model_parallel()
  143. torch.distributed.barrier()
  144. if torch.distributed.get_rank() == 0:
  145. print(' >> passed the test :-)')
  146. class IdentityLayer2D(torch.nn.Module):
  147. def __init__(self, m, n):
  148. super(IdentityLayer2D, self).__init__()
  149. self.weight = Parameter(torch.Tensor(m, n))
  150. torch.nn.init.xavier_normal_(self.weight)
  151. def forward(self):
  152. return self.weight
  153. def test_column_parallel_linear(tensor_model_parallel_size):
  154. mpu.initialize_model_parallel(tensor_model_parallel_size)
  155. if torch.distributed.get_rank() == 0:
  156. print('> testing ColumnParallelLinear with model parallel '
  157. 'size: {}'.format(tensor_model_parallel_size))
  158. tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
  159. seed = 12345
  160. set_random_seed(seed)
  161. input_size_coeff = 13
  162. input_size = input_size_coeff * tensor_model_parallel_size
  163. output_size_coeff = 17
  164. output_size = output_size_coeff * tensor_model_parallel_size
  165. batch_size = 7
  166. # Network
  167. identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
  168. linear_layer = mpu.ColumnParallelLinear(
  169. input_size, output_size, keep_master_weight_for_test=True).cuda()
  170. loss_weight = torch.randn([batch_size, output_size]).cuda()
  171. # Forward
  172. input_ = identity_layer()
  173. output = linear_layer(input_)
  174. loss = torch.mul(output, loss_weight).sum()
  175. # Backward
  176. loss.backward()
  177. # Values.
  178. dLdY = loss_weight
  179. X = identity_layer.weight
  180. A = linear_layer.master_weight.cuda()
  181. dLdA = torch.matmul(dLdY.t(), X)
  182. dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
  183. dLdX = torch.matmul(dLdY, A)
  184. rank = mpu.get_tensor_model_parallel_rank()
  185. my_dLdA = torch.split(dLdA, output_size_coeff,
  186. dim=0)[rank].contiguous().clone()
  187. error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
  188. torch.distributed.barrier()
  189. print(' error in dLdA on global rank {}: {}'.format(
  190. torch.distributed.get_rank(), error))
  191. assert error < 1.0e-6
  192. my_dLdb = torch.split(dLdb, output_size_coeff,
  193. dim=0)[rank].contiguous().clone()
  194. error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
  195. torch.distributed.barrier()
  196. print(' error in dLdb on global rank {}: {}'.format(
  197. torch.distributed.get_rank(), error))
  198. assert error < 1.0e-6
  199. error = dLdX.sub(identity_layer.weight.grad).abs().max()
  200. torch.distributed.barrier()
  201. print(' error in dLdX on global rank {}: {}'.format(
  202. torch.distributed.get_rank(), error))
  203. assert error < 1.0e-6
  204. # Reset groups
  205. mpu.destroy_model_parallel()
  206. torch.distributed.barrier()
  207. if torch.distributed.get_rank() == 0:
  208. print(' >> passed the test :-)')
  209. def test_row_parallel_linear(tensor_model_parallel_size):
  210. mpu.initialize_model_parallel(tensor_model_parallel_size)
  211. if torch.distributed.get_rank() == 0:
  212. print('> testing RowParallelLinear with model parallel '
  213. 'size: {}'.format(tensor_model_parallel_size))
  214. tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
  215. seed = 12345
  216. set_random_seed(seed)
  217. input_size_coeff = 13
  218. input_size = input_size_coeff * tensor_model_parallel_size
  219. output_size_coeff = 17
  220. output_size = output_size_coeff * tensor_model_parallel_size
  221. batch_size = 7
  222. # Network
  223. identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
  224. linear_layer = mpu.RowParallelLinear(
  225. input_size, output_size, keep_master_weight_for_test=True).cuda()
  226. loss_weight = torch.randn([batch_size, output_size]).cuda()
  227. # Forward
  228. input_ = identity_layer()
  229. output = linear_layer(input_)
  230. loss = torch.mul(output, loss_weight).sum()
  231. # Backward
  232. loss.backward()
  233. # Values.
  234. dLdY = loss_weight
  235. X = identity_layer.weight
  236. A = linear_layer.master_weight.cuda()
  237. dLdA = torch.matmul(dLdY.t(), X)
  238. dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
  239. dLdX = torch.matmul(dLdY, A)
  240. rank = mpu.get_tensor_model_parallel_rank()
  241. my_dLdA = torch.split(dLdA, input_size_coeff,
  242. dim=1)[rank].contiguous().clone()
  243. error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
  244. torch.distributed.barrier()
  245. print(' error in dLdA on global rank {}: {}'.format(
  246. torch.distributed.get_rank(), error))
  247. assert error < 1.0e-6
  248. error = dLdb.sub(linear_layer.bias.grad).abs().max()
  249. torch.distributed.barrier()
  250. print(' error in dLdb on global rank {}: {}'.format(
  251. torch.distributed.get_rank(), error))
  252. assert error < 1.0e-6
  253. error = dLdX.sub(identity_layer.weight.grad).abs().max()
  254. torch.distributed.barrier()
  255. print(' error in dLdX on global rank {}: {}'.format(
  256. torch.distributed.get_rank(), error))
  257. assert error < 1.0e-6
  258. # Reset groups
  259. mpu.destroy_model_parallel()
  260. torch.distributed.barrier()
  261. if torch.distributed.get_rank() == 0:
  262. print(' >> passed the test :-)')
  263. class IdentityLayer3D(torch.nn.Module):
  264. def __init__(self, m, n, k):
  265. super(IdentityLayer3D, self).__init__()
  266. self.weight = Parameter(torch.Tensor(m, n, k))
  267. torch.nn.init.xavier_normal_(self.weight)
  268. def forward(self):
  269. return self.weight
  270. def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition,
  271. hidden_size_per_att_head, dropout_prob, batch_size,
  272. sequence_length):
  273. mpu.initialize_model_parallel(tensor_model_parallel_size)
  274. tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
  275. seed = 12345
  276. set_random_seed(seed)
  277. num_att_heads = num_att_heads_per_partition * \
  278. torch.distributed.get_world_size()
  279. hidden_size = hidden_size_per_att_head * num_att_heads
  280. # Network
  281. identity_layer = IdentityLayer3D(batch_size, sequence_length,
  282. hidden_size).cuda()
  283. attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads,
  284. dropout_prob).cuda()
  285. loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
  286. attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
  287. # Forward
  288. input_ = identity_layer()
  289. output = attention_layer(input_, attention_mask)
  290. loss = torch.mul(output, loss_weight).sum()
  291. # Backward
  292. loss.backward()
  293. rank = mpu.get_tensor_model_parallel_rank()
  294. mpu.destroy_model_parallel()
  295. return rank, hidden_size, tensor_model_parallel_size, loss, \
  296. attention_layer, identity_layer
  297. def test_parallel_self_attention(tensor_model_parallel_size):
  298. if torch.distributed.get_rank() == 0:
  299. print('> testing ParallelSelfAttention with model parallel '
  300. 'size: {}'.format(tensor_model_parallel_size))
  301. num_att_heads_per_partition = 3
  302. hidden_size_per_att_head = 7
  303. dropout_prob = 0.0 # has to be zero
  304. batch_size = 5
  305. sequence_length = 13
  306. rank_1, hideen_size_1, tensor_model_parallel_size_1, loss_1, \
  307. attention_layer_1, identity_layer_1 = parallel_self_attention(
  308. 1, num_att_heads_per_partition,
  309. hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
  310. rank, hidden_size, tensor_model_parallel_size, loss, \
  311. attention_layer, identity_layer = parallel_self_attention(
  312. tensor_model_parallel_size, num_att_heads_per_partition,
  313. hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
  314. assert hideen_size_1 == hidden_size
  315. error = loss_1.sub(loss).abs().max()
  316. torch.distributed.barrier()
  317. print(' loss error on global rank {}: {}'.format(
  318. torch.distributed.get_rank(), error))
  319. assert error < 5.0e-6
  320. my_lin_grad_list = torch.split(
  321. attention_layer_1.query_key_value.weight.grad,
  322. hidden_size // tensor_model_parallel_size, 0)[rank::tensor_model_parallel_size]
  323. my_lin_grad = torch.cat(my_lin_grad_list, dim=0)
  324. error = my_lin_grad.sub(
  325. attention_layer.query_key_value.weight.grad).abs().max()
  326. torch.distributed.barrier()
  327. print(' weight gradient error on global rank {}: {}'.format(
  328. torch.distributed.get_rank(), error))
  329. assert error < 5.0e-6
  330. error = identity_layer_1.weight.grad.sub(
  331. identity_layer.weight.grad).abs().max()
  332. torch.distributed.barrier()
  333. print(' input gradient error on global rank {}: {}'.format(
  334. torch.distributed.get_rank(), error))
  335. assert error < 5.0e-6
  336. torch.distributed.barrier()
  337. if torch.distributed.get_rank() == 0:
  338. print(' >> passed the test :-)')
  339. def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition,
  340. hidden_size_per_att_head, batch_size, sequence_length):
  341. mpu.initialize_model_parallel(tensor_model_parallel_size)
  342. tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
  343. seed = 12345
  344. set_random_seed(seed)
  345. num_att_heads = num_att_heads_per_partition * \
  346. torch.distributed.get_world_size()
  347. hidden_size = hidden_size_per_att_head * num_att_heads
  348. intermediate_size = 4 * hidden_size
  349. # Network
  350. identity_layer = IdentityLayer3D(batch_size, sequence_length,
  351. hidden_size).cuda()
  352. transformer_layer = mpu.BertParallelTransformerLayer(
  353. hidden_size, intermediate_size, num_att_heads, 0.0, 0.0,
  354. torch.nn.functional.relu, 1.0e-5).cuda()
  355. loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
  356. attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
  357. # Forward
  358. input_ = identity_layer()
  359. output = transformer_layer(input_, attention_mask)
  360. loss = torch.mul(output, loss_weight).sum()
  361. # Backward
  362. loss.backward()
  363. rank = mpu.get_tensor_model_parallel_rank()
  364. mpu.destroy_model_parallel()
  365. return rank, hidden_size, tensor_model_parallel_size, loss, \
  366. transformer_layer, identity_layer
  367. def test_parallel_transformer_layer(tensor_model_parallel_size):
  368. if torch.distributed.get_rank() == 0:
  369. print('> testing ParallelTransformerLayer with model parallel '
  370. 'size: {}'.format(tensor_model_parallel_size))
  371. num_att_heads_per_partition = 3
  372. hidden_size_per_att_head = 7
  373. batch_size = 5
  374. sequence_length = 13
  375. rank_1, hidden_size_1, tensor_model_parallel_size_1, loss_1, \
  376. transformer_layer_1, identity_layer_1 = parallel_transformer(
  377. 1, num_att_heads_per_partition,
  378. hidden_size_per_att_head, batch_size, sequence_length)
  379. rank, hidden_size, tensor_model_parallel_size, loss, \
  380. transformer_layer, identity_layer = parallel_transformer(
  381. tensor_model_parallel_size, num_att_heads_per_partition,
  382. hidden_size_per_att_head, batch_size, sequence_length)
  383. error = loss_1.sub(loss).abs().max()
  384. torch.distributed.barrier()
  385. print(' loss error on global rank {}: {}'.format(
  386. torch.distributed.get_rank(), error))
  387. assert error < 5.0e-5, 'error: {}'.format(error)
  388. error = identity_layer_1.weight.grad.sub(
  389. identity_layer.weight.grad).abs().max()
  390. torch.distributed.barrier()
  391. print(' input gradient error on global rank {}: {}'.format(
  392. torch.distributed.get_rank(), error))
  393. assert error < 5.0e-5, 'error: {}'.format(error)
  394. torch.distributed.barrier()
  395. if torch.distributed.get_rank() == 0:
  396. print(' >> passed the test :-)')
  397. if __name__ == '__main__':
  398. torch.backends.cudnn.deterministic = True
  399. torch.backends.cudnn.benchmark = False
  400. initialize_distributed()
  401. world_size = torch.distributed.get_world_size()
  402. print_separator('test initialize affine weight')
  403. tensor_model_parallel_size = 1
  404. while tensor_model_parallel_size <= world_size:
  405. test_initialize_affine_weight(tensor_model_parallel_size)
  406. tensor_model_parallel_size *= 2
  407. tensor_model_parallel_size = 1
  408. while tensor_model_parallel_size <= world_size:
  409. print_separator('test parallel embedding')
  410. test_parallel_embedding(tensor_model_parallel_size)
  411. tensor_model_parallel_size *= 2
  412. print_separator('test column-parallel linear')
  413. tensor_model_parallel_size = 1
  414. while tensor_model_parallel_size <= world_size:
  415. test_column_parallel_linear(tensor_model_parallel_size)
  416. tensor_model_parallel_size *= 2
  417. print_separator('test row-parallel linear')
  418. tensor_model_parallel_size = 1
  419. while tensor_model_parallel_size <= world_size:
  420. test_row_parallel_linear(tensor_model_parallel_size)
  421. tensor_model_parallel_size *= 2
  422. print_separator('test parallel self-attention')
  423. tensor_model_parallel_size = 1
  424. while tensor_model_parallel_size <= world_size:
  425. test_parallel_self_attention(tensor_model_parallel_size)
  426. tensor_model_parallel_size *= 2
  427. print_separator('test parallel transformer')
  428. tensor_model_parallel_size = 1
  429. while tensor_model_parallel_size <= world_size:
  430. test_parallel_transformer_layer(tensor_model_parallel_size)
  431. tensor_model_parallel_size *= 2