123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Utilities for generating text."""
- import copy
- import json
- import os
- import time
- import torch
- import torch.nn.functional as F
- from megatron import get_args
- from megatron import get_tokenizer
- from megatron import mpu
- from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
- from megatron.p2p_communication import recv_forward, send_forward
- # These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
- from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
- from megatron.model import DistributedDataParallel as LocalDDP
- from megatron.model import Float16Module
- def get_batch(context_tokens):
- """Generate batch from context tokens."""
- args = get_args()
- tokenizer = get_tokenizer()
- # Move to GPU.
- tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
- # Get the attention mask and postition ids.
- attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
- tokens,
- tokenizer.eod,
- args.reset_position_ids,
- args.reset_attention_mask,
- args.eod_mask_loss)
- return tokens, attention_mask, position_ids
- def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
- """ This function has been mostly taken from huggingface conversational
- ai code at
- https://medium.com/huggingface/how-to-build-a-state-of-the-art-
- conversational-ai-with-transfer-learning-2d818ac26313 """
- if top_k > 0:
- # Remove all tokens with a probability less than the
- # last token of the top-k
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
- logits[indices_to_remove] = filter_value
- if top_p > 0.0:
- # Cconvert to 1D
- sorted_logits, sorted_indices = torch.sort(
- logits, descending=True, dim=-1)
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
- dim=-1)
- # Remove tokens with cumulative probability above the threshold
- sorted_indices_to_remove = cumulative_probs > top_p
- # Shift the indices to the right to keep also the first token
- # above the threshold
- sorted_indices_to_remove[..., 1:] \
- = sorted_indices_to_remove[..., :-1].clone()
- sorted_indices_to_remove[..., 0] = 0
- for i in range(sorted_indices.size(0)):
- indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
- logits[i][indices_to_remove] = filter_value
- return logits
- def generate_samples_input_from_file(model):
- args = get_args()
- tokenizer = get_tokenizer()
- # Read the sample file and open the output file.
- assert args.sample_input_file is not None, \
- 'sample input file is not provided.'
- if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
- fname = open(args.sample_input_file, "r")
- all_raw_text = fname.readlines()
- input_count = len(all_raw_text)
- input_pos = 0
- if args.sample_output_file is None:
- sample_output_file = args.sample_input_file + ".out"
- print('`sample-output-file` not specified, setting '
- 'it to {}'.format(sample_output_file))
- else:
- sample_output_file = args.sample_output_file
- fname_out = open(sample_output_file, "w+")
- context_count = 0
- model.eval()
- with torch.no_grad():
- while True:
- terminate_runs = 0
- raw_text_len = 0
- if mpu.is_pipeline_first_stage() \
- and mpu.get_tensor_model_parallel_rank() == 0:
- raw_text = all_raw_text[input_pos]
- input_pos += 1
- if input_pos == input_count:
- raw_text = "stop"
- raw_text_len = len(raw_text)
- if "stop" in raw_text:
- terminate_runs = 1
- else:
- context_tokens = tokenizer.tokenize(raw_text)
- context_length = len(context_tokens)
- if context_length >= (args.seq_length // 2):
- print("\nContext length", context_length,
- "\nPlease give smaller context (half of the "
- "sequence length)!", flush=True)
- continue
- else:
- context_tokens = tokenizer.tokenize("EMPTY TEXT")
- context_length = 0
- input_info = [terminate_runs, raw_text_len, context_length]
- input_info_tensor = torch.cuda.LongTensor(input_info)
- torch.distributed.all_reduce(input_info_tensor,
- group=mpu.get_model_parallel_group())
- terminate_runs = input_info_tensor[0].item()
- raw_text_len = input_info_tensor[1].item()
- context_length = input_info_tensor[2].item()
- if terminate_runs == 1:
- return
- # For pipeline parallel we send context tokens to other stages
- # so they get the lengths correct
- if mpu.get_tensor_model_parallel_rank() == 0 \
- and args.pipeline_model_parallel_size > 1:
- if mpu.is_pipeline_first_stage():
- src = mpu.get_pipeline_model_parallel_first_rank()
- group = mpu.get_pipeline_model_parallel_group()
- context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
- torch.distributed.broadcast(context_tokens_tensor, src, group)
- else:
- src = mpu.get_pipeline_model_parallel_first_rank()
- group = mpu.get_pipeline_model_parallel_group()
- context_tokens_tensor = torch.empty(context_length,
- dtype=torch.int64,
- device=torch.device("cuda"))
- torch.distributed.broadcast(context_tokens_tensor, src, group)
- context_tokens = context_tokens_tensor.cpu().numpy().tolist()
- token_stream = get_token_stream(model, [context_tokens])
- for _, decode_tokens in enumerate(token_stream):
- pass
- if mpu.get_tensor_model_parallel_rank() == 0:
- if mpu.is_pipeline_first_stage():
- os.system('clear')
- print("\nContext:", raw_text, flush=True)
- fname_out.write("\nContext:")
- fname_out.write(raw_text)
- decode_tokens, _ = decode_tokens
- decode_tokens = decode_tokens[0].cpu().numpy().tolist()
- trim_decode_tokens = tokenizer.detokenize(
- decode_tokens)[raw_text_len:]
- print("\nMegatron-LM:", trim_decode_tokens, flush=True)
- fname_out.write("\n\nMegatron-LM:")
- fname_out.write(trim_decode_tokens)
- fname_out.write("\n")
- raw_text = None
- context_count += 1
- # We added this function to support the tasks evaluation such as squad
- # and drop in the https://github.com/EleutherAI/lm-evaluation-harness
- # codebase. The lm-evaluation-harness code can now call this function
- # similar to their current generate function call used for gpt style models.
- def generate_samples_eval(model, context, max_gen_length, eos_token_id):
- # Generate samples for lm evaluation
- # NEED TO THINK ABOUT eos token
- args = get_args()
- tokenizer = get_tokenizer()
- raw_text_len = len(context)
- model.eval()
- context_tokens = tokenizer.tokenize(context)
- args.out_seq_length = max_gen_length + len(context_tokens)
- args.eos_id = eos_token_id
- with torch.no_grad():
- token_stream = get_token_stream(model, [context_tokens])
- for counter, decode_tokens in enumerate(token_stream):
- if counter == args.out_seq_length:
- break
- decode_tokens, _ = decode_tokens
- decode_tokens = decode_tokens[0].cpu().numpy().tolist()
- trim_decode_tokens = tokenizer.detokenize(
- decode_tokens)[raw_text_len:]
-
- return trim_decode_tokens
- def generate_samples_interactive(model, print_frequency=24):
- args = get_args()
- tokenizer = get_tokenizer()
- context_count = 0
- model.eval()
- with torch.no_grad():
- while True:
- terminate_runs = 0
- raw_text_len = 0
- if mpu.is_pipeline_first_stage() \
- and mpu.get_tensor_model_parallel_rank() == 0:
- os.system('clear')
- raw_text = input("\nContext prompt (stop to exit) >>> ")
- while not raw_text:
- print('Prompt should not be empty!')
- raw_text = input("\nContext prompt (stop to exit) >>> ")
- raw_text_len = len(raw_text)
- if "stop" in raw_text:
- terminate_runs = 1
- else:
- context_tokens = tokenizer.tokenize(raw_text)
- context_length = len(context_tokens)
- if context_length >= (args.seq_length // 2):
- print("\nContext length", context_length,
- "\nPlease give smaller context (half of the "
- "sequence length)!", flush=True)
- continue
- else:
- context_tokens = tokenizer.tokenize("EMPTY TEXT")
- context_length = 0
- input_info = [terminate_runs, raw_text_len, context_length]
- input_info_tensor = torch.cuda.LongTensor(input_info)
- torch.distributed.all_reduce(input_info_tensor,
- group=mpu.get_model_parallel_group())
- terminate_runs = input_info_tensor[0].item()
- raw_text_len = input_info_tensor[1].item()
- context_length = input_info_tensor[2].item()
- if terminate_runs == 1:
- return
- # For pipeline parallel we send context tokens to other stages
- # so they get the lengths correct
- if mpu.get_tensor_model_parallel_rank() == 0 \
- and args.pipeline_model_parallel_size > 1:
- if mpu.is_pipeline_first_stage():
- src = mpu.get_pipeline_model_parallel_first_rank()
- group = mpu.get_pipeline_model_parallel_group()
- context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
- torch.distributed.broadcast(context_tokens_tensor, src, group)
- else:
- src = mpu.get_pipeline_model_parallel_first_rank()
- group = mpu.get_pipeline_model_parallel_group()
- context_tokens_tensor = torch.empty(context_length,
- dtype=torch.int64,
- device=torch.device("cuda"))
- torch.distributed.broadcast(context_tokens_tensor, src, group)
- context_tokens = context_tokens_tensor.cpu().numpy().tolist()
- token_stream = get_token_stream(model, [context_tokens])
- for counter, decode_tokens in enumerate(token_stream):
- if counter % print_frequency != 0 \
- or mpu.get_tensor_model_parallel_rank() != 0 \
- or not mpu.is_pipeline_first_stage():
- continue
- os.system('clear')
- print("\nContext:", raw_text, flush=True)
- decode_tokens, _ = decode_tokens
- decode_tokens = decode_tokens[0].cpu().numpy().tolist()
- trim_decode_tokens = tokenizer.detokenize(
- decode_tokens)[raw_text_len:]
- print("\nMegatron-LM:", trim_decode_tokens, flush=True)
- if mpu.is_pipeline_first_stage() \
- and mpu.get_tensor_model_parallel_rank() == 0:
- os.system('clear')
- print("\nContext:", raw_text, flush=True)
- if not isinstance(decode_tokens, list):
- decode_tokens, _ = decode_tokens
- decode_tokens = decode_tokens[0].cpu().numpy().tolist()
- trim_decode_tokens = tokenizer.detokenize(
- decode_tokens)[raw_text_len:]
- print("\nMegatron-LM:", trim_decode_tokens, flush=True)
- input("\nPress Enter to continue >>>")
- raw_text = None
- context_count += 1
- def generate_samples_unconditional(model):
- args = get_args()
- tokenizer = get_tokenizer()
- num_samples = args.num_samples
- context_tokens = [[tokenizer.eod]
- for _ in range(args.micro_batch_size)]
- ctr = 0
- while True:
- start_time = time.time()
- for token_stream in get_token_stream(model,
- copy.deepcopy(context_tokens)):
- pass
- if mpu.is_pipeline_last_stage() and \
- mpu.get_tensor_model_parallel_rank() == 0:
- if ctr % args.log_interval == 0:
- print('Avg s/batch:',
- (time.time() - start_time) / min(args.log_interval, ctr + 1))
- start_time = time.time()
- length = len(token_stream)
- token_batch = token_stream[0].cpu().numpy().tolist()
- length_batch = token_stream[1].cpu().numpy().tolist()
- assert len(length_batch) == args.micro_batch_size
- for tokens, length in zip(token_batch, length_batch):
- tokens = tokens[1:length - 1]
- text = tokenizer.detokenize(tokens)
- is_finished = length < args.seq_length - 1
- datum = {'text': text, 'length': length - 1, 'finished': is_finished}
- yield datum
- ctr += 1
- if ctr >= num_samples:
- break
- else:
- for _ in range(args.micro_batch_size):
- yield None
- ctr += 1
- if ctr >= num_samples:
- break
- if ctr >= num_samples:
- break
- def generate_and_write_samples_unconditional(model):
- args = get_args()
- assert args.genfile is not None
- with open(args.genfile, 'w') as f:
- for datum in generate_samples_unconditional(model):
- if mpu.is_pipeline_last_stage() and \
- mpu.get_tensor_model_parallel_rank() == 0:
- f.write(json.dumps(datum) + '\n')
- def pad_batch(batch, pad_id, args):
- context_lengths = []
- for tokens in batch:
- context_length = len(tokens)
- if context_length < args.seq_length:
- tokens.extend([pad_id] * (args.seq_length - context_length))
- context_lengths.append(context_length)
- return batch, context_lengths
- def get_token_stream(model, context_tokens):
- args = get_args()
- tokenizer = get_tokenizer()
- context_tokens, context_lengths = pad_batch(context_tokens,
- tokenizer.eod, args)
- context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
- context_length_tensor = torch.cuda.LongTensor(context_lengths)
- torch.distributed.broadcast(context_length_tensor,
- mpu.get_tensor_model_parallel_src_rank(),
- group=mpu.get_tensor_model_parallel_group())
- torch.distributed.broadcast(context_tokens_tensor,
- mpu.get_tensor_model_parallel_src_rank(),
- group=mpu.get_tensor_model_parallel_group())
- context_length = context_length_tensor.min().item()
- tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
- batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
- context_length_tensor,
- attention_mask, position_ids)
- for tokens, lengths in batch_token_iterator:
- context_length += 1
- if tokens is not None:
- yield tokens[:, :context_length], lengths
- else:
- yield None, None
- def switch(val1, val2, boolean):
- boolean = boolean.type_as(val1)
- return (1 - boolean) * val1 + boolean * val2
- def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
- layer_past=None, get_key_value=None,
- forward_method_parallel_output=None):
- # Hidden size changes when not using recompute, need to tell p2p_communicate
- # functions the correct size
- args = get_args()
- orig_seq_length = args.seq_length
- args.seq_length = tokens.shape[1]
- input_tensor = recv_forward()
- # Forward pass through the model.
- unwrapped_model = unwrap_model(
- model, (torchDDP, LocalDDP, Float16Module))
- unwrapped_model.set_input_tensor(input_tensor)
- output_tensor = model(tokens, position_ids, attention_mask,
- tokentype_ids=tokentype_ids,
- layer_past=layer_past,
- get_key_value=get_key_value,
- forward_method_parallel_output=forward_method_parallel_output)
- if get_key_value:
- output_tensor, layer_past = output_tensor
- send_forward(output_tensor)
- args.seq_length = orig_seq_length
- if get_key_value:
- return output_tensor, layer_past
- return output_tensor
- def sample_sequence_batch(model, context_tokens, context_lengths,
- attention_mask, position_ids,
- maxlen=None, type_ids=None):
- args = get_args()
- tokenizer = get_tokenizer()
- model.eval()
- with torch.no_grad():
- context_length = context_lengths.min().item()
- # added eos_id to support the function generate_samples_eval that passes
- # eos_id as an argument and needs termination when that id id found.
- if hasattr(args, 'eos_id'):
- eos_id = args.eos_id
- else:
- eos_id = tokenizer.eod
- counter = 0
- org_context_length = context_length
- layer_past = None
- batch_size = context_tokens.size(0)
- is_done = torch.zeros([batch_size]).byte().cuda()
- tokens = context_tokens
- if maxlen is None:
- maxlen = args.seq_length - 1
- if maxlen > (org_context_length + args.out_seq_length):
- maxlen = org_context_length + args.out_seq_length
- lengths = torch.ones([batch_size]).long().cuda() * maxlen
- while context_length <= (maxlen):
- if args.recompute:
- output = forward_step(model, tokens,
- position_ids,
- attention_mask,
- tokentype_ids=type_ids,
- forward_method_parallel_output=False)
- if mpu.is_pipeline_last_stage():
- assert output is not None
- logits = output[:, context_length - 1, :]
- else:
- types2use = None
- if counter == 0:
- tokens2use = tokens[:, :context_length]
- positions2use = position_ids[:, :context_length]
- if type_ids is not None:
- types2use = type_ids[:, :context_length]
- else:
- tokens2use = tokens[:, context_length - 1].view(
- batch_size, -1)
- positions2use = position_ids[:, context_length - 1].view(
- batch_size, -1)
- if type_ids is not None:
- types2use = type_ids[:, context_length - 1].view(
- batch_size, -1)
- output, layer_past = forward_step(model, tokens2use,
- positions2use,
- attention_mask,
- layer_past=layer_past,
- get_key_value=True,
- tokentype_ids=types2use,
- forward_method_parallel_output=False)
- if mpu.is_pipeline_last_stage():
- assert output is not None
- logits = output[:, -1].view(batch_size, -1).contiguous()
- if mpu.is_pipeline_last_stage():
- if args.greedy:
- prev = torch.argmax(logits, dim=-1).view(-1)
- else:
- logits = logits.float()
- logits /= args.temperature
- logits = top_k_logits(logits, top_k=args.top_k,
- top_p=args.top_p)
- log_probs = F.softmax(logits, dim=-1)
- prev = torch.multinomial(log_probs, num_samples=1).view(-1)
- started = context_lengths <= context_length
- new_tokens = switch(
- tokens[:, context_length].view(-1), prev, started)
- tokens[:, context_length] = new_tokens
- src = mpu.get_pipeline_model_parallel_last_rank()
- group = mpu.get_embedding_group()
- torch.distributed.broadcast(new_tokens, src, group)
- done_token = (prev == eos_id).byte() & started.byte()
- just_finished = (done_token & ~is_done).bool()
- lengths[just_finished.view(-1)] = context_length
- is_done = is_done | done_token
- done = torch.all(is_done)
- src = mpu.get_pipeline_model_parallel_last_rank()
- group = mpu.get_pipeline_model_parallel_group()
- torch.distributed.broadcast(done, src, group)
- yield tokens, lengths
- else:
- if mpu.is_pipeline_first_stage():
- src = mpu.get_pipeline_model_parallel_last_rank()
- group = mpu.get_embedding_group()
- new_tokens = torch.empty_like(tokens[:, context_length])
- torch.distributed.broadcast(new_tokens, src, group)
- tokens[:, context_length] = new_tokens
- yield tokens, None
- else:
- yield None, None
- done = torch.cuda.ByteTensor([0])
- src = mpu.get_pipeline_model_parallel_last_rank()
- group = mpu.get_pipeline_model_parallel_group()
- torch.distributed.broadcast(done, src, group)
- context_length += 1
- counter += 1
- if done:
- break
|