|
@@ -0,0 +1,126 @@
|
|
|
+# coding=utf-8
|
|
|
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
|
+#
|
|
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
+# you may not use this file except in compliance with the License.
|
|
|
+# You may obtain a copy of the License at
|
|
|
+#
|
|
|
+# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
+#
|
|
|
+# Unless required by applicable law or agreed to in writing, software
|
|
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
+# See the License for the specific language governing permissions and
|
|
|
+# limitations under the License.
|
|
|
+
|
|
|
+"""Pretrain GPT"""
|
|
|
+
|
|
|
+import torch
|
|
|
+from functools import partial
|
|
|
+from megatron import get_args
|
|
|
+from megatron import print_rank_0
|
|
|
+from megatron import get_timers
|
|
|
+from megatron import get_tokenizer
|
|
|
+from megatron import mpu
|
|
|
+from megatron.data.gpt_dataset import build_train_valid_test_datasets
|
|
|
+from megatron.model import GPTModel
|
|
|
+from megatron.training import pretrain
|
|
|
+from megatron.utils import get_ltor_masks_and_position_ids
|
|
|
+from megatron.utils import average_losses_across_data_parallel_group
|
|
|
+import pyprof
|
|
|
+pyprof.init(enable_function_stack=True)
|
|
|
+def model_provider(pre_process=True, post_process=True):
|
|
|
+ """Build the model."""
|
|
|
+
|
|
|
+ print_rank_0('building GPT model ...')
|
|
|
+ model = GPTModel(
|
|
|
+ num_tokentypes=0,
|
|
|
+ parallel_output=True,
|
|
|
+ pre_process=pre_process,
|
|
|
+ post_process=post_process
|
|
|
+ )
|
|
|
+ return model
|
|
|
+
|
|
|
+
|
|
|
+def get_batch(data_iterator):
|
|
|
+ """Generate a batch"""
|
|
|
+ args = get_args()
|
|
|
+ tokenizer = get_tokenizer()
|
|
|
+
|
|
|
+ # Items and their type.
|
|
|
+ keys = ['text']
|
|
|
+ datatype = torch.int64
|
|
|
+
|
|
|
+ # Broadcast data.
|
|
|
+ if data_iterator is not None:
|
|
|
+ data = next(data_iterator)
|
|
|
+ else:
|
|
|
+ data = None
|
|
|
+ data_b = mpu.broadcast_data(keys, data, datatype)
|
|
|
+
|
|
|
+ # Unpack.
|
|
|
+ tokens_ = data_b['text'].long()
|
|
|
+ labels = tokens_[:, 1:].contiguous()
|
|
|
+ tokens = tokens_[:, :-1].contiguous()
|
|
|
+
|
|
|
+ # Get the masks and postition ids.
|
|
|
+ attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
|
|
|
+ tokens,
|
|
|
+ tokenizer.eod,
|
|
|
+ args.reset_position_ids,
|
|
|
+ args.reset_attention_mask,
|
|
|
+ args.eod_mask_loss)
|
|
|
+
|
|
|
+ return tokens, labels, loss_mask, attention_mask, position_ids
|
|
|
+
|
|
|
+def loss_func(loss_mask, output_tensor):
|
|
|
+ losses = output_tensor.float()
|
|
|
+ loss_mask = loss_mask.view(-1).float()
|
|
|
+ loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
|
|
|
+
|
|
|
+ # Reduce loss for logging.
|
|
|
+ averaged_loss = average_losses_across_data_parallel_group([loss])
|
|
|
+
|
|
|
+ return loss, {'lm loss': averaged_loss[0]}
|
|
|
+
|
|
|
+
|
|
|
+def forward_step(data_iterator, model):
|
|
|
+ """Forward step."""
|
|
|
+ args = get_args()
|
|
|
+ timers = get_timers()
|
|
|
+
|
|
|
+ # Get the batch.
|
|
|
+ timers('batch-generator').start()
|
|
|
+ tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
|
|
|
+ data_iterator)
|
|
|
+ timers('batch-generator').stop()
|
|
|
+
|
|
|
+ output_tensor = model(tokens, position_ids, attention_mask,
|
|
|
+ labels=labels)
|
|
|
+
|
|
|
+ return output_tensor, partial(loss_func, loss_mask)
|
|
|
+
|
|
|
+
|
|
|
+def train_valid_test_datasets_provider(train_val_test_num_samples):
|
|
|
+ """Build train, valid, and test datasets."""
|
|
|
+ args = get_args()
|
|
|
+
|
|
|
+ print_rank_0('> building train, validation, and test datasets '
|
|
|
+ 'for GPT ...')
|
|
|
+ train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
|
|
|
+ data_prefix=args.data_path,
|
|
|
+ data_impl=args.data_impl,
|
|
|
+ splits_string=args.split,
|
|
|
+ train_valid_test_num_samples=train_val_test_num_samples,
|
|
|
+ seq_length=args.seq_length,
|
|
|
+ seed=args.seed,
|
|
|
+ skip_warmup=(not args.mmap_warmup))
|
|
|
+ print_rank_0("> finished creating GPT datasets ...")
|
|
|
+
|
|
|
+ return train_ds, valid_ds, test_ds
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ with torch.autograd.profiler.emit_nvtx():
|
|
|
+ pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
|
|
|
+ args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
|