| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 | import argparseimport osimport sysimport typesimport timeimport numpy as npimport torchfrom torch.autograd import Variablefrom baselines.common.vec_env.dummy_vec_env import DummyVecEnvfrom baselines.common.vec_env.vec_normalize import VecNormalizefrom envs import make_envparser = argparse.ArgumentParser(description='RL')parser.add_argument('--seed', type=int, default=1,                    help='random seed (default: 1)')parser.add_argument('--num-stack', type=int, default=4,                    help='number of frames to stack (default: 4)')parser.add_argument('--log-interval', type=int, default=10,                    help='log interval, one log per n updates (default: 10)')parser.add_argument('--env-name', default='PongNoFrameskip-v4',                    help='environment to train on (default: PongNoFrameskip-v4)')parser.add_argument('--load-dir', default='./trained_models/',                    help='directory to save agent logs (default: ./trained_models/)')args = parser.parse_args()env = make_env(args.env_name, args.seed, 0, None)env = DummyVecEnv([env])actor_critic, ob_rms = \            torch.load(os.path.join(args.load_dir, args.env_name + ".pt"))if len(env.observation_space.shape) == 1:    env = VecNormalize(env, ret=False)    env.ob_rms = ob_rms    # An ugly hack to remove updates    def _obfilt(self, obs):        if self.ob_rms:            obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)            return obs        else:            return obs    env._obfilt = types.MethodType(_obfilt, env)    render_func = env.venv.envs[0].renderelse:    render_func = env.envs[0].renderobs_shape = env.observation_space.shapeobs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])current_obs = torch.zeros(1, *obs_shape)states = torch.zeros(1, actor_critic.state_size)masks = torch.zeros(1, 1)def update_current_obs(obs):    shape_dim0 = env.observation_space.shape[0]    obs = torch.from_numpy(obs).float()    if args.num_stack > 1:        current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]    current_obs[:, -shape_dim0:] = obsrender_func('human')obs = env.reset()update_current_obs(obs)if args.env_name.find('Bullet') > -1:    import pybullet as p    torsoId = -1    for i in range(p.getNumBodies()):        if (p.getBodyInfo(i)[0].decode() == "torso"):            torsoId = iwhile True:    value, action, _, states = actor_critic.act(Variable(current_obs, volatile=True),                                                Variable(states, volatile=True),                                                Variable(masks, volatile=True),                                                deterministic=True)    states = states.data    cpu_actions = action.data.squeeze(1).cpu().numpy()    # Obser reward and next obs    obs, reward, done, _ = env.step(cpu_actions)    time.sleep(0.05)    masks.fill_(0.0 if done else 1.0)    if current_obs.dim() == 4:        current_obs *= masks.unsqueeze(2).unsqueeze(2)    else:        current_obs *= masks    update_current_obs(obs)    if args.env_name.find('Bullet') > -1:        if torsoId > -1:            distance = 5            yaw = 0            humanPos, humanOrn = p.getBasePositionAndOrientation(torsoId)            p.resetDebugVisualizerCamera(distance, yaw, -20, humanPos)    renderer = render_func('human')    if not renderer.window:        sys.exit(0)
 |