Browse Source

Added GRU to policy, made model larger.

Maxime Chevalier-Boisvert 7 years ago
parent
commit
5a6461ff2e
3 changed files with 28 additions and 176 deletions
  1. 1 1
      pytorch_rl/arguments.py
  2. 3 15
      pytorch_rl/main.py
  3. 24 160
      pytorch_rl/model.py

+ 1 - 1
pytorch_rl/arguments.py

@@ -55,7 +55,7 @@ def get_args():
                         help='directory to save agent logs (default: ./trained_models/)')
     parser.add_argument('--no-cuda', action='store_true', default=False,
                         help='disables CUDA training')
-    parser.add_argument('--recurrent-policy', action='store_true', default=False,
+    parser.add_argument('--recurrent-policy', action='store_true', default=True,
                         help='use a recurrent policy')
     parser.add_argument('--no-vis', action='store_true', default=False,
                         help='disables visdom visualization')

+ 3 - 15
pytorch_rl/main.py

@@ -18,7 +18,7 @@ from vec_env.dummy_vec_env import DummyVecEnv
 from vec_env.subproc_vec_env import SubprocVecEnv
 from envs import make_env
 from kfac import KFACOptimizer
-from model import RecMLPPolicy, MLPPolicy, CNNPolicy
+from model import Policy
 from storage import RolloutStorage
 from visualize import visdom_plot
 
@@ -53,30 +53,18 @@ def main():
         viz = Visdom()
         win = None
 
-    envs = [make_env(args.env_name, args.seed, i, args.log_dir)
-                for i in range(args.num_processes)]
+    envs = [make_env(args.env_name, args.seed, i, args.log_dir) for i in range(args.num_processes)]
 
     if args.num_processes > 1:
         envs = SubprocVecEnv(envs)
     else:
         envs = DummyVecEnv(envs)
 
-    # Maxime: commented this out because it very much changes the behavior
-    # of the code for seemingly arbitrary reasons
-    #if len(envs.observation_space.shape) == 1:
-    #    envs = VecNormalize(envs)
-
     obs_shape = envs.observation_space.shape
     obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])
-
     obs_numel = reduce(operator.mul, obs_shape, 1)
 
-    if len(obs_shape) == 3 and obs_numel > 1024:
-        actor_critic = CNNPolicy(obs_shape[0], envs.action_space, args.recurrent_policy)
-    elif args.recurrent_policy:
-        actor_critic = RecMLPPolicy(obs_numel, envs.action_space)
-    else:
-        actor_critic = MLPPolicy(obs_numel, envs.action_space)
+    actor_critic = Policy(obs_numel, envs.action_space)
 
     # Maxime: log some info about the model and its size
     modelSize = 0

+ 24 - 160
pytorch_rl/model.py

@@ -8,7 +8,7 @@ from utils import orthogonal
 
 class FFPolicy(nn.Module):
     def __init__(self):
-        super(FFPolicy, self).__init__()
+        super().__init__()
 
     def forward(self, inputs, states, masks):
         raise NotImplementedError
@@ -27,30 +27,31 @@ class FFPolicy(nn.Module):
 def weights_init_mlp(m):
     classname = m.__class__.__name__
     if classname.find('Linear') != -1:
-        m.weight.data.normal_(0, 1)
-        m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
+        nn.init.xavier_normal(m.weight)
         if m.bias is not None:
             m.bias.data.fill_(0)
 
-class RecMLPPolicy(FFPolicy):
+class Policy(FFPolicy):
     def __init__(self, num_inputs, action_space):
-        super(RecMLPPolicy, self).__init__()
+        super().__init__()
 
         self.action_space = action_space
         assert action_space.__class__.__name__ == "Discrete"
         num_outputs = action_space.n
 
-        self.a_fc1 = nn.Linear(num_inputs, 64)
-        self.a_fc2 = nn.Linear(64, 64)
+        self.fc1 = nn.Linear(num_inputs, 128)
+        self.fc2 = nn.Linear(128, 128)
 
-        self.v_fc1 = nn.Linear(num_inputs, 64)
-        self.v_fc2 = nn.Linear(64, 64)
-        self.v_fc3 = nn.Linear(64, 1)
+        # Input size, hidden state size
+        self.gru = nn.GRUCell(128, 128)
 
-        # Input size, hidden size
-        self.gru = nn.GRUCell(64, 64)
+        self.a_fc1 = nn.Linear(128, 128)
+        self.a_fc2 = nn.Linear(128, 128)
+        self.dist = Categorical(128, num_outputs)
 
-        self.dist = Categorical(64, num_outputs)
+        self.v_fc1 = nn.Linear(128, 128)
+        self.v_fc2 = nn.Linear(128, 128)
+        self.v_fc3 = nn.Linear(128, 1)
 
         self.train()
         self.reset_parameters()
@@ -58,9 +59,9 @@ class RecMLPPolicy(FFPolicy):
     @property
     def state_size(self):
         """
-        Size of the recurrent state of the model (propagated between steps
+        Size of the recurrent state of the model (propagated between steps)
         """
-        return 64
+        return 128
 
     def reset_parameters(self):
         self.apply(weights_init_mlp)
@@ -77,161 +78,24 @@ class RecMLPPolicy(FFPolicy):
         batch_numel = reduce(operator.mul, inputs.size()[1:], 1)
         inputs = inputs.view(-1, batch_numel)
 
-        x = self.a_fc1(inputs)
+        x = self.fc1(inputs)
         x = F.tanh(x)
