123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """ORQA dataset."""
- import json
- import random
- from abc import ABC
- from abc import abstractmethod
- import numpy as np
- from torch.utils.data import Dataset
- from megatron import print_rank_0, get_args
- from megatron.data.biencoder_dataset_utils import make_attention_mask
- def build_token_types_from_context_list(ctx_list, tokenizer, max_seq_length):
- ctx_id_list, ctx_types_list = [], []
- for context in ctx_list:
- title_ids = tokenizer.tokenize(context['title'])
- ctx_ids = tokenizer.tokenize(context['text'])
- ctx_ids = title_ids + [tokenizer.sep_id] + ctx_ids
- ctx_ids, ctx_types, _ = build_tokens_types_paddings_from_ids(ctx_ids,
- max_seq_length, tokenizer.cls,
- tokenizer.sep, tokenizer.pad)
- ctx_id_list.append(ctx_ids)
- ctx_types_list.append(ctx_types)
- return ctx_id_list, ctx_types_list
- def build_tokens_types_paddings_from_text(query, context,
- tokenizer, max_seq_length):
- """Build token types and paddings, trim if needed, and pad if needed."""
- query_ids = tokenizer.tokenize(query)
- query_ids, query_types, query_pad_mask = \
- build_tokens_types_paddings_from_ids(query_ids, max_seq_length, \
- tokenizer.cls, tokenizer.sep, tokenizer.pad)
- # Appending the title of the context at front
- extended_ctx_ids = None
- if context is not None:
- title_ids = tokenizer.tokenize(context['title'])
- ctx_ids = tokenizer.tokenize(context['text'])
- extended_ctx_ids = title_ids + [tokenizer.sep] + ctx_ids
- ctx_ids, ctx_types, ctx_pad_mask = \
- build_tokens_types_paddings_from_ids(extended_ctx_ids,
- max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad)
- return query_ids, query_types, query_pad_mask, \
- ctx_ids, ctx_types, ctx_pad_mask
- # Similar code tasks/data_utils with some changes
- def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
- cls_id, sep_id, pad_id):
- """Build token types and paddings, trim if needed, and pad if needed."""
- enc_ids = []
- tokentypes_enc = []
- # [CLS].
- enc_ids.append(cls_id)
- tokentypes_enc.append(0)
- # A.
- len_src = len(text_ids)
- enc_ids.extend(text_ids)
- tokentypes_enc.extend([0] * len_src)
- # Cap the size.
- if len(enc_ids) > max_seq_length - 1:
- enc_ids = enc_ids[0: max_seq_length - 1]
- tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
- # [SEP].
- enc_ids.append(sep_id)
- tokentypes_enc.append(0)
- num_tokens_enc = len(enc_ids)
- # Padding.
- padding_length = max_seq_length - len(enc_ids)
- if padding_length > 0:
- enc_ids.extend([pad_id] * padding_length)
- tokentypes_enc.extend([pad_id] * padding_length)
- pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length)
- pad_mask = np.array(pad_mask, dtype=np.int64)
- return enc_ids, tokentypes_enc, pad_mask
- def build_sample(query_ids, query_types, query_pad_mask,
- ctx_ids, ctx_types, ctx_pad_mask, answers,
- neg_ctx_id_list=None, neg_ctx_types_list=None,
- include_neg=False):
- """Convert to numpy and return a sample consumed by the batch producer."""
- query_ids = np.array(query_ids, dtype=np.int64)
- query_types = np.array(query_types, dtype=np.int64)
- query_mask = make_attention_mask(query_ids, query_ids)
- ctx_ids = np.array(ctx_ids, dtype=np.int64)
- ctx_types = np.array(ctx_types, dtype=np.int64)
- ctx_mask = make_attention_mask(ctx_ids, ctx_ids)
- sample = ({
- 'query': query_ids,
- 'query_mask': query_mask,
- 'query_types': query_types,
- 'query_pad_mask': query_pad_mask,
- 'context': ctx_ids,
- 'context_mask': ctx_mask,
- 'context_types': ctx_types,
- 'context_pad_mask': ctx_pad_mask,
- 'reference': answers
- })
- if include_neg:
- neg_ctx_ids = np.array(neg_ctx_id_list, dtype=np.int64)
- neg_ctx_id_types = np.array(neg_ctx_types_list, dtype=np.int64)
- neg_ctx_mask = np.array([make_attention_mask(ids, ids) \
- for ids in neg_ctx_ids], dtype=np.int64)
- sample['neg_context'] = neg_ctx_ids
- sample['neg_context_types'] = neg_ctx_id_types
- sample['neg_context_mask'] = neg_ctx_mask
- return sample
- class OpenRetrievalAbstractDataset(ABC, Dataset):
- """Open Retrieval base dataset class."""
- def __init__(self, task_name, dataset_name, datapaths, tokenizer, \
- max_seq_length, evaluate=False):
- # Store inputs.
- args = get_args()
- self.evaluate = evaluate
- self.val_av_rank_hard_neg = args.val_av_rank_hard_neg
- self.val_av_rank_other_neg = args.val_av_rank_other_neg
- self.train_with_neg = args.train_with_neg
- self.train_hard_neg = args.train_hard_neg
- self.task_name = task_name
- self.dataset_name = dataset_name
- self.tokenizer = tokenizer
- self.max_seq_length = max_seq_length
- print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
- self.dataset_name))
- # Process the files.
- string = ' > paths:'
- for path in datapaths:
- string += ' ' + path
- print_rank_0(string)
- self.samples = []
- for datapath in datapaths:
- self.samples.extend(self.process_samples_from_single_path(datapath))
- args = get_args()
- if args.sample_rate < 1: # subsample
- k = int(len(self.samples) * args.sample_rate)
- self.samples = random.sample(self.samples, k)
- print_rank_0(' >> total number of samples: {}'.format(
- len(self.samples)))
- def __len__(self):
- return len(self.samples)
- def __getitem__(self, idx):
- raw_sample = self.samples[idx]
- query_ids, query_types, query_pad_mask, ctx_ids, ctx_types, \
- ctx_pad_mask = build_tokens_types_paddings_from_text( \
- raw_sample['question'], raw_sample['pos_context'], \
- self.tokenizer, self.max_seq_length)
- if self.evaluate:
- neg_ctx_list = \
- raw_sample['negative_context'][:self.val_av_rank_other_neg] + \
- raw_sample['hard_negative_context'][:self.val_av_rank_hard_neg]
- neg_ctx_id_list, neg_ctx_types_list = \
- build_token_types_from_context_list(neg_ctx_list, \
- self.tokenizer, self.max_seq_length)
- elif self.train_with_neg:
- hard_negative_ctx = raw_sample['hard_negative_context']
- negative_ctx = raw_sample['negative_context']
- if True: # TODO: fix this or remove this condition
- random.shuffle(hard_negative_ctx)
- random.shuffle(negative_ctx)
- neg_ctx_list = hard_negative_ctx[:self.train_hard_neg]
- # In the Google NQ dataset by DPR paper, there are around more than
- # 50 missing hard negatives in training data.
- # In those cases, substitute hard negatives by simple negatives.
- if len(neg_ctx_list) < self.train_hard_neg:
- neg_ctx_list += negative_ctx[:self.train_hard_neg - \
- len(neg_ctx_list)]
- neg_ctx_id_list, neg_ctx_types_list = \
- build_token_types_from_context_list(neg_ctx_list,
- self.tokenizer, self.max_seq_length)
- else:
- neg_ctx_id_list = None
- neg_ctx_types_list = None
- sample = build_sample(query_ids, query_types, query_pad_mask,
- ctx_ids, ctx_types, ctx_pad_mask,
- raw_sample['answers'],
- neg_ctx_id_list, neg_ctx_types_list,
- include_neg=self.evaluate or self.train_with_neg)
- return sample
- @staticmethod
- @abstractmethod
- def process_samples_from_single_path(filename):
- """Abstract method that takes a filename and
- returns a list of dataset samples, each sample being a dict of
- {'text': string, 'text': string}
- """
- pass
- def normalize_question(question):
- if question[-1] == '?':
- question = question[:-1]
- return question
- # The following class reads the datasets for training retriever as
- # prepared by the DPR codebase (https://github.com/facebookresearch/DPR)
- class NQSupervisedDataset(OpenRetrievalAbstractDataset):
- def __init__(self, name, datapaths, tokenizer, max_seq_length, \
- evaluate=False):
- super().__init__('natural_questions_ret',
- name,
- datapaths,
- tokenizer,
- max_seq_length,
- evaluate=evaluate)
- @staticmethod
- def process_samples_from_single_path(filename):
- """"Implement abstract method."""
- print_rank_0(' > Processing {} ...'.format(filename))
- samples = []
- total = 0
- with open(filename, 'r', encoding="utf-8") as f:
- data = json.load(f)
- for row in data:
- question = normalize_question(row['question'])
- pos_context = row['positive_ctxs'][0]
- # Hard Negative Contexts
- if len(row['hard_negative_ctxs']) > 0:
- hard_neg_context = row['hard_negative_ctxs']
- else:
- hard_neg_context = []
- # Negative Contexts
- if len(row['negative_ctxs']) > 0:
- neg_context = row['negative_ctxs']
- else:
- neg_context = []
- answers = row['answers']
- sample = {'question': question,
- 'pos_context': pos_context,
- 'hard_negative_context': hard_neg_context,
- 'negative_context': neg_context,
- 'answers': answers}
- total += 1
- samples.append(sample)
- if total % 5000 == 0:
- print_rank_0(' > processed {} so far ...'.format(total))
- print_rank_0(' >> processed {} samples.'.format(len(samples)))
- return samples
|