mnli.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """MNLI dataset."""
  16. from megatron import print_rank_0
  17. from tasks.data_utils import clean_text
  18. from .data import GLUEAbstractDataset
  19. LABELS = {'contradiction': 0, 'entailment': 1, 'neutral': 2}
  20. class MNLIDataset(GLUEAbstractDataset):
  21. def __init__(self, name, datapaths, tokenizer, max_seq_length,
  22. test_label='contradiction'):
  23. self.test_label = test_label
  24. super().__init__('MNLI', name, datapaths,
  25. tokenizer, max_seq_length)
  26. def process_samples_from_single_path(self, filename):
  27. """"Implement abstract method."""
  28. print_rank_0(' > Processing {} ...'.format(filename))
  29. samples = []
  30. total = 0
  31. first = True
  32. is_test = False
  33. with open(filename, 'r') as f:
  34. for line in f:
  35. row = line.strip().split('\t')
  36. if first:
  37. first = False
  38. if len(row) == 10:
  39. is_test = True
  40. print_rank_0(
  41. ' reading {}, {} and {} columns and setting '
  42. 'labels to {}'.format(
  43. row[0].strip(), row[8].strip(),
  44. row[9].strip(), self.test_label))
  45. else:
  46. print_rank_0(' reading {} , {}, {}, and {} columns '
  47. '...'.format(
  48. row[0].strip(), row[8].strip(),
  49. row[9].strip(), row[-1].strip()))
  50. continue
  51. text_a = clean_text(row[8].strip())
  52. text_b = clean_text(row[9].strip())
  53. unique_id = int(row[0].strip())
  54. label = row[-1].strip()
  55. if is_test:
  56. label = self.test_label
  57. assert len(text_a) > 0
  58. assert len(text_b) > 0
  59. assert label in LABELS
  60. assert unique_id >= 0
  61. sample = {'text_a': text_a,
  62. 'text_b': text_b,
  63. 'label': LABELS[label],
  64. 'uid': unique_id}
  65. total += 1
  66. samples.append(sample)
  67. if total % 50000 == 0:
  68. print_rank_0(' > processed {} so far ...'.format(total))
  69. print_rank_0(' >> processed {} samples.'.format(len(samples)))
  70. return samples