浏览代码

Added recurrent MLP policy

Maxime Chevalier-Boisvert 7 年之前
父节点
当前提交
ca85d1086d
共有 2 个文件被更改,包括 138 次插入66 次删除
  1. 4 6
      pytorch_rl/main.py
  2. 134 60
      pytorch_rl/model.py

+ 4 - 6
pytorch_rl/main.py

@@ -19,7 +19,7 @@ from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
 from baselines.common.vec_env.vec_normalize import VecNormalize
 from envs import make_env
 from kfac import KFACOptimizer
-from model import CNNPolicy, MLPPolicy
+from model import RecMLPPolicy, MLPPolicy, CNNPolicy
 from storage import RolloutStorage
 from visualize import visdom_plot
 
@@ -27,8 +27,7 @@ args = get_args()
 
 assert args.algo in ['a2c', 'ppo', 'acktr']
 if args.recurrent_policy:
-    assert args.algo in ['a2c', 'ppo'], \
-        'Recurrent policy is not implemented for ACKTR'
+    assert args.algo in ['a2c', 'ppo'], 'Recurrent policy is not implemented for ACKTR'
 
 num_updates = int(args.num_frames) // args.num_steps // args.num_processes
 
@@ -43,7 +42,6 @@ except OSError:
     for f in files:
         os.remove(f)
 
-
 def main():
     print("#######")
     print("WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards")
@@ -76,9 +74,9 @@ def main():
 
     if len(obs_shape) == 3 and obs_numel > 1024:
         actor_critic = CNNPolicy(obs_shape[0], envs.action_space, args.recurrent_policy)
+    elif args.recurrent_policy:
+        actor_critic = RecMLPPolicy(obs_numel, envs.action_space)
     else:
-        assert not args.recurrent_policy, \
-            "Recurrent policy is not implemented for the MLP controller"
         actor_critic = MLPPolicy(obs_numel, envs.action_space)
 
     # Maxime: log some info about the model and its size

+ 134 - 60
pytorch_rl/model.py

@@ -6,13 +6,6 @@ import torch.nn.functional as F
 from distributions import Categorical, DiagGaussian
 from utils import orthogonal
 
-def weights_init(m):
-    classname = m.__class__.__name__
-    if classname.find('Conv') != -1 or classname.find('Linear') != -1:
-        orthogonal(m.weight.data)
-        if m.bias is not None:
-            m.bias.data.fill_(0)
-
 class FFPolicy(nn.Module):
     def __init__(self):
         super(FFPolicy, self).__init__()
@@ -31,71 +24,63 @@ class FFPolicy(nn.Module):
         action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(x, actions)
         return value, action_log_probs, dist_entropy, states
 
+def weights_init_mlp(m):
+    classname = m.__class__.__name__
+    if classname.find('Linear') != -1:
+        m.weight.data.normal_(0, 1)
+        m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
+        if m.bias is not None:
+            m.bias.data.fill_(0)
 
-class CNNPolicy(FFPolicy):
-    def __init__(self, num_inputs, action_space, use_gru):
-        super(CNNPolicy, self).__init__()
-        self.conv1 = nn.Conv2d(num_inputs, 32, 8, stride=4)
-        self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
-        self.conv3 = nn.Conv2d(64, 32, 3, stride=1)
+class RecMLPPolicy(FFPolicy):
+    def __init__(self, num_inputs, action_space):
+        super(RecMLPPolicy, self).__init__()
 
-        self.linear1 = nn.Linear(32 * 7 * 7, 512)
+        self.action_space = action_space
+        assert action_space.__class__.__name__ == "Discrete"
+        num_outputs = action_space.n
 
-        if use_gru:
-            self.gru = nn.GRUCell(512, 512)
+        self.a_fc1 = nn.Linear(num_inputs, 64)
+        self.a_fc2 = nn.Linear(64, 64)
 
-        self.critic_linear = nn.Linear(512, 1)
+        self.v_fc1 = nn.Linear(num_inputs, 64)
+        self.v_fc2 = nn.Linear(64, 64)
+        self.v_fc3 = nn.Linear(64, 1)
 
-        if action_space.__class__.__name__ == "Discrete":
-            num_outputs = action_space.n
-            self.dist = Categorical(512, num_outputs)
-        elif action_space.__class__.__name__ == "Box":
-            num_outputs = action_space.shape[0]
-            self.dist = DiagGaussian(512, num_outputs)
-        else:
-            raise NotImplementedError
+        # Input size, hidden size
+        self.gru = nn.GRUCell(64, 64)
+
+        self.dist = Categorical(64, num_outputs)
 
         self.train()
         self.reset_parameters()
 
     @property
     def state_size(self):
-        if hasattr(self, 'gru'):
-            return 512
-        else:
-            return 1
+        """
+        Size of the recurrent state of the model (propagated between steps
+        """
+        return 64
 
     def reset_parameters(self):
-        self.apply(weights_init)
-
-        relu_gain = nn.init.calculate_gain('relu')
-        self.conv1.weight.data.mul_(relu_gain)
-        self.conv2.weight.data.mul_(relu_gain)
-        self.conv3.weight.data.mul_(relu_gain)
-        self.linear1.weight.data.mul_(relu_gain)
+        self.apply(weights_init_mlp)
 
