ensemble_classifier.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import os
  2. import argparse
  3. import collections
  4. import numpy as np
  5. import torch
  6. def process_files(args):
  7. all_predictions = collections.OrderedDict()
  8. all_labels = collections.OrderedDict()
  9. all_uid = collections.OrderedDict()
  10. for path in args.paths:
  11. path = os.path.join(path, args.prediction_name)
  12. try:
  13. data = torch.load(path)
  14. for dataset in data:
  15. name, d = dataset
  16. predictions, labels, uid = d
  17. if name not in all_predictions:
  18. all_predictions[name] = np.array(predictions)
  19. if args.labels is None:
  20. args.labels = [i for i in range(all_predictions[name].shape[1])]
  21. if args.eval:
  22. all_labels[name] = np.array(labels)
  23. all_uid[name] = np.array(uid)
  24. else:
  25. all_predictions[name] += np.array(predictions)
  26. assert np.allclose(all_uid[name], np.array(uid))
  27. except Exception as e:
  28. print(e)
  29. continue
  30. return all_predictions, all_labels, all_uid
  31. def get_threshold(all_predictions, all_labels, one_threshold=False):
  32. if one_threshold:
  33. all_predictons = {'combined': np.concatenate(list(all_predictions.values()))}
  34. all_labels = {'combined': np.concatenate(list(all_predictions.labels()))}
  35. out_thresh = []
  36. for dataset in all_predictions:
  37. preds = all_predictions[dataset]
  38. labels = all_labels[dataset]
  39. out_thresh.append(calc_threshold(preds, labels))
  40. return out_thresh
  41. def calc_threshold(p, l):
  42. trials = [(i) * (1. / 100.) for i in range(100)]
  43. best_acc = float('-inf')
  44. best_thresh = 0
  45. for t in trials:
  46. acc = ((apply_threshold(p, t).argmax(-1) == l).astype(float)).mean()
  47. if acc > best_acc:
  48. best_acc = acc
  49. best_thresh = t
  50. return best_thresh
  51. def apply_threshold(preds, t):
  52. assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0])))
  53. prob = preds[:, -1]
  54. thresholded = (prob >= t).astype(int)
  55. preds = np.zeros_like(preds)
  56. preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1
  57. return preds
  58. def threshold_predictions(all_predictions, threshold):
  59. if len(threshold) != len(all_predictions):
  60. threshold = [threshold[-1]] * (len(all_predictions) - len(threshold))
  61. for i, dataset in enumerate(all_predictions):
  62. thresh = threshold[i]
  63. preds = all_predictions[dataset]
  64. all_predictions[dataset] = apply_threshold(preds, thresh)
  65. return all_predictions
  66. def postprocess_predictions(all_predictions, all_labels, args):
  67. for d in all_predictions:
  68. all_predictions[d] = all_predictions[d] / len(args.paths)
  69. if args.calc_threshold:
  70. args.threshold = get_threshold(all_predictions, all_labels, args.one_threshold)
  71. print('threshold', args.threshold)
  72. if args.threshold is not None:
  73. all_predictions = threshold_predictions(all_predictions, args.threshold)
  74. return all_predictions, all_labels
  75. def write_predictions(all_predictions, all_labels, all_uid, args):
  76. all_correct = 0
  77. count = 0
  78. for dataset in all_predictions:
  79. preds = all_predictions[dataset]
  80. preds = np.argmax(preds, -1)
  81. if args.eval:
  82. correct = (preds == all_labels[dataset]).sum()
  83. num = len(all_labels[dataset])
  84. accuracy = correct / num
  85. count += num
  86. all_correct += correct
  87. accuracy = (preds == all_labels[dataset]).mean()
  88. print(accuracy)
  89. if not os.path.exists(os.path.join(args.outdir, dataset)):
  90. os.makedirs(os.path.join(args.outdir, dataset))
  91. outpath = os.path.join(
  92. args.outdir, dataset, os.path.splitext(
  93. args.prediction_name)[0] + '.tsv')
  94. with open(outpath, 'w') as f:
  95. f.write('id\tlabel\n')
  96. f.write('\n'.join(str(uid) + '\t' + str(args.labels[p])
  97. for uid, p in zip(all_uid[dataset], preds.tolist())))
  98. if args.eval:
  99. print(all_correct / count)
  100. def ensemble_predictions(args):
  101. all_predictions, all_labels, all_uid = process_files(args)
  102. all_predictions, all_labels = postprocess_predictions(all_predictions, all_labels, args)
  103. write_predictions(all_predictions, all_labels, all_uid, args)
  104. def main():
  105. parser = argparse.ArgumentParser()
  106. parser.add_argument('--paths', required=True, nargs='+',
  107. help='paths to checkpoint directories used in ensemble')
  108. parser.add_argument('--eval', action='store_true',
  109. help='compute accuracy metrics against labels (dev set)')
  110. parser.add_argument('--outdir',
  111. help='directory to place ensembled predictions in')
  112. parser.add_argument('--prediction-name', default='test_predictions.pt',
  113. help='name of predictions in checkpoint directories')
  114. parser.add_argument('--calc-threshold', action='store_true',
  115. help='calculate threshold classification')
  116. parser.add_argument('--one-threshold', action='store_true',
  117. help='use on threshold for all subdatasets')
  118. parser.add_argument('--threshold', nargs='+', default=None, type=float,
  119. help='user supplied threshold for classification')
  120. parser.add_argument('--labels', nargs='+', default=None,
  121. help='whitespace separated list of label names')
  122. args = parser.parse_args()
  123. ensemble_predictions(args)
  124. if __name__ == '__main__':
  125. main()