enjoy.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import argparse
  2. import os
  3. import sys
  4. import types
  5. import time
  6. import numpy as np
  7. import torch
  8. from torch.autograd import Variable
  9. from vec_env.dummy_vec_env import DummyVecEnv
  10. from envs import make_env
  11. parser = argparse.ArgumentParser(description='RL')
  12. parser.add_argument('--seed', type=int, default=1,
  13. help='random seed (default: 1)')
  14. parser.add_argument('--num-stack', type=int, default=1,
  15. help='number of frames to stack (default: 1)')
  16. parser.add_argument('--log-interval', type=int, default=10,
  17. help='log interval, one log per n updates (default: 10)')
  18. parser.add_argument('--env-name', default='PongNoFrameskip-v4',
  19. help='environment to train on (default: PongNoFrameskip-v4)')
  20. parser.add_argument('--load-dir', default='./trained_models/',
  21. help='directory to save agent logs (default: ./trained_models/)')
  22. args = parser.parse_args()
  23. env = make_env(args.env_name, args.seed, 0, None)
  24. env = DummyVecEnv([env])
  25. actor_critic, ob_rms = torch.load(os.path.join(args.load_dir, args.env_name + ".pt"))
  26. render_func = env.envs[0].render
  27. obs_shape = env.observation_space.shape
  28. obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])
  29. current_obs = torch.zeros(1, *obs_shape)
  30. states = torch.zeros(1, actor_critic.state_size)
  31. masks = torch.zeros(1, 1)
  32. def update_current_obs(obs):
  33. shape_dim0 = env.observation_space.shape[0]
  34. obs = torch.from_numpy(obs).float()
  35. if args.num_stack > 1:
  36. current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
  37. current_obs[:, -shape_dim0:] = obs
  38. render_func('human')
  39. obs = env.reset()
  40. update_current_obs(obs)
  41. while True:
  42. value, action, _, states = actor_critic.act(
  43. Variable(current_obs, volatile=True),
  44. Variable(states, volatile=True),
  45. Variable(masks, volatile=True),
  46. deterministic=True
  47. )
  48. states = states.data
  49. cpu_actions = action.data.squeeze(1).cpu().numpy()
  50. # Observation, reward and next obs
  51. obs, reward, done, _ = env.step(cpu_actions)
  52. time.sleep(0.05)
  53. masks.fill_(0.0 if done else 1.0)
  54. if current_obs.dim() == 4:
  55. current_obs *= masks.unsqueeze(2).unsqueeze(2)
  56. else:
  57. current_obs *= masks
  58. update_current_obs(obs)
  59. renderer = render_func('human')
  60. if not renderer.window:
  61. sys.exit(0)