data.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import glob
  2. import json
  3. import os
  4. import time
  5. from torch.utils.data import Dataset
  6. from megatron import print_rank_0
  7. from tasks.data_utils import build_sample
  8. from tasks.data_utils import build_tokens_types_paddings_from_ids
  9. from tasks.data_utils import clean_text
  10. NUM_CHOICES = 4
  11. MAX_QA_LENGTH = 128
  12. class RaceDataset(Dataset):
  13. def __init__(self, dataset_name, datapaths, tokenizer, max_seq_length,
  14. max_qa_length=MAX_QA_LENGTH):
  15. self.dataset_name = dataset_name
  16. print_rank_0(' > building RACE dataset for {}:'.format(
  17. self.dataset_name))
  18. string = ' > paths:'
  19. for path in datapaths:
  20. string += ' ' + path
  21. print_rank_0(string)
  22. self.samples = []
  23. for datapath in datapaths:
  24. self.samples.extend(process_single_datapath(datapath, tokenizer,
  25. max_qa_length,
  26. max_seq_length))
  27. print_rank_0(' >> total number of samples: {}'.format(
  28. len(self.samples)))
  29. # This indicates that each "sample" has multiple samples that
  30. # will collapse into batch dimension
  31. self.sample_multiplier = NUM_CHOICES
  32. def __len__(self):
  33. return len(self.samples)
  34. def __getitem__(self, idx):
  35. return self.samples[idx]
  36. def process_single_datapath(datapath, tokenizer, max_qa_length, max_seq_length):
  37. """Read in RACE files, combine, clean-up, tokenize, and convert to
  38. samples."""
  39. print_rank_0(' > working on {}'.format(datapath))
  40. start_time = time.time()
  41. # Get list of files.
  42. filenames = glob.glob(os.path.join(datapath, '*.txt'))
  43. samples = []
  44. num_docs = 0
  45. num_questions = 0
  46. num_samples = 0
  47. # Load all the files
  48. for filename in filenames:
  49. with open(filename, 'r') as f:
  50. for line in f:
  51. data = json.loads(line)
  52. num_docs += 1
  53. context = data["article"]
  54. questions = data["questions"]
  55. choices = data["options"]
  56. answers = data["answers"]
  57. # Check the length.
  58. assert len(questions) == len(answers)
  59. assert len(questions) == len(choices)
  60. # Context: clean up and convert to ids.
  61. context = clean_text(context)
  62. context_ids = tokenizer.tokenize(context)
  63. # Loop over questions.
  64. for qi, question in enumerate(questions):
  65. num_questions += 1
  66. # Label.
  67. label = ord(answers[qi]) - ord("A")
  68. assert label >= 0
  69. assert label < NUM_CHOICES
  70. assert len(choices[qi]) == NUM_CHOICES
  71. # For each question, build num-choices samples.
  72. ids_list = []
  73. types_list = []
  74. paddings_list = []
  75. for ci in range(NUM_CHOICES):
  76. choice = choices[qi][ci]
  77. # Merge with choice.
  78. if "_" in question:
  79. qa = question.replace("_", choice)
  80. else:
  81. qa = " ".join([question, choice])
  82. # Clean QA.
  83. qa = clean_text(qa)
  84. # Tokenize.
  85. qa_ids = tokenizer.tokenize(qa)
  86. # Trim if needed.
  87. if len(qa_ids) > max_qa_length:
  88. qa_ids = qa_ids[0:max_qa_length]
  89. # Build the sample.
  90. ids, types, paddings \
  91. = build_tokens_types_paddings_from_ids(
  92. qa_ids, context_ids, max_seq_length,
  93. tokenizer.cls, tokenizer.sep, tokenizer.pad)
  94. ids_list.append(ids)
  95. types_list.append(types)
  96. paddings_list.append(paddings)
  97. # Convert to numpy and add to samples
  98. samples.append(build_sample(ids_list, types_list,
  99. paddings_list, label,
  100. num_samples))
  101. num_samples += 1
  102. elapsed_time = time.time() - start_time
  103. print_rank_0(' > processed {} document, {} questions, and {} samples'
  104. ' in {:.2f} seconds'.format(num_docs, num_questions,
  105. num_samples, elapsed_time))
  106. return samples