nq.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Data Loader for Google NQ dataset
  17. """
  18. from abc import ABC
  19. import csv
  20. from collections import OrderedDict
  21. import numpy as np
  22. import torch
  23. from torch.utils.data import DataLoader
  24. from torch.utils.data import Dataset, BatchSampler
  25. from megatron import print_rank_0, get_args, get_tokenizer, mpu
  26. from megatron.data.biencoder_dataset_utils import make_attention_mask
  27. def get_nq_dataset(qa_data, split):
  28. args = get_args()
  29. tokenizer = get_tokenizer()
  30. dataset = NQDataset('Google NQ {} Split'.format(split),
  31. 'Google Natural Questions',
  32. qa_data,
  33. tokenizer,
  34. args.retriever_seq_length)
  35. return dataset
  36. def process_nq_batch(batch):
  37. query_tokens = batch['token_ids'].long().cuda()
  38. query_mask = (batch['token_mask'] < 0.5).cuda()
  39. query_types = batch['token_types'].long().cuda()
  40. query_len = batch['seq_len'].long().cuda()
  41. reference = batch['reference']
  42. return query_tokens, query_mask, query_types, query_len, reference
  43. class CustomDataLoader(DataLoader):
  44. def __init__(self, dataset, eval=False, **kwargs):
  45. if kwargs.get('collate_fn', None) is None:
  46. kwargs['collate_fn'] = self._collate_fn
  47. self.eval = eval
  48. super().__init__(dataset, **kwargs)
  49. def _collate_fn(self, batch_data):
  50. # generate batch
  51. batch_size = len(batch_data)
  52. tensorized = OrderedDict()
  53. for d in batch_data:
  54. for k, v in d.items():
  55. tensorized.setdefault(k, []).append(v)
  56. assert len(tensorized) == 5
  57. tensorized['token_ids'] = torch.LongTensor(tensorized['token_ids'])
  58. tensorized['token_mask'] = torch.LongTensor(tensorized['token_mask'])
  59. tensorized['token_types'] = torch.LongTensor(tensorized['token_types'])
  60. tensorized['seq_len'] = torch.LongTensor(tensorized['seq_len'])
  61. return tensorized
  62. def get_one_epoch_nq_dataloader(dataset, micro_batch_size=None):
  63. """Data loader. Note that batch-size is the local (per GPU) batch-size.
  64. NOTE: This dataloader is not distributed !!!
  65. """
  66. args = get_args()
  67. if micro_batch_size is None:
  68. micro_batch_size = args.micro_batch_size
  69. num_workers = args.num_workers
  70. sampler = torch.utils.data.SequentialSampler(dataset)
  71. # importantly, drop_last must be False to get all the data.
  72. batch_sampler = BatchSampler(sampler,
  73. batch_size=micro_batch_size,
  74. drop_last=False)
  75. # Data loader. Note that batch size is the per GPU batch size.
  76. data_loader = CustomDataLoader(dataset,
  77. batch_sampler=batch_sampler,
  78. num_workers=num_workers,
  79. pin_memory=True)
  80. return data_loader
  81. def build_tokens_types_paddings_from_text(src_text, tokenizer, max_seq_length):
  82. """Build token types and paddings, trim if needed, and pad if needed."""
  83. src_text_ids = tokenizer.tokenize(src_text)
  84. return build_tokens_types_paddings_from_ids(src_text_ids,
  85. max_seq_length,
  86. tokenizer.cls,
  87. tokenizer.sep,
  88. tokenizer.pad)
  89. def build_tokens_types_paddings_from_ids(src_ids, max_seq_length, cls_id, \
  90. sep_id, pad_id):
  91. """
  92. Build token types and paddings, trim if needed, and pad if needed.
  93. TODO: Design modular interface to reuse this function. This is getting
  94. repeated multiple times in different tasks
  95. """
  96. enc_ids = []
  97. tokentypes_enc = []
  98. # [CLS].
  99. enc_ids.append(cls_id)
  100. tokentypes_enc.append(0)
  101. # A.
  102. len_src = len(src_ids)
  103. enc_ids.extend(src_ids)
  104. tokentypes_enc.extend([0] * len_src)
  105. # Cap the size.
  106. if len(enc_ids) > max_seq_length - 1:
  107. enc_ids = enc_ids[0: max_seq_length - 1]
  108. tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
  109. # [SEP].
  110. enc_ids.append(sep_id)
  111. tokentypes_enc.append(0)
  112. num_tokens_enc = len(enc_ids)
  113. # Padding.
  114. padding_length = max_seq_length - len(enc_ids)
  115. if padding_length > 0:
  116. enc_ids.extend([pad_id] * padding_length)
  117. tokentypes_enc.extend([pad_id] * padding_length)
  118. return enc_ids, tokentypes_enc, num_tokens_enc
  119. def build_sample(token_ids, token_types, num_tokens, reference):
  120. """
  121. Convert to numpy and return a sample consumed by the
  122. batch producer.
  123. """
  124. token_ids = np.array(token_ids, dtype=np.int64)
  125. token_types = np.array(token_types, dtype=np.int64)
  126. token_mask = make_attention_mask(token_ids, token_ids)
  127. sample = ({
  128. 'token_ids': token_ids,
  129. 'token_mask': token_mask,
  130. 'token_types': token_types,
  131. 'seq_len': num_tokens,
  132. 'reference': reference
  133. })
  134. return sample
  135. class NQDataset(ABC, Dataset):
  136. """
  137. Open Retrieval Question Answering evaluation using Google NQ dataset.
  138. """
  139. def __init__(self, task_name, dataset_name, datapath,
  140. tokenizer, max_seq_length):
  141. # Store inputs.
  142. self.task_name = task_name
  143. self.dataset_name = dataset_name
  144. self.tokenizer = tokenizer
  145. self.max_seq_length = max_seq_length
  146. print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
  147. self.dataset_name))
  148. print_rank_0(datapath)
  149. self.samples = self.process_samples_from_single_path(datapath)
  150. print_rank_0(' >> total number of samples: {}'.format(\
  151. len(self.samples)))
  152. def __len__(self):
  153. return len(self.samples)
  154. def __getitem__(self, idx):
  155. raw_sample = self.samples[idx]
  156. ques_tokens, tokentypes_enc, num_tokens_ques = \
  157. build_tokens_types_paddings_from_text(raw_sample['question'],
  158. self.tokenizer, self.max_seq_length)
  159. sample = build_sample(ques_tokens,
  160. tokentypes_enc,
  161. num_tokens_ques,
  162. raw_sample['answers'])
  163. return sample
  164. @staticmethod
  165. def process_samples_from_single_path(filename):
  166. print_rank_0(' > Processing {} ...'.format(filename))
  167. samples = []
  168. total = 0
  169. with open(filename, 'r') as ifile:
  170. reader = csv.reader(ifile, delimiter='\t')
  171. for row in reader:
  172. question = row[0]
  173. answers = eval(row[1])
  174. sample = {'question': question, 'answers': answers}
  175. total += 1
  176. samples.append(sample)
  177. if total % 1000 == 0:
  178. print_rank_0(' > processed {} so far ...'.format(total))
  179. print_rank_0(' >> processed {} samples.'.format(len(samples)))
  180. return samples