Browse Source

Added code to use Gym Monitor for video recording

Maxime Chevalier-Boisvert 7 years ago
parent
commit
6ded215f0b
2 changed files with 119 additions and 2 deletions
  1. 9 2
      basicrl/envs.py
  2. 110 0
      basicrl/record_video.py

+ 9 - 2
basicrl/envs.py

@@ -19,7 +19,7 @@ except:
     pass
 
 
-def make_env(env_id, seed, rank, log_dir, size=None):
+def make_env(env_id, seed, rank, log_dir, size=None, video=False):
     def _thunk():
 
         env = gym.make(env_id)
@@ -36,11 +36,18 @@ def make_env(env_id, seed, rank, log_dir, size=None):
 
         #env = StateBonus(env)
 
+        if video:
+            env = gym.wrappers.Monitor(
+                env,
+                "./monitor",
+                video_callable=lambda episode_id: True,
+                force=True
+            )
+
         return env
 
     return _thunk
 
-
 class WrapPyTorch(gym.ObservationWrapper):
     def __init__(self, env=None):
         super(WrapPyTorch, self).__init__(env)

+ 110 - 0
basicrl/record_video.py

@@ -0,0 +1,110 @@
+import argparse
+import os
+import sys
+import types
+import time
+
+import numpy as np
+import torch
+from torch.autograd import Variable
+from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
+from baselines.common.vec_env.vec_normalize import VecNormalize
+
+from envs import make_env
+
+
+parser = argparse.ArgumentParser(description='RL')
+parser.add_argument('--seed', type=int, default=1,
+                    help='random seed (default: 1)')
+parser.add_argument('--num-stack', type=int, default=4,
+                    help='number of frames to stack (default: 4)')
+parser.add_argument('--log-interval', type=int, default=10,
+                    help='log interval, one log per n updates (default: 10)')
+parser.add_argument('--env-name', default='PongNoFrameskip-v4',
+                    help='environment to train on (default: PongNoFrameskip-v4)')
+parser.add_argument('--load-dir', default='./trained_models/',
+                    help='directory to save agent logs (default: ./trained_models/)')
+args = parser.parse_args()
+
+
+env = make_env(args.env_name, args.seed, 0, None, size=None, video=True)
+env = DummyVecEnv([env])
+
+actor_critic, ob_rms = \
+            torch.load(os.path.join(args.load_dir, args.env_name + ".pt"))
+
+
+if len(env.observation_space.shape) == 1:
+    env = VecNormalize(env, ret=False)
+    env.ob_rms = ob_rms
+
+    # An ugly hack to remove updates
+    def _obfilt(self, obs):
+        if self.ob_rms:
+            obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
+            return obs
+        else:
+            return obs
+    env._obfilt = types.MethodType(_obfilt, env)
+    render_func = env.venv.envs[0].render
+else:
+    render_func = env.envs[0].render
+
+obs_shape = env.observation_space.shape
+obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])
+current_obs = torch.zeros(1, *obs_shape)
+states = torch.zeros(1, actor_critic.state_size)
+masks = torch.zeros(1, 1)
+
+
+def update_current_obs(obs):
+    shape_dim0 = env.observation_space.shape[0]
+    obs = torch.from_numpy(obs).float()
+    if args.num_stack > 1:
+        current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
+    current_obs[:, -shape_dim0:] = obs
+
+
+render_func('human')
+obs = env.reset()
+update_current_obs(obs)
+
+if args.env_name.find('Bullet') > -1:
+    import pybullet as p
+
+    torsoId = -1
+    for i in range(p.getNumBodies()):
+        if (p.getBodyInfo(i)[0].decode() == "torso"):
+            torsoId = i
+
+while True:
+    value, action, _, states = actor_critic.act(Variable(current_obs, volatile=True),
+                                                Variable(states, volatile=True),
+                                                Variable(masks, volatile=True),
+                                                deterministic=True)
+    states = states.data
+    cpu_actions = action.data.squeeze(1).cpu().numpy()
+    # Obser reward and next obs
+    obs, reward, done, _ = env.step(cpu_actions)
+
+    time.sleep(0.05)
+
+    masks.fill_(0.0 if done else 1.0)
+
+    if current_obs.dim() == 4:
+        current_obs *= masks.unsqueeze(2).unsqueeze(2)
+    else:
+        current_obs *= masks
+    update_current_obs(obs)
+
+    if args.env_name.find('Bullet') > -1:
+        if torsoId > -1:
+            distance = 5
+            yaw = 0
+            humanPos, humanOrn = p.getBasePositionAndOrientation(torsoId)
+            p.resetDebugVisualizerCamera(distance, yaw, -20, humanPos)
+
+    renderer = render_func('human')
+
+    if not renderer.window:
+        sys.exit(0)