qqp.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """QQP dataset."""
  16. from megatron import print_rank_0
  17. from tasks.data_utils import clean_text
  18. from .data import GLUEAbstractDataset
  19. LABELS = [0, 1]
  20. class QQPDataset(GLUEAbstractDataset):
  21. def __init__(self, name, datapaths, tokenizer, max_seq_length,
  22. test_label=0):
  23. self.test_label = test_label
  24. super().__init__('QQP', name, datapaths,
  25. tokenizer, max_seq_length)
  26. def process_samples_from_single_path(self, filename):
  27. """"Implement abstract method."""
  28. print_rank_0(' > Processing {} ...'.format(filename))
  29. samples = []
  30. total = 0
  31. first = True
  32. is_test = False
  33. with open(filename, 'r') as f:
  34. for line in f:
  35. row = line.strip().split('\t')
  36. if first:
  37. first = False
  38. if len(row) == 3:
  39. is_test = True
  40. print_rank_0(' reading {}, {}, and {} columns and '
  41. 'setting labels to {}'.format(
  42. row[0].strip(), row[1].strip(),
  43. row[2].strip(), self.test_label))
  44. else:
  45. assert len(row) == 6
  46. print_rank_0(' reading {}, {}, {}, and {} columns'
  47. ' ...'.format(
  48. row[0].strip(), row[3].strip(),
  49. row[4].strip(), row[5].strip()))
  50. continue
  51. if is_test:
  52. assert len(row) == 3, 'expected length 3: {}'.format(row)
  53. uid = int(row[0].strip())
  54. text_a = clean_text(row[1].strip())
  55. text_b = clean_text(row[2].strip())
  56. label = self.test_label
  57. assert len(text_a) > 0
  58. assert len(text_b) > 0
  59. else:
  60. if len(row) == 6:
  61. uid = int(row[0].strip())
  62. text_a = clean_text(row[3].strip())
  63. text_b = clean_text(row[4].strip())
  64. label = int(row[5].strip())
  65. else:
  66. print_rank_0('***WARNING*** index error, '
  67. 'skipping: {}'.format(row))
  68. continue
  69. if len(text_a) == 0:
  70. print_rank_0('***WARNING*** zero length a, '
  71. 'skipping: {}'.format(row))
  72. continue
  73. if len(text_b) == 0:
  74. print_rank_0('***WARNING*** zero length b, '
  75. 'skipping: {}'.format(row))
  76. continue
  77. assert label in LABELS
  78. assert uid >= 0
  79. sample = {'uid': uid,
  80. 'text_a': text_a,
  81. 'text_b': text_b,
  82. 'label': label}
  83. total += 1
  84. samples.append(sample)
  85. if total % 50000 == 0:
  86. print_rank_0(' > processed {} so far ...'.format(total))
  87. print_rank_0(' >> processed {} samples.'.format(len(samples)))
  88. return samples