Browse Source

Lazily importing visdom

Maxime Chevalier-Boisvert 7 years ago
parent
commit
ab3db1ba0d
1 changed files with 47 additions and 135 deletions
  1. 47 135
      pytorch_rl/visualize.py

+ 47 - 135
pytorch_rl/visualize.py

@@ -1,142 +1,54 @@
-# Copied from https://github.com/emansim/baselines-mansimov/blob/master/baselines/a2c/visualize_atari.py
-# and https://github.com/emansim/baselines-mansimov/blob/master/baselines/a2c/load.py
-# Thanks to the author and OpenAI team!
-
-import glob
-import json
-import os
-
-import matplotlib
-matplotlib.use('Agg')
-import matplotlib.pyplot as plt
 import numpy as np
-from scipy.signal import medfilt
-matplotlib.rcParams.update({'font.size': 8})
-
-
-def smooth_reward_curve(x, y):
-    # Halfwidth of our smoothing convolution
-    halfwidth = min(31, int(np.ceil(len(x) / 30)))
-    k = halfwidth
-    xsmoo = x[k:-k]
-    ysmoo = np.convolve(y, np.ones(2 * k + 1), mode='valid') / \
-        np.convolve(np.ones_like(y), np.ones(2 * k + 1), mode='valid')
-    downsample = max(int(np.floor(len(xsmoo) / 1e3)), 1)
-    return xsmoo[::downsample], ysmoo[::downsample]
-
-
-def fix_point(x, y, interval):
-    np.insert(x, 0, 0)
-    np.insert(y, 0, 0)
-
-    fx, fy = [], []
-    pointer = 0
-
-    ninterval = int(max(x) / interval + 1)
-
-    for i in range(ninterval):
-        tmpx = interval * i
-
-        while pointer + 1 < len(x) and tmpx > x[pointer + 1]:
-            pointer += 1
-
-        if pointer + 1 < len(x):
-            alpha = (y[pointer + 1] - y[pointer]) / \
-                (x[pointer + 1] - x[pointer])
-            tmpy = y[pointer] + alpha * (tmpx - x[pointer])
-            fx.append(tmpx)
-            fy.append(tmpy)
-
-    return fx, fy
-
-
-def load_data(indir, smooth, bin_size):
-    datas = []
-    infiles = glob.glob(os.path.join(indir, '*.monitor.csv'))
-
-    for inf in infiles:
-        with open(inf, 'r') as f:
-            f.readline()
-            f.readline()
-            for line in f:
-                tmp = line.split(',')
-                t_time = float(tmp[2])
-                tmp = [t_time, int(tmp[1]), float(tmp[0])]
-                datas.append(tmp)
 
-    datas = sorted(datas, key=lambda d_entry: d_entry[0])
-    result = []
-    timesteps = 0
-    for i in range(len(datas)):
-        result.append([timesteps, datas[i][-1]])
-        timesteps += datas[i][1]
+vis = None
 
-    if len(result) < bin_size:
-        return [None, None]
+win = None
 
-    x, y = np.array(result)[:, 0], np.array(result)[:, 1]
+avg_reward = 0
 
-    if smooth == 1:
-        x, y = smooth_reward_curve(x, y)
+X = []
+Y = []
 
-    if smooth == 2:
-        y = medfilt(y, kernel_size=9)
-
-    x, y = fix_point(x, y, bin_size)
-    return [x, y]
-
-
-color_defaults = [
-    '#1f77b4',  # muted blue
-    '#ff7f0e',  # safety orange
-    '#2ca02c',  # cooked asparagus green
-    '#d62728',  # brick red
-    '#9467bd',  # muted purple
-    '#8c564b',  # chestnut brown
-    '#e377c2',  # raspberry yogurt pink
-    '#7f7f7f',  # middle gray
-    '#bcbd22',  # curry yellow-green
-    '#17becf'  # blue-teal
-]
-
-
-def visdom_plot(viz, win, folder, game, name, bin_size=100, smooth=1):
-    tx, ty = load_data(folder, smooth, bin_size)
-    if tx is None or ty is None:
-        return win
-
-    fig = plt.figure()
-    plt.plot(tx, ty, label="{}".format(name))
-
-    # Ugly hack to detect atari
-    if game.find('NoFrameskip') > -1:
-        plt.xticks([1e6, 2e6, 4e6, 6e6, 8e6, 10e6],
-                   ["1M", "2M", "4M", "6M", "8M", "10M"])
-        plt.xlim(0, 10e6)
-    else:
-        plt.xticks([1e5, 2e5, 4e5, 6e5, 8e5, 1e5],
-                   ["0.1M", "0.2M", "0.4M", "0.6M", "0.8M", "1M"])
-        plt.xlim(0, 1e6)
-
-    plt.xlabel('Number of Timesteps')
-    plt.ylabel('Rewards')
-
-
-    plt.title(game)
-    plt.legend(loc=4)
-    plt.show()
-    plt.draw()
-
-    image = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
-    image = image.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
-    plt.close(fig)
-
-    # Show it in visdom
-    image = np.transpose(image, (2, 0, 1))
-    return viz.image(image, win=win)
-
-
-if __name__ == "__main__":
+def visdom_plot(
+    total_num_steps,
+    mean_reward
+):
+    # Lazily import visdom so that people don't need to install visdom
+    # if they're not actually using it
     from visdom import Visdom
-    viz = Visdom()
-    visdom_plot(viz, None, '/tmp/gym/', 'BreakOut', 'a2c', bin_size=100, smooth=1)
+
+    global vis
+    global win
+    global avg_reward
+
+    if vis is None:
+        vis = Visdom()
+        assert vis.check_connection()
+
+        # Close all existing plots
+        vis.close()
+
+    # Running average for curve smoothing
+    avg_reward = avg_reward * 0.9 + 0.1 * mean_reward
+
+    X.append(total_num_steps)
+    Y.append(avg_reward)
+
+    # The plot with the handle 'win' is updated each time this is called
+    win = vis.line(
+        X = np.array(X),
+        Y = np.array(Y),
+        opts = dict(
+            #title = 'All Environments',
+            xlabel='Total time steps',
+            ylabel='Reward per episode',
+            ytickmin=0,
+            #ytickmax=1,
+            #ytickstep=0.1,
+            #legend=legend,
+            #showlegend=True,
+            width=900,
+            height=500
+        ),
+        win = win
+    )