datasets.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Zero-shot datasets."""
  16. import json
  17. import math
  18. import numpy as np
  19. import torch
  20. from megatron import get_args
  21. from megatron import print_rank_0
  22. from megatron import get_tokenizer
  23. from .detokenizer import get_detokenizer
  24. def build_dataset(task):
  25. """Helper function to select and build dataset."""
  26. if task == 'LAMBADA':
  27. return _build_lambada_dataset()
  28. if task == 'WIKITEXT103':
  29. return _build_wikitext103_dataset()
  30. raise NotImplementedError('dataset for {} task is not '
  31. 'implemented.'.format(task))
  32. class _LMDataset(torch.utils.data.Dataset):
  33. def __init__(self, tokens, seq_len, pad_idx, num_original_tokens,
  34. num_tokenized_tokens, overalapping_eval=None):
  35. self.tokens = tokens
  36. self.seq_len = seq_len
  37. self.pad_idx = pad_idx
  38. self.overalapping_eval = overalapping_eval
  39. if self.overalapping_eval is None:
  40. self.overalapping_eval = self.seq_len
  41. self.overalapping_eval = max(1, self.overalapping_eval)
  42. self.num_original_tokens = num_original_tokens
  43. self.num_tokenized_tokens = num_tokenized_tokens
  44. self.total_targets = len(self.tokens) - 1
  45. # remove first sequence tokens
  46. targets = max(self.total_targets - self.overalapping_eval, 0)
  47. self.total_sequences = max(
  48. math.ceil(targets / self.overalapping_eval) + 1, 1)
  49. def __len__(self):
  50. return self.total_sequences
  51. def __getitem__(self, idx):
  52. start_idx = idx * self.overalapping_eval
  53. end_idx = start_idx + self.seq_len
  54. tokens = self.tokens[start_idx:end_idx + 1]
  55. num_tokens = len(tokens)
  56. pad_mask = [1] * num_tokens
  57. if num_tokens < self.seq_len + 1:
  58. num_pad = (self.seq_len + 1 - num_tokens)
  59. pad_mask += [0] * (num_pad)
  60. tokens += [self.pad_idx] * num_pad
  61. pad_mask = np.array(pad_mask[1:])
  62. if self.overalapping_eval != self.seq_len and idx != 0:
  63. pad_mask[:-self.overalapping_eval] *= 0
  64. return {'text': np.array(tokens), 'pad_mask': pad_mask}
  65. class _LambadaDataset(torch.utils.data.Dataset):
  66. def __init__(self, path, pad_idx, tokenizer, seq_len, strict=False):
  67. print_rank_0('> building lambada dataset from {} ...'.format(path))
  68. self.seq_len = seq_len
  69. self.pad_idx = pad_idx
  70. self.tokenizer = tokenizer
  71. self.strict = strict
  72. self.tokens = []
  73. self.labels = []
  74. with open(path, 'r') as f:
  75. for line in f.readlines():
  76. text = json.loads(line)['text']
  77. tokens, labels = self.get_tokens(text)
  78. self.tokens.append(tokens)
  79. self.labels.append(labels)
  80. def get_tokens(self, text):
  81. if not self.strict:
  82. tokens = self.tokenizer.tokenize(text)
  83. return tokens[:-1], [tokens[-1]]
  84. last_token = text.split()[-1]
  85. start_idx = text.rfind(last_token)
  86. beginning_tokens = self.tokenizer.tokenize(text[:start_idx].strip())
  87. last_token = self.tokenizer.tokenize(' ' + last_token)
  88. return beginning_tokens, last_token
  89. def __len__(self):
  90. return len(self.tokens)
  91. def __getitem__(self, idx):
  92. tokens = self.tokens[idx]
  93. num_tokens = len(tokens)
  94. pad_mask = [0] * num_tokens
  95. labels = self.labels[idx]
  96. pad_mask += [1] * len(labels)
  97. tokens = tokens + labels
  98. num_tokens = len(tokens)
  99. if num_tokens < self.seq_len + 1:
  100. num_pad = (self.seq_len + 1 - num_tokens)
  101. pad_mask += [0] * (num_pad)
  102. tokens += [self.pad_idx] * num_pad
  103. pad_mask = np.array(pad_mask[1:])
  104. return {'text': np.array(tokens), 'pad_mask': pad_mask}
  105. def _build_lambada_dataset():
  106. """Build lambada dataset."""
  107. args = get_args()
  108. tokenizer = get_tokenizer()
  109. assert len(args.valid_data) == 1
  110. val_dataset = _LambadaDataset(args.valid_data[0], tokenizer.eod, tokenizer,
  111. args.seq_length, args.strict_lambada)
  112. print_rank_0(' > found {} samples.'.format(len(val_dataset)))
  113. return val_dataset
  114. def _build_wikitext103_dataset():
  115. """"""
  116. args = get_args()
  117. tokenizer = get_tokenizer()
  118. assert len(args.valid_data) == 1
  119. with open(args.valid_data[0], "rb") as reader:
  120. entire_data = reader.read().decode('utf-8')
  121. num_original_tokens = len(entire_data.strip().split(" "))
  122. entire_data = get_detokenizer(args.valid_data[0])(entire_data)
  123. tokenized_data = tokenizer.tokenize(entire_data)
  124. num_tokenized_tokens = len(tokenized_data)
  125. val_dataset = _LMDataset(tokenized_data, args.seq_length, tokenizer.eod,
  126. num_original_tokens, num_tokenized_tokens,
  127. args.overlapping_eval)
  128. print_rank_0(' > number of original tokens: {}, number of detokenized '
  129. 'tokens: {}'.format(num_original_tokens, num_tokenized_tokens))
  130. return val_dataset