-        x = self.a_fc2(x)
+        x = self.fc2(x)
         x = F.tanh(x)
 
         assert inputs.size(0) == states.size(0)
-        x = states = self.gru(x, states * masks)
-        actions = x
+        states = self.gru(x, states * masks)
 
-        x = self.v_fc1(inputs)
+        x = self.a_fc1(states)
         x = F.tanh(x)
-        x = self.v_fc2(x)
-        x = F.tanh(x)
-        x = self.v_fc3(x)
-        value = x
-
-        return value, actions, states
-
-class MLPPolicy(FFPolicy):
-    def __init__(self, num_inputs, action_space):
-        super(MLPPolicy, self).__init__()
-
-        self.action_space = action_space
-
-        self.a_fc1 = nn.Linear(num_inputs, 64)
-        self.a_fc2 = nn.Linear(64, 64)
-
-        self.v_fc1 = nn.Linear(num_inputs, 64)
-        self.v_fc2 = nn.Linear(64, 64)
-        self.v_fc3 = nn.Linear(64, 1)
-
-        if action_space.__class__.__name__ == "Discrete":
-            num_outputs = action_space.n
-            self.dist = Categorical(64, num_outputs)
-        elif action_space.__class__.__name__ == "Box":
-            num_outputs = action_space.shape[0]
-            self.dist = DiagGaussian(64, num_outputs)
-        else:
-            raise NotImplementedError
-
-        self.train()
-        self.reset_parameters()
-
-    @property
-    def state_size(self):
-        return 1
-
-    def reset_parameters(self):
-        self.apply(weights_init_mlp)
-
-        """
-        tanh_gain = nn.init.calculate_gain('tanh')
-        self.a_fc1.weight.data.mul_(tanh_gain)
-        self.a_fc2.weight.data.mul_(tanh_gain)
-        self.v_fc1.weight.data.mul_(tanh_gain)
-        self.v_fc2.weight.data.mul_(tanh_gain)
-        """
-
-        if self.dist.__class__.__name__ == "DiagGaussian":
-            self.dist.fc_mean.weight.data.mul_(0.01)
-
-    def forward(self, inputs, states, masks):
-        batch_numel = reduce(operator.mul, inputs.size()[1:], 1)
-        inputs = inputs.view(-1, batch_numel)
+        x = self.a_fc2(x)
+        actions = x
 
-        x = self.v_fc1(inputs)
+        x = self.v_fc1(states)
         x = F.tanh(x)
-
         x = self.v_fc2(x)
         x = F.tanh(x)
-
         x = self.v_fc3(x)
         value = x
 
-        x = self.a_fc1(inputs)
-        x = F.tanh(x)
-
-        x = self.a_fc2(x)
-        x = F.tanh(x)
-
-        return value, x, states
-
-def weights_init_cnn(m):
-    classname = m.__class__.__name__
-    if classname.find('Conv') != -1 or classname.find('Linear') != -1:
-        orthogonal(m.weight.data)
-        if m.bias is not None:
-            m.bias.data.fill_(0)
-
-class CNNPolicy(FFPolicy):
-    def __init__(self, num_inputs, action_space, use_gru):
-        super(CNNPolicy, self).__init__()
-        self.conv1 = nn.Conv2d(num_inputs, 32, 8, stride=4)
-        self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
-        self.conv3 = nn.Conv2d(64, 32, 3, stride=1)
-
-        self.linear1 = nn.Linear(32 * 7 * 7, 512)
-
-        if use_gru:
-            self.gru = nn.GRUCell(512, 512)
-
-        self.critic_linear = nn.Linear(512, 1)
-
-        if action_space.__class__.__name__ == "Discrete":
-            num_outputs = action_space.n
-            self.dist = Categorical(512, num_outputs)
-        elif action_space.__class__.__name__ == "Box":
-            num_outputs = action_space.shape[0]
-            self.dist = DiagGaussian(512, num_outputs)
-        else:
-            raise NotImplementedError
-
-        self.train()
-        self.reset_parameters()
-
-    @property
-    def state_size(self):
-        if hasattr(self, 'gru'):
-            return 512
-        else:
-            return 1
-
-    def reset_parameters(self):
-        self.apply(weights_init_cnn)
-
-        relu_gain = nn.init.calculate_gain('relu')
-        self.conv1.weight.data.mul_(relu_gain)
-        self.conv2.weight.data.mul_(relu_gain)
-        self.conv3.weight.data.mul_(relu_gain)
-        self.linear1.weight.data.mul_(relu_gain)
-
-        if hasattr(self, 'gru'):
-            orthogonal(self.gru.weight_ih.data)
-            orthogonal(self.gru.weight_hh.data)
-            self.gru.bias_ih.data.fill_(0)
-            self.gru.bias_hh.data.fill_(0)
-
-        if self.dist.__class__.__name__ == "DiagGaussian":
-            self.dist.fc_mean.weight.data.mul_(0.01)
-
-    def forward(self, inputs, states, masks):
-        x = self.conv1(inputs / 255.0)
-        x = F.relu(x)
-
-        x = self.conv2(x)
-        x = F.relu(x)
-
-        x = self.conv3(x)
-        x = F.relu(x)
-
-        x = x.view(-1, 32 * 7 * 7)
-        x = self.linear1(x)
-        x = F.relu(x)
-
-        if hasattr(self, 'gru'):
-            x = states = self.gru(x, states * masks)
-
-        return self.critic_linear(x), x, states
+        return value, actions, states