main.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. import copy
  2. import glob
  3. import os
  4. import time
  5. import operator
  6. from functools import reduce
  7. import gym
  8. import numpy as np
  9. import torch
  10. import torch.nn as nn
  11. import torch.nn.functional as F
  12. import torch.optim as optim
  13. from torch.autograd import Variable
  14. from arguments import get_args
  15. from vec_env.dummy_vec_env import DummyVecEnv
  16. from vec_env.subproc_vec_env import SubprocVecEnv
  17. from envs import make_env
  18. from kfac import KFACOptimizer
  19. from model import Policy
  20. from storage import RolloutStorage
  21. from visualize import visdom_plot
  22. args = get_args()
  23. assert args.algo in ['a2c', 'ppo', 'acktr']
  24. if args.recurrent_policy:
  25. assert args.algo in ['a2c', 'ppo'], 'Recurrent policy is not implemented for ACKTR'
  26. num_updates = int(args.num_frames) // args.num_steps // args.num_processes
  27. torch.manual_seed(args.seed)
  28. if args.cuda:
  29. torch.cuda.manual_seed(args.seed)
  30. try:
  31. os.makedirs(args.log_dir)
  32. except OSError:
  33. files = glob.glob(os.path.join(args.log_dir, '*.monitor.csv'))
  34. for f in files:
  35. os.remove(f)
  36. def main():
  37. os.environ['OMP_NUM_THREADS'] = '1'
  38. envs = [make_env(args.env_name, args.seed, i, args.log_dir) for i in range(args.num_processes)]
  39. if args.num_processes > 1:
  40. envs = SubprocVecEnv(envs)
  41. else:
  42. envs = DummyVecEnv(envs)
  43. obs_shape = envs.observation_space.shape
  44. obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])
  45. obs_numel = reduce(operator.mul, obs_shape, 1)
  46. actor_critic = Policy(obs_numel, envs.action_space)
  47. # Maxime: log some info about the model and its size
  48. modelSize = 0
  49. for p in actor_critic.parameters():
  50. pSize = reduce(operator.mul, p.size(), 1)
  51. modelSize += pSize
  52. print(str(actor_critic))
  53. print('Total model size: %d' % modelSize)
  54. if envs.action_space.__class__.__name__ == "Discrete":
  55. action_shape = 1
  56. else:
  57. action_shape = envs.action_space.shape[0]
  58. if args.cuda:
  59. actor_critic.cuda()
  60. if args.algo == 'a2c':
  61. optimizer = optim.RMSprop(actor_critic.parameters(), args.lr, eps=args.eps, alpha=args.alpha)
  62. elif args.algo == 'ppo':
  63. optimizer = optim.Adam(actor_critic.parameters(), args.lr, eps=args.eps)
  64. elif args.algo == 'acktr':
  65. optimizer = KFACOptimizer(actor_critic)
  66. rollouts = RolloutStorage(args.num_steps, args.num_processes, obs_shape, envs.action_space, actor_critic.state_size)
  67. current_obs = torch.zeros(args.num_processes, *obs_shape)
  68. def update_current_obs(obs):
  69. shape_dim0 = envs.observation_space.shape[0]
  70. obs = torch.from_numpy(obs).float()
  71. if args.num_stack > 1:
  72. current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
  73. current_obs[:, -shape_dim0:] = obs
  74. obs = envs.reset()
  75. update_current_obs(obs)
  76. rollouts.observations[0].copy_(current_obs)
  77. # These variables are used to compute average rewards for all processes.
  78. episode_rewards = torch.zeros([args.num_processes, 1])
  79. final_rewards = torch.zeros([args.num_processes, 1])
  80. if args.cuda:
  81. current_obs = current_obs.cuda()
  82. rollouts.cuda()
  83. start = time.time()
  84. for j in range(num_updates):
  85. for step in range(args.num_steps):
  86. # Sample actions
  87. value, action, action_log_prob, states = actor_critic.act(
  88. Variable(rollouts.observations[step], volatile=True),
  89. Variable(rollouts.states[step], volatile=True),
  90. Variable(rollouts.masks[step], volatile=True)
  91. )
  92. cpu_actions = action.data.squeeze(1).cpu().numpy()
  93. # Obser reward and next obs
  94. obs, reward, done, info = envs.step(cpu_actions)
  95. reward = torch.from_numpy(np.expand_dims(np.stack(reward), 1)).float()
  96. episode_rewards += reward
  97. # If done then clean the history of observations.
  98. masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
  99. final_rewards *= masks
  100. final_rewards += (1 - masks) * episode_rewards
  101. episode_rewards *= masks
  102. if args.cuda:
  103. masks = masks.cuda()
  104. if current_obs.dim() == 4:
  105. current_obs *= masks.unsqueeze(2).unsqueeze(2)
  106. elif current_obs.dim() == 3:
  107. current_obs *= masks.unsqueeze(2)
  108. else:
  109. current_obs *= masks
  110. update_current_obs(obs)
  111. rollouts.insert(step, current_obs, states.data, action.data, action_log_prob.data, value.data, reward, masks)
  112. next_value = actor_critic(
  113. Variable(rollouts.observations[-1], volatile=True),
  114. Variable(rollouts.states[-1], volatile=True),
  115. Variable(rollouts.masks[-1], volatile=True)
  116. )[0].data
  117. rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau)
  118. if args.algo in ['a2c', 'acktr']:
  119. values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions(
  120. Variable(rollouts.observations[:-1].view(-1, *obs_shape)),
  121. Variable(rollouts.states[:-1].view(-1, actor_critic.state_size)),
  122. Variable(rollouts.masks[:-1].view(-1, 1)),
  123. Variable(rollouts.actions.view(-1, action_shape))
  124. )
  125. values = values.view(args.num_steps, args.num_processes, 1)
  126. action_log_probs = action_log_probs.view(args.num_steps, args.num_processes, 1)
  127. advantages = Variable(rollouts.returns[:-1]) - values
  128. value_loss = advantages.pow(2).mean()
  129. action_loss = -(Variable(advantages.data) * action_log_probs).mean()
  130. if args.algo == 'acktr' and optimizer.steps % optimizer.Ts == 0:
  131. # Sampled fisher, see Martens 2014
  132. actor_critic.zero_grad()
  133. pg_fisher_loss = -action_log_probs.mean()
  134. value_noise = Variable(torch.randn(values.size()))
  135. if args.cuda:
  136. value_noise = value_noise.cuda()
  137. sample_values = values + value_noise
  138. vf_fisher_loss = -(values - Variable(sample_values.data)).pow(2).mean()
  139. fisher_loss = pg_fisher_loss + vf_fisher_loss
  140. optimizer.acc_stats = True
  141. fisher_loss.backward(retain_graph=True)
  142. optimizer.acc_stats = False
  143. optimizer.zero_grad()
  144. (value_loss * args.value_loss_coef + action_loss - dist_entropy * args.entropy_coef).backward()
  145. if args.algo == 'a2c':
  146. nn.utils.clip_grad_norm(actor_critic.parameters(), args.max_grad_norm)
  147. optimizer.step()
  148. elif args.algo == 'ppo':
  149. advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1]
  150. advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)
  151. for e in range(args.ppo_epoch):
  152. if args.recurrent_policy:
  153. data_generator = rollouts.recurrent_generator(advantages, args.num_mini_batch)
  154. else:
  155. data_generator = rollouts.feed_forward_generator(advantages, args.num_mini_batch)
  156. for sample in data_generator:
  157. observations_batch, states_batch, actions_batch, \
  158. return_batch, masks_batch, old_action_log_probs_batch, \
  159. adv_targ = sample
  160. # Reshape to do in a single forward pass for all steps
  161. values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions(
  162. Variable(observations_batch),
  163. Variable(states_batch),
  164. Variable(masks_batch),
  165. Variable(actions_batch)
  166. )
  167. adv_targ = Variable(adv_targ)
  168. ratio = torch.exp(action_log_probs - Variable(old_action_log_probs_batch))
  169. surr1 = ratio * adv_targ
  170. surr2 = torch.clamp(ratio, 1.0 - args.clip_param, 1.0 + args.clip_param) * adv_targ
  171. action_loss = -torch.min(surr1, surr2).mean() # PPO's pessimistic surrogate (L^CLIP)
  172. value_loss = (Variable(return_batch) - values).pow(2).mean()
  173. optimizer.zero_grad()
  174. (value_loss + action_loss - dist_entropy * args.entropy_coef).backward()
  175. nn.utils.clip_grad_norm(actor_critic.parameters(), args.max_grad_norm)
  176. optimizer.step()
  177. rollouts.after_update()
  178. if j % args.save_interval == 0 and args.save_dir != "":
  179. save_path = os.path.join(args.save_dir, args.algo)
  180. try:
  181. os.makedirs(save_path)
  182. except OSError:
  183. pass
  184. # A really ugly way to save a model to CPU
  185. save_model = actor_critic
  186. if args.cuda:
  187. save_model = copy.deepcopy(actor_critic).cpu()
  188. save_model = [save_model,
  189. hasattr(envs, 'ob_rms') and envs.ob_rms or None]
  190. torch.save(save_model, os.path.join(save_path, args.env_name + ".pt"))
  191. if j % args.log_interval == 0:
  192. end = time.time()
  193. total_num_steps = (j + 1) * args.num_processes * args.num_steps
  194. print(
  195. "Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}".
  196. format(
  197. j,
  198. total_num_steps,
  199. int(total_num_steps / (end - start)),
  200. final_rewards.mean(),
  201. final_rewards.median(),
  202. final_rewards.min(),
  203. final_rewards.max(), dist_entropy.data[0],
  204. value_loss.data[0], action_loss.data[0]
  205. )
  206. )
  207. if args.vis and j % args.vis_interval == 0:
  208. win = visdom_plot(
  209. total_num_steps,
  210. final_rewards.mean()
  211. )
  212. if __name__ == "__main__":
  213. main()