main.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import copy
  2. import glob
  3. import os
  4. import time
  5. import operator
  6. from functools import reduce
  7. import gym
  8. import numpy as np
  9. import torch
  10. import torch.nn as nn
  11. import torch.nn.functional as F
  12. import torch.optim as optim
  13. from torch.autograd import Variable
  14. from arguments import get_args
  15. from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
  16. from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
  17. from baselines.common.vec_env.vec_normalize import VecNormalize
  18. from envs import make_env
  19. from kfac import KFACOptimizer
  20. from model import CNNPolicy, MLPPolicy
  21. from storage import RolloutStorage
  22. from visualize import visdom_plot
  23. args = get_args()
  24. assert args.algo in ['a2c', 'ppo', 'acktr']
  25. if args.recurrent_policy:
  26. assert args.algo in ['a2c', 'ppo'], \
  27. 'Recurrent policy is not implemented for ACKTR'
  28. num_updates = int(args.num_frames) // args.num_steps // args.num_processes
  29. torch.manual_seed(args.seed)
  30. if args.cuda:
  31. torch.cuda.manual_seed(args.seed)
  32. try:
  33. os.makedirs(args.log_dir)
  34. except OSError:
  35. files = glob.glob(os.path.join(args.log_dir, '*.monitor.csv'))
  36. for f in files:
  37. os.remove(f)
  38. def main():
  39. print("#######")
  40. print("WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards")
  41. print("#######")
  42. os.environ['OMP_NUM_THREADS'] = '1'
  43. if args.vis:
  44. from visdom import Visdom
  45. viz = Visdom()
  46. win = None
  47. envs = [make_env(args.env_name, args.seed, i, args.log_dir)
  48. for i in range(args.num_processes)]
  49. if args.num_processes > 1:
  50. envs = SubprocVecEnv(envs)
  51. else:
  52. envs = DummyVecEnv(envs)
  53. if len(envs.observation_space.shape) == 1:
  54. envs = VecNormalize(envs)
  55. obs_shape = envs.observation_space.shape
  56. obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])
  57. obs_numel = reduce(operator.mul, obs_shape, 1)
  58. if len(obs_shape) == 3 and obs_numel > 1024:
  59. actor_critic = CNNPolicy(obs_shape[0], envs.action_space, args.recurrent_policy)
  60. else:
  61. assert not args.recurrent_policy, \
  62. "Recurrent policy is not implemented for the MLP controller"
  63. actor_critic = MLPPolicy(obs_numel, envs.action_space)
  64. if envs.action_space.__class__.__name__ == "Discrete":
  65. action_shape = 1
  66. else:
  67. action_shape = envs.action_space.shape[0]
  68. if args.cuda:
  69. actor_critic.cuda()
  70. if args.algo == 'a2c':
  71. optimizer = optim.RMSprop(actor_critic.parameters(), args.lr, eps=args.eps, alpha=args.alpha)
  72. elif args.algo == 'ppo':
  73. optimizer = optim.Adam(actor_critic.parameters(), args.lr, eps=args.eps)
  74. elif args.algo == 'acktr':
  75. optimizer = KFACOptimizer(actor_critic)
  76. rollouts = RolloutStorage(args.num_steps, args.num_processes, obs_shape, envs.action_space, actor_critic.state_size)
  77. current_obs = torch.zeros(args.num_processes, *obs_shape)
  78. def update_current_obs(obs):
  79. shape_dim0 = envs.observation_space.shape[0]
  80. obs = torch.from_numpy(obs).float()
  81. if args.num_stack > 1:
  82. current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
  83. current_obs[:, -shape_dim0:] = obs
  84. obs = envs.reset()
  85. update_current_obs(obs)
  86. rollouts.observations[0].copy_(current_obs)
  87. # These variables are used to compute average rewards for all processes.
  88. episode_rewards = torch.zeros([args.num_processes, 1])
  89. final_rewards = torch.zeros([args.num_processes, 1])
  90. if args.cuda:
  91. current_obs = current_obs.cuda()
  92. rollouts.cuda()
  93. start = time.time()
  94. for j in range(num_updates):
  95. for step in range(args.num_steps):
  96. # Sample actions
  97. value, action, action_log_prob, states = actor_critic.act(Variable(rollouts.observations[step], volatile=True),
  98. Variable(rollouts.states[step], volatile=True),
  99. Variable(rollouts.masks[step], volatile=True))
  100. cpu_actions = action.data.squeeze(1).cpu().numpy()
  101. # Obser reward and next obs
  102. obs, reward, done, info = envs.step(cpu_actions)
  103. reward = torch.from_numpy(np.expand_dims(np.stack(reward), 1)).float()
  104. episode_rewards += reward
  105. # If done then clean the history of observations.
  106. masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
  107. final_rewards *= masks
  108. final_rewards += (1 - masks) * episode_rewards
  109. episode_rewards *= masks
  110. if args.cuda:
  111. masks = masks.cuda()
  112. if current_obs.dim() == 4:
  113. current_obs *= masks.unsqueeze(2).unsqueeze(2)
  114. else:
  115. current_obs *= masks
  116. update_current_obs(obs)
  117. rollouts.insert(step, current_obs, states.data, action.data, action_log_prob.data, value.data, reward, masks)
  118. next_value = actor_critic(Variable(rollouts.observations[-1], volatile=True),
  119. Variable(rollouts.states[-1], volatile=True),
  120. Variable(rollouts.masks[-1], volatile=True))[0].data
  121. rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau)
  122. if args.algo in ['a2c', 'acktr']:
  123. values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions(Variable(rollouts.observations[:-1].view(-1, *obs_shape)),
  124. Variable(rollouts.states[0].view(-1, actor_critic.state_size)),
  125. Variable(rollouts.masks[:-1].view(-1, 1)),
  126. Variable(rollouts.actions.view(-1, action_shape)))
  127. values = values.view(args.num_steps, args.num_processes, 1)
  128. action_log_probs = action_log_probs.view(args.num_steps, args.num_processes, 1)
  129. advantages = Variable(rollouts.returns[:-1]) - values
  130. value_loss = advantages.pow(2).mean()
  131. action_loss = -(Variable(advantages.data) * action_log_probs).mean()
  132. if args.algo == 'acktr' and optimizer.steps % optimizer.Ts == 0:
  133. # Sampled fisher, see Martens 2014
  134. actor_critic.zero_grad()
  135. pg_fisher_loss = -action_log_probs.mean()
  136. value_noise = Variable(torch.randn(values.size()))
  137. if args.cuda:
  138. value_noise = value_noise.cuda()
  139. sample_values = values + value_noise
  140. vf_fisher_loss = -(values - Variable(sample_values.data)).pow(2).mean()
  141. fisher_loss = pg_fisher_loss + vf_fisher_loss
  142. optimizer.acc_stats = True
  143. fisher_loss.backward(retain_graph=True)
  144. optimizer.acc_stats = False
  145. optimizer.zero_grad()
  146. (value_loss * args.value_loss_coef + action_loss - dist_entropy * args.entropy_coef).backward()
  147. if args.algo == 'a2c':
  148. nn.utils.clip_grad_norm(actor_critic.parameters(), args.max_grad_norm)
  149. optimizer.step()
  150. elif args.algo == 'ppo':
  151. advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1]
  152. advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)
  153. for e in range(args.ppo_epoch):
  154. if args.recurrent_policy:
  155. data_generator = rollouts.recurrent_generator(advantages,
  156. args.num_mini_batch)
  157. else:
  158. data_generator = rollouts.feed_forward_generator(advantages,
  159. args.num_mini_batch)
  160. for sample in data_generator:
  161. observations_batch, states_batch, actions_batch, \
  162. return_batch, masks_batch, old_action_log_probs_batch, \
  163. adv_targ = sample
  164. # Reshape to do in a single forward pass for all steps
  165. values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions(Variable(observations_batch),
  166. Variable(states_batch),
  167. Variable(masks_batch),
  168. Variable(actions_batch))
  169. adv_targ = Variable(adv_targ)
  170. ratio = torch.exp(action_log_probs - Variable(old_action_log_probs_batch))
  171. surr1 = ratio * adv_targ
  172. surr2 = torch.clamp(ratio, 1.0 - args.clip_param, 1.0 + args.clip_param) * adv_targ
  173. action_loss = -torch.min(surr1, surr2).mean() # PPO's pessimistic surrogate (L^CLIP)
  174. value_loss = (Variable(return_batch) - values).pow(2).mean()
  175. optimizer.zero_grad()
  176. (value_loss + action_loss - dist_entropy * args.entropy_coef).backward()
  177. nn.utils.clip_grad_norm(actor_critic.parameters(), args.max_grad_norm)
  178. optimizer.step()
  179. rollouts.after_update()
  180. if j % args.save_interval == 0 and args.save_dir != "":
  181. save_path = os.path.join(args.save_dir, args.algo)
  182. try:
  183. os.makedirs(save_path)
  184. except OSError:
  185. pass
  186. # A really ugly way to save a model to CPU
  187. save_model = actor_critic
  188. if args.cuda:
  189. save_model = copy.deepcopy(actor_critic).cpu()
  190. save_model = [save_model,
  191. hasattr(envs, 'ob_rms') and envs.ob_rms or None]
  192. torch.save(save_model, os.path.join(save_path, args.env_name + ".pt"))
  193. if j % args.log_interval == 0:
  194. end = time.time()
  195. total_num_steps = (j + 1) * args.num_processes * args.num_steps
  196. print("Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}".
  197. format(j, total_num_steps,
  198. int(total_num_steps / (end - start)),
  199. final_rewards.mean(),
  200. final_rewards.median(),
  201. final_rewards.min(),
  202. final_rewards.max(), dist_entropy.data[0],
  203. value_loss.data[0], action_loss.data[0]))
  204. if args.vis and j % args.vis_interval == 0:
  205. try:
  206. # Sometimes monitor doesn't properly flush the outputs
  207. win = visdom_plot(viz, win, args.log_dir, args.env_name, args.algo)
  208. except IOError:
  209. pass
  210. if __name__ == "__main__":
  211. main()