import argparse import os import sys import types import time import numpy as np import torch from torch.autograd import Variable from baselines.common.vec_env.dummy_vec_env import DummyVecEnv from baselines.common.vec_env.vec_normalize import VecNormalize from envs import make_env parser = argparse.ArgumentParser(description='RL') parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') parser.add_argument('--num-stack', type=int, default=4, help='number of frames to stack (default: 4)') parser.add_argument('--log-interval', type=int, default=10, help='log interval, one log per n updates (default: 10)') parser.add_argument('--env-name', default='PongNoFrameskip-v4', help='environment to train on (default: PongNoFrameskip-v4)') parser.add_argument('--load-dir', default='./trained_models/', help='directory to save agent logs (default: ./trained_models/)') args = parser.parse_args() env = make_env(args.env_name, args.seed, 0, None, size=None, video=True) env = DummyVecEnv([env]) actor_critic, ob_rms = \ torch.load(os.path.join(args.load_dir, args.env_name + ".pt")) if len(env.observation_space.shape) == 1: env = VecNormalize(env, ret=False) env.ob_rms = ob_rms # An ugly hack to remove updates def _obfilt(self, obs): if self.ob_rms: obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob) return obs else: return obs env._obfilt = types.MethodType(_obfilt, env) render_func = env.venv.envs[0].render else: render_func = env.envs[0].render obs_shape = env.observation_space.shape obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:]) current_obs = torch.zeros(1, *obs_shape) states = torch.zeros(1, actor_critic.state_size) masks = torch.zeros(1, 1) def update_current_obs(obs): shape_dim0 = env.observation_space.shape[0] obs = torch.from_numpy(obs).float() if args.num_stack > 1: current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:] current_obs[:, -shape_dim0:] = obs render_func('human') obs = env.reset() update_current_obs(obs) if args.env_name.find('Bullet') > -1: import pybullet as p torsoId = -1 for i in range(p.getNumBodies()): if (p.getBodyInfo(i)[0].decode() == "torso"): torsoId = i while True: value, action, _, states = actor_critic.act(Variable(current_obs, volatile=True), Variable(states, volatile=True), Variable(masks, volatile=True), deterministic=True) states = states.data cpu_actions = action.data.squeeze(1).cpu().numpy() # Obser reward and next obs obs, reward, done, _ = env.step(cpu_actions) time.sleep(0.05) masks.fill_(0.0 if done else 1.0) if current_obs.dim() == 4: current_obs *= masks.unsqueeze(2).unsqueeze(2) else: current_obs *= masks update_current_obs(obs) if args.env_name.find('Bullet') > -1: if torsoId > -1: distance = 5 yaw = 0 humanPos, humanOrn = p.getBasePositionAndOrientation(torsoId) p.resetDebugVisualizerCamera(distance, yaw, -20, humanPos) renderer = render_func('human') if not renderer.window: sys.exit(0)