123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- import os
- import argparse
- import collections
- import numpy as np
- import torch
- def process_files(args):
- all_predictions = collections.OrderedDict()
- all_labels = collections.OrderedDict()
- all_uid = collections.OrderedDict()
- for path in args.paths:
- path = os.path.join(path, args.prediction_name)
- try:
- data = torch.load(path)
- for dataset in data:
- name, d = dataset
- predictions, labels, uid = d
- if name not in all_predictions:
- all_predictions[name] = np.array(predictions)
- if args.labels is None:
- args.labels = [i for i in range(all_predictions[name].shape[1])]
- if args.eval:
- all_labels[name] = np.array(labels)
- all_uid[name] = np.array(uid)
- else:
- all_predictions[name] += np.array(predictions)
- assert np.allclose(all_uid[name], np.array(uid))
- except Exception as e:
- print(e)
- continue
- return all_predictions, all_labels, all_uid
- def get_threshold(all_predictions, all_labels, one_threshold=False):
- if one_threshold:
- all_predictons = {'combined': np.concatenate(list(all_predictions.values()))}
- all_labels = {'combined': np.concatenate(list(all_predictions.labels()))}
- out_thresh = []
- for dataset in all_predictions:
- preds = all_predictions[dataset]
- labels = all_labels[dataset]
- out_thresh.append(calc_threshold(preds, labels))
- return out_thresh
- def calc_threshold(p, l):
- trials = [(i) * (1. / 100.) for i in range(100)]
- best_acc = float('-inf')
- best_thresh = 0
- for t in trials:
- acc = ((apply_threshold(p, t).argmax(-1) == l).astype(float)).mean()
- if acc > best_acc:
- best_acc = acc
- best_thresh = t
- return best_thresh
- def apply_threshold(preds, t):
- assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0])))
- prob = preds[:, -1]
- thresholded = (prob >= t).astype(int)
- preds = np.zeros_like(preds)
- preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1
- return preds
- def threshold_predictions(all_predictions, threshold):
- if len(threshold) != len(all_predictions):
- threshold = [threshold[-1]] * (len(all_predictions) - len(threshold))
- for i, dataset in enumerate(all_predictions):
- thresh = threshold[i]
- preds = all_predictions[dataset]
- all_predictions[dataset] = apply_threshold(preds, thresh)
- return all_predictions
- def postprocess_predictions(all_predictions, all_labels, args):
- for d in all_predictions:
- all_predictions[d] = all_predictions[d] / len(args.paths)
- if args.calc_threshold:
- args.threshold = get_threshold(all_predictions, all_labels, args.one_threshold)
- print('threshold', args.threshold)
- if args.threshold is not None:
- all_predictions = threshold_predictions(all_predictions, args.threshold)
- return all_predictions, all_labels
- def write_predictions(all_predictions, all_labels, all_uid, args):
- all_correct = 0
- count = 0
- for dataset in all_predictions:
- preds = all_predictions[dataset]
- preds = np.argmax(preds, -1)
- if args.eval:
- correct = (preds == all_labels[dataset]).sum()
- num = len(all_labels[dataset])
- accuracy = correct / num
- count += num
- all_correct += correct
- accuracy = (preds == all_labels[dataset]).mean()
- print(accuracy)
- if not os.path.exists(os.path.join(args.outdir, dataset)):
- os.makedirs(os.path.join(args.outdir, dataset))
- outpath = os.path.join(
- args.outdir, dataset, os.path.splitext(
- args.prediction_name)[0] + '.tsv')
- with open(outpath, 'w') as f:
- f.write('id\tlabel\n')
- f.write('\n'.join(str(uid) + '\t' + str(args.labels[p])
- for uid, p in zip(all_uid[dataset], preds.tolist())))
- if args.eval:
- print(all_correct / count)
- def ensemble_predictions(args):
- all_predictions, all_labels, all_uid = process_files(args)
- all_predictions, all_labels = postprocess_predictions(all_predictions, all_labels, args)
- write_predictions(all_predictions, all_labels, all_uid, args)
- def main():
- parser = argparse.ArgumentParser()
- parser.add_argument('--paths', required=True, nargs='+',
- help='paths to checkpoint directories used in ensemble')
- parser.add_argument('--eval', action='store_true',
- help='compute accuracy metrics against labels (dev set)')
- parser.add_argument('--outdir',
- help='directory to place ensembled predictions in')
- parser.add_argument('--prediction-name', default='test_predictions.pt',
- help='name of predictions in checkpoint directories')
- parser.add_argument('--calc-threshold', action='store_true',
- help='calculate threshold classification')
- parser.add_argument('--one-threshold', action='store_true',
- help='use on threshold for all subdatasets')
- parser.add_argument('--threshold', nargs='+', default=None, type=float,
- help='user supplied threshold for classification')
- parser.add_argument('--labels', nargs='+', default=None,
- help='whitespace separated list of label names')
- args = parser.parse_args()
- ensemble_predictions(args)
- if __name__ == '__main__':
- main()
|