edit_curve.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. #!/usr/bin/env python
  2. import csv
  3. import glob
  4. import numpy as np
  5. history_files = glob.glob("cifar100_*.csv")
  6. print(history_files)
  7. data = []
  8. loss = []
  9. for filename in history_files:
  10. print(filename)
  11. with open(filename) as f:
  12. reader = csv.reader(f)
  13. datathis = [row for row in reader]
  14. if len(data) == 0:
  15. for i, el in enumerate(datathis):
  16. if i == 0:
  17. continue
  18. data.append([float(el[2])])
  19. loss.append([float(el[1])])
  20. else:
  21. for i, el in enumerate(datathis):
  22. if i == 0:
  23. continue
  24. if i == len(data):
  25. break
  26. data[i - 1].append(float(el[2]))
  27. loss[i - 1].append(float(el[1]))
  28. # crop to where all are trained
  29. print(len(data))
  30. orderings = []
  31. for i, el in enumerate(data):
  32. if len(el) != len(history_files):
  33. break
  34. data = data[:i]
  35. loss = loss[:i]
  36. print(len(data))
  37. # orderings
  38. def get_changes(ord1, ord2):
  39. """Count how often the order changes between two elements."""
  40. changes = 0
  41. for i in range(10):
  42. for j in range(i + 1, 10):
  43. o1go2 = (ord1.index(i) > ord1.index(j) and
  44. ord2.index(i) > ord2.index(j))
  45. o1lo2 = (ord1.index(i) < ord1.index(j) and
  46. ord2.index(i) < ord2.index(j))
  47. if not (o1go2 or o1lo2):
  48. changes += 1
  49. return changes
  50. if len(history_files) > 1:
  51. orderings = []
  52. for row in data:
  53. ordering = zip(range(10), row)
  54. ordering = sorted(ordering, key=lambda n: n[1])
  55. ordering = [el[0] for el in ordering]
  56. orderings.append(ordering)
  57. get_changes(orderings[0], orderings[1])
  58. change_list = []
  59. for ord1, ord2 in zip(orderings, orderings[1:]):
  60. changes = get_changes(ord1, ord2)
  61. change_list.append(changes)
  62. change_list = np.array(change_list)
  63. print("change_list = {}".format(change_list.mean()))
  64. # write
  65. max_range = 0
  66. with open('baseline_cifar_test_acc.csv', 'w') as fp:
  67. writer = csv.writer(fp, delimiter=',')
  68. writer.writerow(["epoch", "min_acc", "max_acc", "mean_acc", "mean_loss"])
  69. for epoch, row in enumerate(data):
  70. if len(row) < len(history_files):
  71. break
  72. max_range = max(max_range, max(row) - min(row))
  73. writer.writerow([epoch, min(row), max(row), np.array(row).mean(),
  74. np.array(loss[epoch]).mean()])
  75. print("max_range={}".format(max_range))