|
@@ -128,9 +128,11 @@ def main():
|
|
|
for j in range(num_updates):
|
|
|
for step in range(args.num_steps):
|
|
|
# Sample actions
|
|
|
- value, action, action_log_prob, states = actor_critic.act(Variable(rollouts.observations[step], volatile=True),
|
|
|
- Variable(rollouts.states[step], volatile=True),
|
|
|
- Variable(rollouts.masks[step], volatile=True))
|
|
|
+ value, action, action_log_prob, states = actor_critic.act(
|
|
|
+ Variable(rollouts.observations[step], volatile=True),
|
|
|
+ Variable(rollouts.states[step], volatile=True),
|
|
|
+ Variable(rollouts.masks[step], volatile=True)
|
|
|
+ )
|
|
|
cpu_actions = action.data.squeeze(1).cpu().numpy()
|
|
|
|
|
|
# Obser reward and next obs
|
|
@@ -155,17 +157,21 @@ def main():
|
|
|
update_current_obs(obs)
|
|
|
rollouts.insert(step, current_obs, states.data, action.data, action_log_prob.data, value.data, reward, masks)
|
|
|
|
|
|
- next_value = actor_critic(Variable(rollouts.observations[-1], volatile=True),
|
|
|
- Variable(rollouts.states[-1], volatile=True),
|
|
|
- Variable(rollouts.masks[-1], volatile=True))[0].data
|
|
|
+ next_value = actor_critic(
|
|
|
+ Variable(rollouts.observations[-1], volatile=True),
|
|
|
+ Variable(rollouts.states[-1], volatile=True),
|
|
|
+ Variable(rollouts.masks[-1], volatile=True)
|
|
|
+ )[0].data
|
|
|
|
|
|
rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau)
|
|
|
|
|
|
if args.algo in ['a2c', 'acktr']:
|
|
|
- values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions(Variable(rollouts.observations[:-1].view(-1, *obs_shape)),
|
|
|
- Variable(rollouts.states[0].view(-1, actor_critic.state_size)),
|
|
|
- Variable(rollouts.masks[:-1].view(-1, 1)),
|
|
|
- Variable(rollouts.actions.view(-1, action_shape)))
|
|
|
+ values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions(
|
|
|
+ Variable(rollouts.observations[:-1].view(-1, *obs_shape)),
|
|
|
+ Variable(rollouts.states[:-1].view(-1, actor_critic.state_size)),
|
|
|
+ Variable(rollouts.masks[:-1].view(-1, 1)),
|
|
|
+ Variable(rollouts.actions.view(-1, action_shape))
|
|
|
+ )
|
|
|
|
|
|
values = values.view(args.num_steps, args.num_processes, 1)
|
|
|
action_log_probs = action_log_probs.view(args.num_steps, args.num_processes, 1)
|
|
@@ -205,11 +211,9 @@ def main():
|
|
|
|
|
|
for e in range(args.ppo_epoch):
|
|
|
if args.recurrent_policy:
|
|
|
- data_generator = rollouts.recurrent_generator(advantages,
|
|
|
- args.num_mini_batch)
|
|
|
+ data_generator = rollouts.recurrent_generator(advantages, args.num_mini_batch)
|
|
|
else:
|
|
|
- data_generator = rollouts.feed_forward_generator(advantages,
|
|
|
- args.num_mini_batch)
|
|
|
+ data_generator = rollouts.feed_forward_generator(advantages, args.num_mini_batch)
|
|
|
|
|
|
for sample in data_generator:
|
|
|
observations_batch, states_batch, actions_batch, \
|
|
@@ -217,10 +221,12 @@ def main():
|
|
|
adv_targ = sample
|
|
|
|
|
|
# Reshape to do in a single forward pass for all steps
|
|
|
- values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions(Variable(observations_batch),
|
|
|
- Variable(states_batch),
|
|
|
- Variable(masks_batch),
|
|
|
- Variable(actions_batch))
|
|
|
+ values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions(
|
|
|
+ Variable(observations_batch),
|
|
|
+ Variable(states_batch),
|
|
|
+ Variable(masks_batch),
|
|
|
+ Variable(actions_batch)
|
|
|
+ )
|
|
|
|
|
|
adv_targ = Variable(adv_targ)
|
|
|
ratio = torch.exp(action_log_probs - Variable(old_action_log_probs_batch))
|