-        if hasattr(self, 'gru'):
-            orthogonal(self.gru.weight_ih.data)
-            orthogonal(self.gru.weight_hh.data)
-            self.gru.bias_ih.data.fill_(0)
-            self.gru.bias_hh.data.fill_(0)
+        orthogonal(self.gru.weight_ih.data)
+        orthogonal(self.gru.weight_hh.data)
+        self.gru.bias_ih.data.fill_(0)
+        self.gru.bias_hh.data.fill_(0)
 
         if self.dist.__class__.__name__ == "DiagGaussian":
             self.dist.fc_mean.weight.data.mul_(0.01)
 
     def forward(self, inputs, states, masks):
-        x = self.conv1(inputs / 255.0)
-        x = F.relu(x)
-
-        x = self.conv2(x)
-        x = F.relu(x)
-
-        x = self.conv3(x)
-        x = F.relu(x)
+        batch_numel = reduce(operator.mul, inputs.size()[1:], 1)
+        inputs = inputs.view(-1, batch_numel)
 
-        x = x.view(-1, 32 * 7 * 7)
-        x = self.linear1(x)
-        x = F.relu(x)
+        x = self.a_fc1(inputs)
+        x = F.tanh(x)
+        x = self.a_fc2(x)
+        x = F.tanh(x)
 
         if hasattr(self, 'gru'):
             if inputs.size(0) == states.size(0):
@@ -104,19 +89,22 @@ class CNNPolicy(FFPolicy):
                 x = x.view(-1, states.size(0), x.size(1))
                 masks = masks.view(-1, states.size(0), 1)
                 outputs = []
+                # For every element in the batch
                 for i in range(x.size(0)):
                     hx = states = self.gru(x[i], states * masks[i])
                     outputs.append(hx)
                 x = torch.cat(outputs, 0)
-        return self.critic_linear(x), x, states
 
-def weights_init_mlp(m):
-    classname = m.__class__.__name__
-    if classname.find('Linear') != -1:
-        m.weight.data.normal_(0, 1)
-        m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
-        if m.bias is not None:
-            m.bias.data.fill_(0)
+        actions = x
+
+        x = self.v_fc1(inputs)
+        x = F.tanh(x)
+        x = self.v_fc2(x)
+        x = F.tanh(x)
+        x = self.v_fc3(x)
+        value = x
+
+        return value, actions, states
 
 class MLPPolicy(FFPolicy):
     def __init__(self, num_inputs, action_space):
@@ -181,3 +169,89 @@ class MLPPolicy(FFPolicy):
         x = F.tanh(x)
 
         return value, x, states
+
+def weights_init_cnn(m):
+    classname = m.__class__.__name__
+    if classname.find('Conv') != -1 or classname.find('Linear') != -1:
+        orthogonal(m.weight.data)
+        if m.bias is not None:
+            m.bias.data.fill_(0)
+
+class CNNPolicy(FFPolicy):
+    def __init__(self, num_inputs, action_space, use_gru):
+        super(CNNPolicy, self).__init__()
+        self.conv1 = nn.Conv2d(num_inputs, 32, 8, stride=4)
+        self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
+        self.conv3 = nn.Conv2d(64, 32, 3, stride=1)
+
+        self.linear1 = nn.Linear(32 * 7 * 7, 512)
+
+        if use_gru:
+            self.gru = nn.GRUCell(512, 512)
+
+        self.critic_linear = nn.Linear(512, 1)
+
+        if action_space.__class__.__name__ == "Discrete":
+            num_outputs = action_space.n
+            self.dist = Categorical(512, num_outputs)
+        elif action_space.__class__.__name__ == "Box":
+            num_outputs = action_space.shape[0]
+            self.dist = DiagGaussian(512, num_outputs)
+        else:
+            raise NotImplementedError
+
+        self.train()
+        self.reset_parameters()
+
+    @property
+    def state_size(self):
+        if hasattr(self, 'gru'):
+            return 512
+        else:
+            return 1
+
+    def reset_parameters(self):
+        self.apply(weights_init_cnn)
+
+        relu_gain = nn.init.calculate_gain('relu')
+        self.conv1.weight.data.mul_(relu_gain)
+        self.conv2.weight.data.mul_(relu_gain)
+        self.conv3.weight.data.mul_(relu_gain)
+        self.linear1.weight.data.mul_(relu_gain)
+
+        if hasattr(self, 'gru'):
+            orthogonal(self.gru.weight_ih.data)
+            orthogonal(self.gru.weight_hh.data)
+            self.gru.bias_ih.data.fill_(0)
+            self.gru.bias_hh.data.fill_(0)
+
+        if self.dist.__class__.__name__ == "DiagGaussian":
+            self.dist.fc_mean.weight.data.mul_(0.01)
+
+    def forward(self, inputs, states, masks):
+        x = self.conv1(inputs / 255.0)
+        x = F.relu(x)
+
+        x = self.conv2(x)
+        x = F.relu(x)
+
+        x = self.conv3(x)
+        x = F.relu(x)
+
+        x = x.view(-1, 32 * 7 * 7)
+        x = self.linear1(x)
+        x = F.relu(x)
+
+        if hasattr(self, 'gru'):
+            if inputs.size(0) == states.size(0):
+                x = states = self.gru(x, states * masks)
+            else:
+                x = x.view(-1, states.size(0), x.size(1))
+                masks = masks.view(-1, states.size(0), 1)
+                outputs = []
+                for i in range(x.size(0)):
+                    hx = states = self.gru(x[i], states * masks[i])
+                    outputs.append(hx)
+                x = torch.cat(outputs, 0)
+
+        return self.critic_linear(x), x, states