# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Zero-shot datasets.""" import json import math import numpy as np import torch from megatron import get_args from megatron import print_rank_0 from megatron import get_tokenizer from .detokenizer import get_detokenizer def build_dataset(task): """Helper function to select and build dataset.""" if task == 'LAMBADA': return _build_lambada_dataset() if task == 'WIKITEXT103': return _build_wikitext103_dataset() raise NotImplementedError('dataset for {} task is not ' 'implemented.'.format(task)) class _LMDataset(torch.utils.data.Dataset): def __init__(self, tokens, seq_len, pad_idx, num_original_tokens, num_tokenized_tokens, overalapping_eval=None): self.tokens = tokens self.seq_len = seq_len self.pad_idx = pad_idx self.overalapping_eval = overalapping_eval if self.overalapping_eval is None: self.overalapping_eval = self.seq_len self.overalapping_eval = max(1, self.overalapping_eval) self.num_original_tokens = num_original_tokens self.num_tokenized_tokens = num_tokenized_tokens self.total_targets = len(self.tokens) - 1 # remove first sequence tokens targets = max(self.total_targets - self.overalapping_eval, 0) self.total_sequences = max( math.ceil(targets / self.overalapping_eval) + 1, 1) def __len__(self): return self.total_sequences def __getitem__(self, idx): start_idx = idx * self.overalapping_eval end_idx = start_idx + self.seq_len tokens = self.tokens[start_idx:end_idx + 1] num_tokens = len(tokens) pad_mask = [1] * num_tokens if num_tokens < self.seq_len + 1: num_pad = (self.seq_len + 1 - num_tokens) pad_mask += [0] * (num_pad) tokens += [self.pad_idx] * num_pad pad_mask = np.array(pad_mask[1:]) if self.overalapping_eval != self.seq_len and idx != 0: pad_mask[:-self.overalapping_eval] *= 0 return {'text': np.array(tokens), 'pad_mask': pad_mask} class _LambadaDataset(torch.utils.data.Dataset): def __init__(self, path, pad_idx, tokenizer, seq_len, strict=False): print_rank_0('> building lambada dataset from {} ...'.format(path)) self.seq_len = seq_len self.pad_idx = pad_idx self.tokenizer = tokenizer self.strict = strict self.tokens = [] self.labels = [] with open(path, 'r') as f: for line in f.readlines(): text = json.loads(line)['text'] tokens, labels = self.get_tokens(text) self.tokens.append(tokens) self.labels.append(labels) def get_tokens(self, text): if not self.strict: tokens = self.tokenizer.tokenize(text) return tokens[:-1], [tokens[-1]] last_token = text.split()[-1] start_idx = text.rfind(last_token) beginning_tokens = self.tokenizer.tokenize(text[:start_idx].strip()) last_token = self.tokenizer.tokenize(' ' + last_token) return beginning_tokens, last_token def __len__(self): return len(self.tokens) def __getitem__(self, idx): tokens = self.tokens[idx] num_tokens = len(tokens) pad_mask = [0] * num_tokens labels = self.labels[idx] pad_mask += [1] * len(labels) tokens = tokens + labels num_tokens = len(tokens) if num_tokens < self.seq_len + 1: num_pad = (self.seq_len + 1 - num_tokens) pad_mask += [0] * (num_pad) tokens += [self.pad_idx] * num_pad pad_mask = np.array(pad_mask[1:]) return {'text': np.array(tokens), 'pad_mask': pad_mask} def _build_lambada_dataset(): """Build lambada dataset.""" args = get_args() tokenizer = get_tokenizer() assert len(args.valid_data) == 1 val_dataset = _LambadaDataset(args.valid_data[0], tokenizer.eod, tokenizer, args.seq_length, args.strict_lambada) print_rank_0(' > found {} samples.'.format(len(val_dataset))) return val_dataset def _build_wikitext103_dataset(): """""" args = get_args() tokenizer = get_tokenizer() assert len(args.valid_data) == 1 with open(args.valid_data[0], "rb") as reader: entire_data = reader.read().decode('utf-8') num_original_tokens = len(entire_data.strip().split(" ")) entire_data = get_detokenizer(args.valid_data[0])(entire_data) tokenized_data = tokenizer.tokenize(entire_data) num_tokenized_tokens = len(tokenized_data) val_dataset = _LMDataset(tokenized_data, args.seq_length, tokenizer.eod, num_original_tokens, num_tokenized_tokens, args.overlapping_eval) print_rank_0(' > number of original tokens: {}, number of detokenized ' 'tokens: {}'.format(num_original_tokens, num_tokenized_tokens)) return val_dataset