Browse Source

Refactored handling of recurrent policies for simplicity

Maxime Chevalier-Boisvert 7 years ago
parent
commit
16085191ab
2 changed files with 26 additions and 31 deletions
  1. 24 18
      pytorch_rl/main.py
  2. 2 13
      pytorch_rl/model.py

+ 24 - 18
pytorch_rl/main.py

@@ -128,9 +128,11 @@ def main():
     for j in range(num_updates):
         for step in range(args.num_steps):
             # Sample actions
-            value, action, action_log_prob, states = actor_critic.act(Variable(rollouts.observations[step], volatile=True),
-                                                                      Variable(rollouts.states[step], volatile=True),
-                                                                      Variable(rollouts.masks[step], volatile=True))
+            value, action, action_log_prob, states = actor_critic.act(
+                Variable(rollouts.observations[step], volatile=True),
+                Variable(rollouts.states[step], volatile=True),
+                Variable(rollouts.masks[step], volatile=True)
+            )
             cpu_actions = action.data.squeeze(1).cpu().numpy()
 
             # Obser reward and next obs
@@ -155,17 +157,21 @@ def main():
             update_current_obs(obs)
             rollouts.insert(step, current_obs, states.data, action.data, action_log_prob.data, value.data, reward, masks)
 
-        next_value = actor_critic(Variable(rollouts.observations[-1], volatile=True),
-                                  Variable(rollouts.states[-1], volatile=True),
-                                  Variable(rollouts.masks[-1], volatile=True))[0].data
+        next_value = actor_critic(
+            Variable(rollouts.observations[-1], volatile=True),
+            Variable(rollouts.states[-1], volatile=True),
+            Variable(rollouts.masks[-1], volatile=True)
+        )[0].data
 
         rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau)
 
         if args.algo in ['a2c', 'acktr']:
-            values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions(Variable(rollouts.observations[:-1].view(-1, *obs_shape)),
-                                                                                           Variable(rollouts.states[0].view(-1, actor_critic.state_size)),
-                                                                                           Variable(rollouts.masks[:-1].view(-1, 1)),
-                                                                                           Variable(rollouts.actions.view(-1, action_shape)))
+            values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions(
+                Variable(rollouts.observations[:-1].view(-1, *obs_shape)),
+                Variable(rollouts.states[:-1].view(-1, actor_critic.state_size)),
+                Variable(rollouts.masks[:-1].view(-1, 1)),
+                Variable(rollouts.actions.view(-1, action_shape))
+            )
 
             values = values.view(args.num_steps, args.num_processes, 1)
             action_log_probs = action_log_probs.view(args.num_steps, args.num_processes, 1)
@@ -205,11 +211,9 @@ def main():
 
             for e in range(args.ppo_epoch):
                 if args.recurrent_policy:
-                    data_generator = rollouts.recurrent_generator(advantages,
-                                                            args.num_mini_batch)
+                    data_generator = rollouts.recurrent_generator(advantages, args.num_mini_batch)
                 else:
-                    data_generator = rollouts.feed_forward_generator(advantages,
-                                                            args.num_mini_batch)
+                    data_generator = rollouts.feed_forward_generator(advantages, args.num_mini_batch)
 
                 for sample in data_generator:
                     observations_batch, states_batch, actions_batch, \
@@ -217,10 +221,12 @@ def main():
                             adv_targ = sample
 
                     # Reshape to do in a single forward pass for all steps
-                    values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions(Variable(observations_batch),
-                                                                                                   Variable(states_batch),
-                                                                                                   Variable(masks_batch),
-                                                                                                   Variable(actions_batch))
+                    values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions(
+                        Variable(observations_batch),
+                        Variable(states_batch),
+                        Variable(masks_batch),
+                        Variable(actions_batch)
+                    )
 
                     adv_targ = Variable(adv_targ)
                     ratio = torch.exp(action_log_probs - Variable(old_action_log_probs_batch))

+ 2 - 13
pytorch_rl/model.py

@@ -82,19 +82,8 @@ class RecMLPPolicy(FFPolicy):
         x = self.a_fc2(x)
         x = F.tanh(x)
 
-        if hasattr(self, 'gru'):
-            if inputs.size(0) == states.size(0):
-                x = states = self.gru(x, states * masks)
-            else:
-                x = x.view(-1, states.size(0), x.size(1))
-                masks = masks.view(-1, states.size(0), 1)
-                outputs = []
-                # For every element in the batch
-                for i in range(x.size(0)):
-                    hx = states = self.gru(x[i], states * masks[i])
-                    outputs.append(hx)
-                x = torch.cat(outputs, 0)
-
+        assert inputs.size(0) == states.size(0)
+        x = states = self.gru(x, states * masks)
         actions = x
 
         x = self.v_fc1(inputs)