| 
					
				 | 
			
			
				@@ -0,0 +1,110 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import argparse 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import os 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import sys 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import types 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import time 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import numpy as np 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import torch 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from torch.autograd import Variable 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from baselines.common.vec_env.dummy_vec_env import DummyVecEnv 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from baselines.common.vec_env.vec_normalize import VecNormalize 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from envs import make_env 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+parser = argparse.ArgumentParser(description='RL') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+parser.add_argument('--seed', type=int, default=1, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    help='random seed (default: 1)') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+parser.add_argument('--num-stack', type=int, default=4, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    help='number of frames to stack (default: 4)') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+parser.add_argument('--log-interval', type=int, default=10, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    help='log interval, one log per n updates (default: 10)') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+parser.add_argument('--env-name', default='PongNoFrameskip-v4', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    help='environment to train on (default: PongNoFrameskip-v4)') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+parser.add_argument('--load-dir', default='./trained_models/', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    help='directory to save agent logs (default: ./trained_models/)') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+args = parser.parse_args() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+env = make_env(args.env_name, args.seed, 0, None, size=None, video=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+env = DummyVecEnv([env]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+actor_critic, ob_rms = \ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            torch.load(os.path.join(args.load_dir, args.env_name + ".pt")) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+if len(env.observation_space.shape) == 1: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    env = VecNormalize(env, ret=False) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    env.ob_rms = ob_rms 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # An ugly hack to remove updates 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _obfilt(self, obs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if self.ob_rms: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return obs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return obs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    env._obfilt = types.MethodType(_obfilt, env) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    render_func = env.venv.envs[0].render 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    render_func = env.envs[0].render 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+obs_shape = env.observation_space.shape 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+current_obs = torch.zeros(1, *obs_shape) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+states = torch.zeros(1, actor_critic.state_size) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+masks = torch.zeros(1, 1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def update_current_obs(obs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    shape_dim0 = env.observation_space.shape[0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    obs = torch.from_numpy(obs).float() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if args.num_stack > 1: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    current_obs[:, -shape_dim0:] = obs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+render_func('human') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+obs = env.reset() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+update_current_obs(obs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+if args.env_name.find('Bullet') > -1: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    import pybullet as p 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    torsoId = -1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for i in range(p.getNumBodies()): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if (p.getBodyInfo(i)[0].decode() == "torso"): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            torsoId = i 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+while True: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    value, action, _, states = actor_critic.act(Variable(current_obs, volatile=True), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                Variable(states, volatile=True), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                Variable(masks, volatile=True), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                deterministic=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    states = states.data 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    cpu_actions = action.data.squeeze(1).cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Obser reward and next obs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    obs, reward, done, _ = env.step(cpu_actions) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    time.sleep(0.05) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    masks.fill_(0.0 if done else 1.0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if current_obs.dim() == 4: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        current_obs *= masks.unsqueeze(2).unsqueeze(2) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        current_obs *= masks 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    update_current_obs(obs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if args.env_name.find('Bullet') > -1: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if torsoId > -1: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            distance = 5 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            yaw = 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            humanPos, humanOrn = p.getBasePositionAndOrientation(torsoId) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            p.resetDebugVisualizerCamera(distance, yaw, -20, humanPos) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    renderer = render_func('human') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if not renderer.window: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        sys.exit(0) 
			 |