text_generation_utils.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Utilities for generating text."""
  16. import copy
  17. import json
  18. import os
  19. import time
  20. import torch
  21. import torch.nn.functional as F
  22. from megatron import get_args
  23. from megatron import get_tokenizer
  24. from megatron import mpu
  25. from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
  26. from megatron.p2p_communication import recv_forward, send_forward
  27. # These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
  28. from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
  29. from megatron.model import DistributedDataParallel as LocalDDP
  30. from megatron.model import Float16Module
  31. def get_batch(context_tokens):
  32. """Generate batch from context tokens."""
  33. args = get_args()
  34. tokenizer = get_tokenizer()
  35. # Move to GPU.
  36. tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
  37. # Get the attention mask and postition ids.
  38. attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
  39. tokens,
  40. tokenizer.eod,
  41. args.reset_position_ids,
  42. args.reset_attention_mask,
  43. args.eod_mask_loss)
  44. return tokens, attention_mask, position_ids
  45. def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
  46. """ This function has been mostly taken from huggingface conversational
  47. ai code at
  48. https://medium.com/huggingface/how-to-build-a-state-of-the-art-
  49. conversational-ai-with-transfer-learning-2d818ac26313 """
  50. if top_k > 0:
  51. # Remove all tokens with a probability less than the
  52. # last token of the top-k
  53. indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
  54. logits[indices_to_remove] = filter_value
  55. if top_p > 0.0:
  56. # Cconvert to 1D
  57. sorted_logits, sorted_indices = torch.sort(
  58. logits, descending=True, dim=-1)
  59. cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
  60. dim=-1)
  61. # Remove tokens with cumulative probability above the threshold
  62. sorted_indices_to_remove = cumulative_probs > top_p
  63. # Shift the indices to the right to keep also the first token
  64. # above the threshold
  65. sorted_indices_to_remove[..., 1:] \
  66. = sorted_indices_to_remove[..., :-1].clone()
  67. sorted_indices_to_remove[..., 0] = 0
  68. for i in range(sorted_indices.size(0)):
  69. indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
  70. logits[i][indices_to_remove] = filter_value
  71. return logits
  72. def generate_samples_input_from_file(model):
  73. args = get_args()
  74. tokenizer = get_tokenizer()
  75. # Read the sample file and open the output file.
  76. assert args.sample_input_file is not None, \
  77. 'sample input file is not provided.'
  78. if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
  79. fname = open(args.sample_input_file, "r")
  80. all_raw_text = fname.readlines()
  81. input_count = len(all_raw_text)
  82. input_pos = 0
  83. if args.sample_output_file is None:
  84. sample_output_file = args.sample_input_file + ".out"
  85. print('`sample-output-file` not specified, setting '
  86. 'it to {}'.format(sample_output_file))
  87. else:
  88. sample_output_file = args.sample_output_file
  89. fname_out = open(sample_output_file, "w+")
  90. context_count = 0
  91. model.eval()
  92. with torch.no_grad():
  93. while True:
  94. terminate_runs = 0
  95. raw_text_len = 0
  96. if mpu.is_pipeline_first_stage() \
  97. and mpu.get_tensor_model_parallel_rank() == 0:
  98. raw_text = all_raw_text[input_pos]
  99. input_pos += 1
  100. if input_pos == input_count:
  101. raw_text = "stop"
  102. raw_text_len = len(raw_text)
  103. if "stop" in raw_text:
  104. terminate_runs = 1
  105. else:
  106. context_tokens = tokenizer.tokenize(raw_text)
  107. context_length = len(context_tokens)
  108. if context_length >= (args.seq_length // 2):
  109. print("\nContext length", context_length,
  110. "\nPlease give smaller context (half of the "
  111. "sequence length)!", flush=True)
  112. continue
  113. else:
  114. context_tokens = tokenizer.tokenize("EMPTY TEXT")
  115. context_length = 0
  116. input_info = [terminate_runs, raw_text_len, context_length]
  117. input_info_tensor = torch.cuda.LongTensor(input_info)
  118. torch.distributed.all_reduce(input_info_tensor,
  119. group=mpu.get_model_parallel_group())
  120. terminate_runs = input_info_tensor[0].item()
  121. raw_text_len = input_info_tensor[1].item()
  122. context_length = input_info_tensor[2].item()
  123. if terminate_runs == 1:
  124. return
  125. # For pipeline parallel we send context tokens to other stages
  126. # so they get the lengths correct
  127. if mpu.get_tensor_model_parallel_rank() == 0 \
  128. and args.pipeline_model_parallel_size > 1:
  129. if mpu.is_pipeline_first_stage():
  130. src = mpu.get_pipeline_model_parallel_first_rank()
  131. group = mpu.get_pipeline_model_parallel_group()
  132. context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
  133. torch.distributed.broadcast(context_tokens_tensor, src, group)
  134. else:
  135. src = mpu.get_pipeline_model_parallel_first_rank()
  136. group = mpu.get_pipeline_model_parallel_group()
  137. context_tokens_tensor = torch.empty(context_length,
  138. dtype=torch.int64,
  139. device=torch.device("cuda"))
  140. torch.distributed.broadcast(context_tokens_tensor, src, group)
  141. context_tokens = context_tokens_tensor.cpu().numpy().tolist()
  142. token_stream = get_token_stream(model, [context_tokens])
  143. for _, decode_tokens in enumerate(token_stream):
  144. pass
  145. if mpu.get_tensor_model_parallel_rank() == 0:
  146. if mpu.is_pipeline_first_stage():
  147. os.system('clear')
  148. print("\nContext:", raw_text, flush=True)
  149. fname_out.write("\nContext:")
  150. fname_out.write(raw_text)
  151. decode_tokens, _ = decode_tokens
  152. decode_tokens = decode_tokens[0].cpu().numpy().tolist()
  153. trim_decode_tokens = tokenizer.detokenize(
  154. decode_tokens)[raw_text_len:]
  155. print("\nMegatron-LM:", trim_decode_tokens, flush=True)
  156. fname_out.write("\n\nMegatron-LM:")
  157. fname_out.write(trim_decode_tokens)
  158. fname_out.write("\n")
  159. raw_text = None
  160. context_count += 1
  161. # We added this function to support the tasks evaluation such as squad
  162. # and drop in the https://github.com/EleutherAI/lm-evaluation-harness
  163. # codebase. The lm-evaluation-harness code can now call this function
  164. # similar to their current generate function call used for gpt style models.
  165. def generate_samples_eval(model, context, max_gen_length, eos_token_id):
  166. # Generate samples for lm evaluation
  167. # NEED TO THINK ABOUT eos token
  168. args = get_args()
  169. tokenizer = get_tokenizer()
  170. raw_text_len = len(context)
  171. model.eval()
  172. context_tokens = tokenizer.tokenize(context)
  173. args.out_seq_length = max_gen_length + len(context_tokens)
  174. args.eos_id = eos_token_id
  175. with torch.no_grad():
  176. token_stream = get_token_stream(model, [context_tokens])
  177. for counter, decode_tokens in enumerate(token_stream):
  178. if counter == args.out_seq_length:
  179. break
  180. decode_tokens, _ = decode_tokens
  181. decode_tokens = decode_tokens[0].cpu().numpy().tolist()
  182. trim_decode_tokens = tokenizer.detokenize(
  183. decode_tokens)[raw_text_len:]
  184. return trim_decode_tokens
  185. def generate_samples_interactive(model, print_frequency=24):
  186. args = get_args()
  187. tokenizer = get_tokenizer()
  188. context_count = 0
  189. model.eval()
  190. with torch.no_grad():
  191. while True:
  192. terminate_runs = 0
  193. raw_text_len = 0
  194. if mpu.is_pipeline_first_stage() \
  195. and mpu.get_tensor_model_parallel_rank() == 0:
  196. os.system('clear')
  197. raw_text = input("\nContext prompt (stop to exit) >>> ")
  198. while not raw_text:
  199. print('Prompt should not be empty!')
  200. raw_text = input("\nContext prompt (stop to exit) >>> ")
  201. raw_text_len = len(raw_text)
  202. if "stop" in raw_text:
  203. terminate_runs = 1
  204. else:
  205. context_tokens = tokenizer.tokenize(raw_text)
  206. context_length = len(context_tokens)
  207. if context_length >= (args.seq_length // 2):
  208. print("\nContext length", context_length,
  209. "\nPlease give smaller context (half of the "
  210. "sequence length)!", flush=True)
  211. continue
  212. else:
  213. context_tokens = tokenizer.tokenize("EMPTY TEXT")
  214. context_length = 0
  215. input_info = [terminate_runs, raw_text_len, context_length]
  216. input_info_tensor = torch.cuda.LongTensor(input_info)
  217. torch.distributed.all_reduce(input_info_tensor,
  218. group=mpu.get_model_parallel_group())
  219. terminate_runs = input_info_tensor[0].item()
  220. raw_text_len = input_info_tensor[1].item()
  221. context_length = input_info_tensor[2].item()
  222. if terminate_runs == 1:
  223. return
  224. # For pipeline parallel we send context tokens to other stages
  225. # so they get the lengths correct
  226. if mpu.get_tensor_model_parallel_rank() == 0 \
  227. and args.pipeline_model_parallel_size > 1:
  228. if mpu.is_pipeline_first_stage():
  229. src = mpu.get_pipeline_model_parallel_first_rank()
  230. group = mpu.get_pipeline_model_parallel_group()
  231. context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
  232. torch.distributed.broadcast(context_tokens_tensor, src, group)
  233. else:
  234. src = mpu.get_pipeline_model_parallel_first_rank()
  235. group = mpu.get_pipeline_model_parallel_group()
  236. context_tokens_tensor = torch.empty(context_length,
  237. dtype=torch.int64,
  238. device=torch.device("cuda"))
  239. torch.distributed.broadcast(context_tokens_tensor, src, group)
  240. context_tokens = context_tokens_tensor.cpu().numpy().tolist()
  241. token_stream = get_token_stream(model, [context_tokens])
  242. for counter, decode_tokens in enumerate(token_stream):
  243. if counter % print_frequency != 0 \
  244. or mpu.get_tensor_model_parallel_rank() != 0 \
  245. or not mpu.is_pipeline_first_stage():
  246. continue
  247. os.system('clear')
  248. print("\nContext:", raw_text, flush=True)
  249. decode_tokens, _ = decode_tokens
  250. decode_tokens = decode_tokens[0].cpu().numpy().tolist()
  251. trim_decode_tokens = tokenizer.detokenize(
  252. decode_tokens)[raw_text_len:]
  253. print("\nMegatron-LM:", trim_decode_tokens, flush=True)
  254. if mpu.is_pipeline_first_stage() \
  255. and mpu.get_tensor_model_parallel_rank() == 0:
  256. os.system('clear')
  257. print("\nContext:", raw_text, flush=True)
  258. if not isinstance(decode_tokens, list):
  259. decode_tokens, _ = decode_tokens
  260. decode_tokens = decode_tokens[0].cpu().numpy().tolist()
  261. trim_decode_tokens = tokenizer.detokenize(
  262. decode_tokens)[raw_text_len:]
  263. print("\nMegatron-LM:", trim_decode_tokens, flush=True)
  264. input("\nPress Enter to continue >>>")
  265. raw_text = None
  266. context_count += 1
  267. def generate_samples_unconditional(model):
  268. args = get_args()
  269. tokenizer = get_tokenizer()
  270. num_samples = args.num_samples
  271. context_tokens = [[tokenizer.eod]
  272. for _ in range(args.micro_batch_size)]
  273. ctr = 0
  274. while True:
  275. start_time = time.time()
  276. for token_stream in get_token_stream(model,
  277. copy.deepcopy(context_tokens)):
  278. pass
  279. if mpu.is_pipeline_last_stage() and \
  280. mpu.get_tensor_model_parallel_rank() == 0:
  281. if ctr % args.log_interval == 0:
  282. print('Avg s/batch:',
  283. (time.time() - start_time) / min(args.log_interval, ctr + 1))
  284. start_time = time.time()
  285. length = len(token_stream)
  286. token_batch = token_stream[0].cpu().numpy().tolist()
  287. length_batch = token_stream[1].cpu().numpy().tolist()
  288. assert len(length_batch) == args.micro_batch_size
  289. for tokens, length in zip(token_batch, length_batch):
  290. tokens = tokens[1:length - 1]
  291. text = tokenizer.detokenize(tokens)
  292. is_finished = length < args.seq_length - 1
  293. datum = {'text': text, 'length': length - 1, 'finished': is_finished}
  294. yield datum
  295. ctr += 1
  296. if ctr >= num_samples:
  297. break
  298. else:
  299. for _ in range(args.micro_batch_size):
  300. yield None
  301. ctr += 1
  302. if ctr >= num_samples:
  303. break
  304. if ctr >= num_samples:
  305. break
  306. def generate_and_write_samples_unconditional(model):
  307. args = get_args()
  308. assert args.genfile is not None
  309. with open(args.genfile, 'w') as f:
  310. for datum in generate_samples_unconditional(model):
  311. if mpu.is_pipeline_last_stage() and \
  312. mpu.get_tensor_model_parallel_rank() == 0:
  313. f.write(json.dumps(datum) + '\n')
  314. def pad_batch(batch, pad_id, args):
  315. context_lengths = []
  316. for tokens in batch:
  317. context_length = len(tokens)
  318. if context_length < args.seq_length:
  319. tokens.extend([pad_id] * (args.seq_length - context_length))
  320. context_lengths.append(context_length)
  321. return batch, context_lengths
  322. def get_token_stream(model, context_tokens):
  323. args = get_args()
  324. tokenizer = get_tokenizer()
  325. context_tokens, context_lengths = pad_batch(context_tokens,
  326. tokenizer.eod, args)
  327. context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
  328. context_length_tensor = torch.cuda.LongTensor(context_lengths)
  329. torch.distributed.broadcast(context_length_tensor,
  330. mpu.get_tensor_model_parallel_src_rank(),
  331. group=mpu.get_tensor_model_parallel_group())
  332. torch.distributed.broadcast(context_tokens_tensor,
  333. mpu.get_tensor_model_parallel_src_rank(),
  334. group=mpu.get_tensor_model_parallel_group())
  335. context_length = context_length_tensor.min().item()
  336. tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
  337. batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
  338. context_length_tensor,
  339. attention_mask, position_ids)
  340. for tokens, lengths in batch_token_iterator:
  341. context_length += 1
  342. if tokens is not None:
  343. yield tokens[:, :context_length], lengths
  344. else:
  345. yield None, None
  346. def switch(val1, val2, boolean):
  347. boolean = boolean.type_as(val1)
  348. return (1 - boolean) * val1 + boolean * val2
  349. def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
  350. layer_past=None, get_key_value=None,
  351. forward_method_parallel_output=None):
  352. # Hidden size changes when not using recompute, need to tell p2p_communicate
  353. # functions the correct size
  354. args = get_args()
  355. orig_seq_length = args.seq_length
  356. args.seq_length = tokens.shape[1]
  357. input_tensor = recv_forward()
  358. # Forward pass through the model.
  359. unwrapped_model = unwrap_model(
  360. model, (torchDDP, LocalDDP, Float16Module))
  361. unwrapped_model.set_input_tensor(input_tensor)
  362. output_tensor = model(tokens, position_ids, attention_mask,
  363. tokentype_ids=tokentype_ids,
  364. layer_past=layer_past,
  365. get_key_value=get_key_value,
  366. forward_method_parallel_output=forward_method_parallel_output)
  367. if get_key_value:
  368. output_tensor, layer_past = output_tensor
  369. send_forward(output_tensor)
  370. args.seq_length = orig_seq_length
  371. if get_key_value:
  372. return output_tensor, layer_past
  373. return output_tensor
  374. def sample_sequence_batch(model, context_tokens, context_lengths,
  375. attention_mask, position_ids,
  376. maxlen=None, type_ids=None):
  377. args = get_args()
  378. tokenizer = get_tokenizer()
  379. model.eval()
  380. with torch.no_grad():
  381. context_length = context_lengths.min().item()
  382. # added eos_id to support the function generate_samples_eval that passes
  383. # eos_id as an argument and needs termination when that id id found.
  384. if hasattr(args, 'eos_id'):
  385. eos_id = args.eos_id
  386. else:
  387. eos_id = tokenizer.eod
  388. counter = 0
  389. org_context_length = context_length
  390. layer_past = None
  391. batch_size = context_tokens.size(0)
  392. is_done = torch.zeros([batch_size]).byte().cuda()
  393. tokens = context_tokens
  394. if maxlen is None:
  395. maxlen = args.seq_length - 1
  396. if maxlen > (org_context_length + args.out_seq_length):
  397. maxlen = org_context_length + args.out_seq_length
  398. lengths = torch.ones([batch_size]).long().cuda() * maxlen
  399. while context_length <= (maxlen):
  400. if args.recompute:
  401. output = forward_step(model, tokens,
  402. position_ids,
  403. attention_mask,
  404. tokentype_ids=type_ids,
  405. forward_method_parallel_output=False)
  406. if mpu.is_pipeline_last_stage():
  407. assert output is not None
  408. logits = output[:, context_length - 1, :]
  409. else:
  410. types2use = None
  411. if counter == 0:
  412. tokens2use = tokens[:, :context_length]
  413. positions2use = position_ids[:, :context_length]
  414. if type_ids is not None:
  415. types2use = type_ids[:, :context_length]
  416. else:
  417. tokens2use = tokens[:, context_length - 1].view(
  418. batch_size, -1)
  419. positions2use = position_ids[:, context_length - 1].view(
  420. batch_size, -1)
  421. if type_ids is not None:
  422. types2use = type_ids[:, context_length - 1].view(
  423. batch_size, -1)
  424. output, layer_past = forward_step(model, tokens2use,
  425. positions2use,
  426. attention_mask,
  427. layer_past=layer_past,
  428. get_key_value=True,
  429. tokentype_ids=types2use,
  430. forward_method_parallel_output=False)
  431. if mpu.is_pipeline_last_stage():
  432. assert output is not None
  433. logits = output[:, -1].view(batch_size, -1).contiguous()
  434. if mpu.is_pipeline_last_stage():
  435. if args.greedy:
  436. prev = torch.argmax(logits, dim=-1).view(-1)
  437. else:
  438. logits = logits.float()
  439. logits /= args.temperature
  440. logits = top_k_logits(logits, top_k=args.top_k,
  441. top_p=args.top_p)
  442. log_probs = F.softmax(logits, dim=-1)
  443. prev = torch.multinomial(log_probs, num_samples=1).view(-1)
  444. started = context_lengths <= context_length
  445. new_tokens = switch(
  446. tokens[:, context_length].view(-1), prev, started)
  447. tokens[:, context_length] = new_tokens
  448. src = mpu.get_pipeline_model_parallel_last_rank()
  449. group = mpu.get_embedding_group()
  450. torch.distributed.broadcast(new_tokens, src, group)
  451. done_token = (prev == eos_id).byte() & started.byte()
  452. just_finished = (done_token & ~is_done).bool()
  453. lengths[just_finished.view(-1)] = context_length
  454. is_done = is_done | done_token
  455. done = torch.all(is_done)
  456. src = mpu.get_pipeline_model_parallel_last_rank()
  457. group = mpu.get_pipeline_model_parallel_group()
  458. torch.distributed.broadcast(done, src, group)
  459. yield tokens, lengths
  460. else:
  461. if mpu.is_pipeline_first_stage():
  462. src = mpu.get_pipeline_model_parallel_last_rank()
  463. group = mpu.get_embedding_group()
  464. new_tokens = torch.empty_like(tokens[:, context_length])
  465. torch.distributed.broadcast(new_tokens, src, group)
  466. tokens[:, context_length] = new_tokens
  467. yield tokens, None
  468. else:
  469. yield None, None
  470. done = torch.cuda.ByteTensor([0])
  471. src = mpu.get_pipeline_model_parallel_last_rank()
  472. group = mpu.get_pipeline_model_parallel_group()
  473. torch.distributed.broadcast(done, src, group)
  474. context_length += 1
  475. counter += 1
  476. if done:
  477. break