record_video.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import argparse
  2. import os
  3. import sys
  4. import types
  5. import time
  6. import numpy as np
  7. import torch
  8. from torch.autograd import Variable
  9. from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
  10. from baselines.common.vec_env.vec_normalize import VecNormalize
  11. from envs import make_env
  12. parser = argparse.ArgumentParser(description='RL')
  13. parser.add_argument('--seed', type=int, default=1,
  14. help='random seed (default: 1)')
  15. parser.add_argument('--num-stack', type=int, default=4,
  16. help='number of frames to stack (default: 4)')
  17. parser.add_argument('--log-interval', type=int, default=10,
  18. help='log interval, one log per n updates (default: 10)')
  19. parser.add_argument('--env-name', default='PongNoFrameskip-v4',
  20. help='environment to train on (default: PongNoFrameskip-v4)')
  21. parser.add_argument('--load-dir', default='./trained_models/',
  22. help='directory to save agent logs (default: ./trained_models/)')
  23. args = parser.parse_args()
  24. env = make_env(args.env_name, args.seed, 0, None, size=None, video=True)
  25. env = DummyVecEnv([env])
  26. actor_critic, ob_rms = \
  27. torch.load(os.path.join(args.load_dir, args.env_name + ".pt"))
  28. if len(env.observation_space.shape) == 1:
  29. env = VecNormalize(env, ret=False)
  30. env.ob_rms = ob_rms
  31. # An ugly hack to remove updates
  32. def _obfilt(self, obs):
  33. if self.ob_rms:
  34. obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
  35. return obs
  36. else:
  37. return obs
  38. env._obfilt = types.MethodType(_obfilt, env)
  39. render_func = env.venv.envs[0].render
  40. else:
  41. render_func = env.envs[0].render
  42. obs_shape = env.observation_space.shape
  43. obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])
  44. current_obs = torch.zeros(1, *obs_shape)
  45. states = torch.zeros(1, actor_critic.state_size)
  46. masks = torch.zeros(1, 1)
  47. def update_current_obs(obs):
  48. shape_dim0 = env.observation_space.shape[0]
  49. obs = torch.from_numpy(obs).float()
  50. if args.num_stack > 1:
  51. current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
  52. current_obs[:, -shape_dim0:] = obs
  53. render_func('human')
  54. obs = env.reset()
  55. update_current_obs(obs)
  56. if args.env_name.find('Bullet') > -1:
  57. import pybullet as p
  58. torsoId = -1
  59. for i in range(p.getNumBodies()):
  60. if (p.getBodyInfo(i)[0].decode() == "torso"):
  61. torsoId = i
  62. while True:
  63. value, action, _, states = actor_critic.act(Variable(current_obs, volatile=True),
  64. Variable(states, volatile=True),
  65. Variable(masks, volatile=True),
  66. deterministic=True)
  67. states = states.data
  68. cpu_actions = action.data.squeeze(1).cpu().numpy()
  69. # Obser reward and next obs
  70. obs, reward, done, _ = env.step(cpu_actions)
  71. time.sleep(0.05)
  72. masks.fill_(0.0 if done else 1.0)
  73. if current_obs.dim() == 4:
  74. current_obs *= masks.unsqueeze(2).unsqueeze(2)
  75. else:
  76. current_obs *= masks
  77. update_current_obs(obs)
  78. if args.env_name.find('Bullet') > -1:
  79. if torsoId > -1:
  80. distance = 5
  81. yaw = 0
  82. humanPos, humanOrn = p.getBasePositionAndOrientation(torsoId)
  83. p.resetDebugVisualizerCamera(distance, yaw, -20, humanPos)
  84. renderer = render_func('human')
  85. if not renderer.window:
  86. sys.exit(0)