Browse Source

Fixed visdom visualization code

Maxime Chevalier-Boisvert 7 years ago
parent
commit
0e496fe636
3 changed files with 70 additions and 165 deletions
  1. 2 2
      pytorch_rl/arguments.py
  2. 18 22
      pytorch_rl/main.py
  3. 50 141
      pytorch_rl/visualize.py

+ 2 - 2
pytorch_rl/arguments.py

@@ -43,8 +43,8 @@ def get_args():
                         help='log interval, one log per n updates (default: 10)')
     parser.add_argument('--save-interval', type=int, default=100,
                         help='save interval, one save per n updates (default: 10)')
-    parser.add_argument('--vis-interval', type=int, default=100,
-                        help='vis interval, one log per n updates (default: 100)')
+    parser.add_argument('--vis-interval', type=int, default=10,
+                        help='vis interval, one log per n updates')
     parser.add_argument('--num-frames', type=int, default=10e6,
                         help='number of frames to train (default: 10e6)')
     parser.add_argument('--env-name', default='PongNoFrameskip-v4',

+ 18 - 22
pytorch_rl/main.py

@@ -42,17 +42,8 @@ except OSError:
         os.remove(f)
 
 def main():
-    print("#######")
-    print("WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards")
-    print("#######")
-
     os.environ['OMP_NUM_THREADS'] = '1'
 
-    if args.vis:
-        from visdom import Visdom
-        viz = Visdom()
-        win = None
-
     envs = [make_env(args.env_name, args.seed, i, args.log_dir) for i in range(args.num_processes)]
 
     if args.num_processes > 1:
@@ -252,20 +243,25 @@ def main():
         if j % args.log_interval == 0:
             end = time.time()
             total_num_steps = (j + 1) * args.num_processes * args.num_steps
-            print("Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}".
-                format(j, total_num_steps,
-                       int(total_num_steps / (end - start)),
-                       final_rewards.mean(),
-                       final_rewards.median(),
-                       final_rewards.min(),
-                       final_rewards.max(), dist_entropy.data[0],
-                       value_loss.data[0], action_loss.data[0]))
+            print(
+                "Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}".
+                format(
+                    j,
+                    total_num_steps,
+                    int(total_num_steps / (end - start)),
+                    final_rewards.mean(),
+                    final_rewards.median(),
+                    final_rewards.min(),
+                    final_rewards.max(), dist_entropy.data[0],
+                    value_loss.data[0], action_loss.data[0]
+                )
+            )
+
         if args.vis and j % args.vis_interval == 0:
-            try:
-                # Sometimes monitor doesn't properly flush the outputs
-                win = visdom_plot(viz, win, args.log_dir, args.env_name, args.algo)
-            except IOError:
-                pass
+            win = visdom_plot(
+                total_num_steps,
+                final_rewards.mean()
+            )
 
 if __name__ == "__main__":
     main()

+ 50 - 141
pytorch_rl/visualize.py

@@ -1,142 +1,51 @@
-# Copied from https://github.com/emansim/baselines-mansimov/blob/master/baselines/a2c/visualize_atari.py
-# and https://github.com/emansim/baselines-mansimov/blob/master/baselines/a2c/load.py
-# Thanks to the author and OpenAI team!
-
-import glob
-import json
-import os
-
-import matplotlib
-matplotlib.use('Agg')
-import matplotlib.pyplot as plt
 import numpy as np
-from scipy.signal import medfilt
-matplotlib.rcParams.update({'font.size': 8})
-
-
-def smooth_reward_curve(x, y):
-    # Halfwidth of our smoothing convolution
-    halfwidth = min(31, int(np.ceil(len(x) / 30)))
-    k = halfwidth
-    xsmoo = x[k:-k]
-    ysmoo = np.convolve(y, np.ones(2 * k + 1), mode='valid') / \
-        np.convolve(np.ones_like(y), np.ones(2 * k + 1), mode='valid')
-    downsample = max(int(np.floor(len(xsmoo) / 1e3)), 1)
-    return xsmoo[::downsample], ysmoo[::downsample]
-
-
-def fix_point(x, y, interval):
-    np.insert(x, 0, 0)
-    np.insert(y, 0, 0)
-
-    fx, fy = [], []
-    pointer = 0
-
-    ninterval = int(max(x) / interval + 1)
-
-    for i in range(ninterval):
-        tmpx = interval * i
-
-        while pointer + 1 < len(x) and tmpx > x[pointer + 1]:
-            pointer += 1
-
-        if pointer + 1 < len(x):
-            alpha = (y[pointer + 1] - y[pointer]) / \
-                (x[pointer + 1] - x[pointer])
-            tmpy = y[pointer] + alpha * (tmpx - x[pointer])
-            fx.append(tmpx)
-            fy.append(tmpy)
-
-    return fx, fy
-
-
-def load_data(indir, smooth, bin_size):
-    datas = []
-    infiles = glob.glob(os.path.join(indir, '*.monitor.csv'))
-
-    for inf in infiles:
-        with open(inf, 'r') as f:
-            f.readline()
-            f.readline()
-            for line in f:
-                tmp = line.split(',')
-                t_time = float(tmp[2])
-                tmp = [t_time, int(tmp[1]), float(tmp[0])]
-                datas.append(tmp)
-
-    datas = sorted(datas, key=lambda d_entry: d_entry[0])
-    result = []
-    timesteps = 0
-    for i in range(len(datas)):
-        result.append([timesteps, datas[i][-1]])
-        timesteps += datas[i][1]
-
-    if len(result) < bin_size:
-        return [None, None]
-
-    x, y = np.array(result)[:, 0], np.array(result)[:, 1]
-
-    if smooth == 1:
-        x, y = smooth_reward_curve(x, y)
-
-    if smooth == 2:
-        y = medfilt(y, kernel_size=9)
-
-    x, y = fix_point(x, y, bin_size)
-    return [x, y]
-
-
-color_defaults = [
-    '#1f77b4',  # muted blue
-    '#ff7f0e',  # safety orange
-    '#2ca02c',  # cooked asparagus green
-    '#d62728',  # brick red
-    '#9467bd',  # muted purple
-    '#8c564b',  # chestnut brown
-    '#e377c2',  # raspberry yogurt pink
-    '#7f7f7f',  # middle gray
-    '#bcbd22',  # curry yellow-green
-    '#17becf'  # blue-teal
-]
-
-
-def visdom_plot(viz, win, folder, game, name, bin_size=100, smooth=1):
-    tx, ty = load_data(folder, smooth, bin_size)
-    if tx is None or ty is None:
-        return win
-
-    fig = plt.figure()
-    plt.plot(tx, ty, label="{}".format(name))
-
-    # Ugly hack to detect atari
-    if game.find('NoFrameskip') > -1:
-        plt.xticks([1e6, 2e6, 4e6, 6e6, 8e6, 10e6],
-                   ["1M", "2M", "4M", "6M", "8M", "10M"])
-        plt.xlim(0, 10e6)
-    else:
-        plt.xticks([1e5, 2e5, 4e5, 6e5, 8e5, 1e5],
-                   ["0.1M", "0.2M", "0.4M", "0.6M", "0.8M", "1M"])
-        plt.xlim(0, 1e6)
-
-    plt.xlabel('Number of Timesteps')
-    plt.ylabel('Rewards')
-
-
-    plt.title(game)
-    plt.legend(loc=4)
-    plt.show()
-    plt.draw()
-
-    image = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
-    image = image.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
-    plt.close(fig)
-
-    # Show it in visdom
-    image = np.transpose(image, (2, 0, 1))
-    return viz.image(image, win=win)
-
-
-if __name__ == "__main__":
-    from visdom import Visdom
-    viz = Visdom()
-    visdom_plot(viz, None, '/tmp/gym/', 'BreakOut', 'a2c', bin_size=100, smooth=1)
+from visdom import Visdom
+
+vis = None
+
+win = None
+
+avg_reward = 0
+
+X = []
+Y = []
+
+def visdom_plot(
+    total_num_steps,
+    mean_reward
+):
+    global vis
+    global win
+    global avg_reward
+
+    if vis is None:
+        vis = Visdom()
+        assert vis.check_connection()
+
+        # Close all existing plots
+        vis.close()
+
+    # Running average for curve smoothing
+    avg_reward = avg_reward * 0.9 + 0.1 * mean_reward
+
+    X.append(total_num_steps)
+    Y.append(avg_reward)
+
+    # The plot with the handle 'win' is updated each time this is called
+    win = vis.line(
+        X = np.array(X),
+        Y = np.array(Y),
+        opts = dict(
+            #title = 'All Environments',
+            xlabel='Total time steps',
+            ylabel='Reward per episode',
+            ytickmin=0,
+            #ytickmax=1,
+            #ytickstep=0.1,
+            #legend=legend,
+            #showlegend=True,
+            width=900,
+            height=500
+        ),
+        win = win
+    )