visualize.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # Copied from https://github.com/emansim/baselines-mansimov/blob/master/baselines/a2c/visualize_atari.py
  2. # and https://github.com/emansim/baselines-mansimov/blob/master/baselines/a2c/load.py
  3. # Thanks to the author and OpenAI team!
  4. import glob
  5. import json
  6. import os
  7. import matplotlib
  8. matplotlib.use('Agg')
  9. import matplotlib.pyplot as plt
  10. import numpy as np
  11. from scipy.signal import medfilt
  12. matplotlib.rcParams.update({'font.size': 8})
  13. def smooth_reward_curve(x, y):
  14. # Halfwidth of our smoothing convolution
  15. halfwidth = min(31, int(np.ceil(len(x) / 30)))
  16. k = halfwidth
  17. xsmoo = x[k:-k]
  18. ysmoo = np.convolve(y, np.ones(2 * k + 1), mode='valid') / \
  19. np.convolve(np.ones_like(y), np.ones(2 * k + 1), mode='valid')
  20. downsample = max(int(np.floor(len(xsmoo) / 1e3)), 1)
  21. return xsmoo[::downsample], ysmoo[::downsample]
  22. def fix_point(x, y, interval):
  23. np.insert(x, 0, 0)
  24. np.insert(y, 0, 0)
  25. fx, fy = [], []
  26. pointer = 0
  27. ninterval = int(max(x) / interval + 1)
  28. for i in range(ninterval):
  29. tmpx = interval * i
  30. while pointer + 1 < len(x) and tmpx > x[pointer + 1]:
  31. pointer += 1
  32. if pointer + 1 < len(x):
  33. alpha = (y[pointer + 1] - y[pointer]) / \
  34. (x[pointer + 1] - x[pointer])
  35. tmpy = y[pointer] + alpha * (tmpx - x[pointer])
  36. fx.append(tmpx)
  37. fy.append(tmpy)
  38. return fx, fy
  39. def load_data(indir, smooth, bin_size):
  40. datas = []
  41. infiles = glob.glob(os.path.join(indir, '*.monitor.csv'))
  42. for inf in infiles:
  43. with open(inf, 'r') as f:
  44. f.readline()
  45. f.readline()
  46. for line in f:
  47. tmp = line.split(',')
  48. t_time = float(tmp[2])
  49. tmp = [t_time, int(tmp[1]), float(tmp[0])]
  50. datas.append(tmp)
  51. datas = sorted(datas, key=lambda d_entry: d_entry[0])
  52. result = []
  53. timesteps = 0
  54. for i in range(len(datas)):
  55. result.append([timesteps, datas[i][-1]])
  56. timesteps += datas[i][1]
  57. if len(result) < bin_size:
  58. return [None, None]
  59. x, y = np.array(result)[:, 0], np.array(result)[:, 1]
  60. if smooth == 1:
  61. x, y = smooth_reward_curve(x, y)
  62. if smooth == 2:
  63. y = medfilt(y, kernel_size=9)
  64. x, y = fix_point(x, y, bin_size)
  65. return [x, y]
  66. color_defaults = [
  67. '#1f77b4', # muted blue
  68. '#ff7f0e', # safety orange
  69. '#2ca02c', # cooked asparagus green
  70. '#d62728', # brick red
  71. '#9467bd', # muted purple
  72. '#8c564b', # chestnut brown
  73. '#e377c2', # raspberry yogurt pink
  74. '#7f7f7f', # middle gray
  75. '#bcbd22', # curry yellow-green
  76. '#17becf' # blue-teal
  77. ]
  78. def visdom_plot(viz, win, folder, game, name, bin_size=100, smooth=1):
  79. tx, ty = load_data(folder, smooth, bin_size)
  80. if tx is None or ty is None:
  81. return win
  82. fig = plt.figure()
  83. plt.plot(tx, ty, label="{}".format(name))
  84. # Ugly hack to detect atari
  85. if game.find('NoFrameskip') > -1:
  86. plt.xticks([1e6, 2e6, 4e6, 6e6, 8e6, 10e6],
  87. ["1M", "2M", "4M", "6M", "8M", "10M"])
  88. plt.xlim(0, 10e6)
  89. else:
  90. plt.xticks([1e5, 2e5, 4e5, 6e5, 8e5, 1e5],
  91. ["0.1M", "0.2M", "0.4M", "0.6M", "0.8M", "1M"])
  92. plt.xlim(0, 1e6)
  93. plt.xlabel('Number of Timesteps')
  94. plt.ylabel('Rewards')
  95. plt.title(game)
  96. plt.legend(loc=4)
  97. plt.show()
  98. plt.draw()
  99. image = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
  100. image = image.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
  101. plt.close(fig)
  102. # Show it in visdom
  103. image = np.transpose(image, (2, 0, 1))
  104. return viz.image(image, win=win)
  105. if __name__ == "__main__":
  106. from visdom import Visdom
  107. viz = Visdom()
  108. visdom_plot(viz, None, '/tmp/gym/', 'BreakOut', 'a2c', bin_size=100, smooth=1)