data.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """ORQA dataset."""
  16. import json
  17. import random
  18. from abc import ABC
  19. from abc import abstractmethod
  20. import numpy as np
  21. from torch.utils.data import Dataset
  22. from megatron import print_rank_0, get_args
  23. from megatron.data.biencoder_dataset_utils import make_attention_mask
  24. def build_token_types_from_context_list(ctx_list, tokenizer, max_seq_length):
  25. ctx_id_list, ctx_types_list = [], []
  26. for context in ctx_list:
  27. title_ids = tokenizer.tokenize(context['title'])
  28. ctx_ids = tokenizer.tokenize(context['text'])
  29. ctx_ids = title_ids + [tokenizer.sep_id] + ctx_ids
  30. ctx_ids, ctx_types, _ = build_tokens_types_paddings_from_ids(ctx_ids,
  31. max_seq_length, tokenizer.cls,
  32. tokenizer.sep, tokenizer.pad)
  33. ctx_id_list.append(ctx_ids)
  34. ctx_types_list.append(ctx_types)
  35. return ctx_id_list, ctx_types_list
  36. def build_tokens_types_paddings_from_text(query, context,
  37. tokenizer, max_seq_length):
  38. """Build token types and paddings, trim if needed, and pad if needed."""
  39. query_ids = tokenizer.tokenize(query)
  40. query_ids, query_types, query_pad_mask = \
  41. build_tokens_types_paddings_from_ids(query_ids, max_seq_length, \
  42. tokenizer.cls, tokenizer.sep, tokenizer.pad)
  43. # Appending the title of the context at front
  44. extended_ctx_ids = None
  45. if context is not None:
  46. title_ids = tokenizer.tokenize(context['title'])
  47. ctx_ids = tokenizer.tokenize(context['text'])
  48. extended_ctx_ids = title_ids + [tokenizer.sep] + ctx_ids
  49. ctx_ids, ctx_types, ctx_pad_mask = \
  50. build_tokens_types_paddings_from_ids(extended_ctx_ids,
  51. max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad)
  52. return query_ids, query_types, query_pad_mask, \
  53. ctx_ids, ctx_types, ctx_pad_mask
  54. # Similar code tasks/data_utils with some changes
  55. def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
  56. cls_id, sep_id, pad_id):
  57. """Build token types and paddings, trim if needed, and pad if needed."""
  58. enc_ids = []
  59. tokentypes_enc = []
  60. # [CLS].
  61. enc_ids.append(cls_id)
  62. tokentypes_enc.append(0)
  63. # A.
  64. len_src = len(text_ids)
  65. enc_ids.extend(text_ids)
  66. tokentypes_enc.extend([0] * len_src)
  67. # Cap the size.
  68. if len(enc_ids) > max_seq_length - 1:
  69. enc_ids = enc_ids[0: max_seq_length - 1]
  70. tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
  71. # [SEP].
  72. enc_ids.append(sep_id)
  73. tokentypes_enc.append(0)
  74. num_tokens_enc = len(enc_ids)
  75. # Padding.
  76. padding_length = max_seq_length - len(enc_ids)
  77. if padding_length > 0:
  78. enc_ids.extend([pad_id] * padding_length)
  79. tokentypes_enc.extend([pad_id] * padding_length)
  80. pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length)
  81. pad_mask = np.array(pad_mask, dtype=np.int64)
  82. return enc_ids, tokentypes_enc, pad_mask
  83. def build_sample(query_ids, query_types, query_pad_mask,
  84. ctx_ids, ctx_types, ctx_pad_mask, answers,
  85. neg_ctx_id_list=None, neg_ctx_types_list=None,
  86. include_neg=False):
  87. """Convert to numpy and return a sample consumed by the batch producer."""
  88. query_ids = np.array(query_ids, dtype=np.int64)
  89. query_types = np.array(query_types, dtype=np.int64)
  90. query_mask = make_attention_mask(query_ids, query_ids)
  91. ctx_ids = np.array(ctx_ids, dtype=np.int64)
  92. ctx_types = np.array(ctx_types, dtype=np.int64)
  93. ctx_mask = make_attention_mask(ctx_ids, ctx_ids)
  94. sample = ({
  95. 'query': query_ids,
  96. 'query_mask': query_mask,
  97. 'query_types': query_types,
  98. 'query_pad_mask': query_pad_mask,
  99. 'context': ctx_ids,
  100. 'context_mask': ctx_mask,
  101. 'context_types': ctx_types,
  102. 'context_pad_mask': ctx_pad_mask,
  103. 'reference': answers
  104. })
  105. if include_neg:
  106. neg_ctx_ids = np.array(neg_ctx_id_list, dtype=np.int64)
  107. neg_ctx_id_types = np.array(neg_ctx_types_list, dtype=np.int64)
  108. neg_ctx_mask = np.array([make_attention_mask(ids, ids) \
  109. for ids in neg_ctx_ids], dtype=np.int64)
  110. sample['neg_context'] = neg_ctx_ids
  111. sample['neg_context_types'] = neg_ctx_id_types
  112. sample['neg_context_mask'] = neg_ctx_mask
  113. return sample
  114. class OpenRetrievalAbstractDataset(ABC, Dataset):
  115. """Open Retrieval base dataset class."""
  116. def __init__(self, task_name, dataset_name, datapaths, tokenizer, \
  117. max_seq_length, evaluate=False):
  118. # Store inputs.
  119. args = get_args()
  120. self.evaluate = evaluate
  121. self.val_av_rank_hard_neg = args.val_av_rank_hard_neg
  122. self.val_av_rank_other_neg = args.val_av_rank_other_neg
  123. self.train_with_neg = args.train_with_neg
  124. self.train_hard_neg = args.train_hard_neg
  125. self.task_name = task_name
  126. self.dataset_name = dataset_name
  127. self.tokenizer = tokenizer
  128. self.max_seq_length = max_seq_length
  129. print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
  130. self.dataset_name))
  131. # Process the files.
  132. string = ' > paths:'
  133. for path in datapaths:
  134. string += ' ' + path
  135. print_rank_0(string)
  136. self.samples = []
  137. for datapath in datapaths:
  138. self.samples.extend(self.process_samples_from_single_path(datapath))
  139. args = get_args()
  140. if args.sample_rate < 1: # subsample
  141. k = int(len(self.samples) * args.sample_rate)
  142. self.samples = random.sample(self.samples, k)
  143. print_rank_0(' >> total number of samples: {}'.format(
  144. len(self.samples)))
  145. def __len__(self):
  146. return len(self.samples)
  147. def __getitem__(self, idx):
  148. raw_sample = self.samples[idx]
  149. query_ids, query_types, query_pad_mask, ctx_ids, ctx_types, \
  150. ctx_pad_mask = build_tokens_types_paddings_from_text( \
  151. raw_sample['question'], raw_sample['pos_context'], \
  152. self.tokenizer, self.max_seq_length)
  153. if self.evaluate:
  154. neg_ctx_list = \
  155. raw_sample['negative_context'][:self.val_av_rank_other_neg] + \
  156. raw_sample['hard_negative_context'][:self.val_av_rank_hard_neg]
  157. neg_ctx_id_list, neg_ctx_types_list = \
  158. build_token_types_from_context_list(neg_ctx_list, \
  159. self.tokenizer, self.max_seq_length)
  160. elif self.train_with_neg:
  161. hard_negative_ctx = raw_sample['hard_negative_context']
  162. negative_ctx = raw_sample['negative_context']
  163. if True: # TODO: fix this or remove this condition
  164. random.shuffle(hard_negative_ctx)
  165. random.shuffle(negative_ctx)
  166. neg_ctx_list = hard_negative_ctx[:self.train_hard_neg]
  167. # In the Google NQ dataset by DPR paper, there are around more than
  168. # 50 missing hard negatives in training data.
  169. # In those cases, substitute hard negatives by simple negatives.
  170. if len(neg_ctx_list) < self.train_hard_neg:
  171. neg_ctx_list += negative_ctx[:self.train_hard_neg - \
  172. len(neg_ctx_list)]
  173. neg_ctx_id_list, neg_ctx_types_list = \
  174. build_token_types_from_context_list(neg_ctx_list,
  175. self.tokenizer, self.max_seq_length)
  176. else:
  177. neg_ctx_id_list = None
  178. neg_ctx_types_list = None
  179. sample = build_sample(query_ids, query_types, query_pad_mask,
  180. ctx_ids, ctx_types, ctx_pad_mask,
  181. raw_sample['answers'],
  182. neg_ctx_id_list, neg_ctx_types_list,
  183. include_neg=self.evaluate or self.train_with_neg)
  184. return sample
  185. @staticmethod
  186. @abstractmethod
  187. def process_samples_from_single_path(filename):
  188. """Abstract method that takes a filename and
  189. returns a list of dataset samples, each sample being a dict of
  190. {'text': string, 'text': string}
  191. """
  192. pass
  193. def normalize_question(question):
  194. if question[-1] == '?':
  195. question = question[:-1]
  196. return question
  197. # The following class reads the datasets for training retriever as
  198. # prepared by the DPR codebase (https://github.com/facebookresearch/DPR)
  199. class NQSupervisedDataset(OpenRetrievalAbstractDataset):
  200. def __init__(self, name, datapaths, tokenizer, max_seq_length, \
  201. evaluate=False):
  202. super().__init__('natural_questions_ret',
  203. name,
  204. datapaths,
  205. tokenizer,
  206. max_seq_length,
  207. evaluate=evaluate)
  208. @staticmethod
  209. def process_samples_from_single_path(filename):
  210. """"Implement abstract method."""
  211. print_rank_0(' > Processing {} ...'.format(filename))
  212. samples = []
  213. total = 0
  214. with open(filename, 'r', encoding="utf-8") as f:
  215. data = json.load(f)
  216. for row in data:
  217. question = normalize_question(row['question'])
  218. pos_context = row['positive_ctxs'][0]
  219. # Hard Negative Contexts
  220. if len(row['hard_negative_ctxs']) > 0:
  221. hard_neg_context = row['hard_negative_ctxs']
  222. else:
  223. hard_neg_context = []
  224. # Negative Contexts
  225. if len(row['negative_ctxs']) > 0:
  226. neg_context = row['negative_ctxs']
  227. else:
  228. neg_context = []
  229. answers = row['answers']
  230. sample = {'question': question,
  231. 'pos_context': pos_context,
  232. 'hard_negative_context': hard_neg_context,
  233. 'negative_context': neg_context,
  234. 'answers': answers}
  235. total += 1
  236. samples.append(sample)
  237. if total % 5000 == 0:
  238. print_rank_0(' > processed {} so far ...'.format(total))
  239. print_rank_0(' >> processed {} samples.'.format(len(samples)))
  240. return samples