Prechádzať zdrojové kódy

RNN training working

Maxime Chevalier-Boisvert 7 rokov pred
rodič
commit
ce5417474e
3 zmenil súbory, kde vykonal 23 pridanie a 39 odobranie
  1. 2 2
      pytorch_rl/arguments.py
  2. 21 8
      pytorch_rl/model.py
  3. 0 29
      pytorch_rl/utils.py

+ 2 - 2
pytorch_rl/arguments.py

@@ -45,8 +45,8 @@ def get_args():
                         help='save interval, one save per n updates (default: 10)')
     parser.add_argument('--vis-interval', type=int, default=100,
                         help='vis interval, one log per n updates (default: 100)')
-    parser.add_argument('--num-frames', type=int, default=10e6,
-                        help='number of frames to train (default: 10e6)')
+    parser.add_argument('--num-frames', type=int, default=10e7,
+                        help='number of frames to train (default: 10e7)')
     parser.add_argument('--env-name', default='PongNoFrameskip-v4',
                         help='environment to train on (default: PongNoFrameskip-v4)')
     parser.add_argument('--log-dir', default='/tmp/gym/',

+ 21 - 8
pytorch_rl/model.py

@@ -3,8 +3,7 @@ from functools import reduce
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-from distributions import Categorical, DiagGaussian
-from utils import orthogonal
+from distributions import Categorical
 
 class FFPolicy(nn.Module):
     def __init__(self):
@@ -39,10 +38,13 @@ class Policy(FFPolicy):
         assert action_space.__class__.__name__ == "Discrete"
         num_outputs = action_space.n
 
-        self.fc1 = nn.Linear(num_inputs, 128)
+        # Input size, hidden size, num layers
+        self.textGRU = nn.GRU(27, 128, 1)
+
+        self.fc1 = nn.Linear(8339, 128)
         self.fc2 = nn.Linear(128, 128)
 
-        # Input size, hidden state size
+        # Input size, hidden size
         self.gru = nn.GRUCell(128, 128)
 
         self.a_fc1 = nn.Linear(128, 128)
@@ -66,8 +68,8 @@ class Policy(FFPolicy):
     def reset_parameters(self):
         self.apply(weights_init_mlp)
 
-        orthogonal(self.gru.weight_ih.data)
-        orthogonal(self.gru.weight_hh.data)
+        nn.init.orthogonal(self.gru.weight_ih.data)
+        nn.init.orthogonal(self.gru.weight_hh.data)
         self.gru.bias_ih.data.fill_(0)
         self.gru.bias_hh.data.fill_(0)
 
@@ -75,8 +77,19 @@ class Policy(FFPolicy):
             self.dist.fc_mean.weight.data.mul_(0.01)
 
     def forward(self, inputs, states, masks):
-        batch_numel = reduce(operator.mul, inputs.size()[1:], 1)
-        inputs = inputs.view(-1, batch_numel)
+        batch_size = inputs.size()[0]
+        image = inputs[:, 0, :147]
+        text = inputs[:, 0, 147:]
+
+        # input (seq_len, batch, input_size)
+        # output (seq_len, batch, hidden_size * num_directions)
+        text = text.contiguous().view(batch_size, -1, 27)
+        text = text.transpose(0, 1)
+        output, hn = self.textGRU(text)
+        output = output.transpose(0, 1)
+        output = output.contiguous().view(batch_size, -1)
+
+        inputs = torch.cat((image, output), dim=1)
 
         x = self.fc1(inputs)
         x = F.tanh(x)

+ 0 - 29
pytorch_rl/utils.py

@@ -1,7 +1,6 @@
 import torch
 import torch.nn as nn
 
-
 # Necessary for my KFAC implementation.
 class AddBias(nn.Module):
     def __init__(self, bias):
@@ -15,31 +14,3 @@ class AddBias(nn.Module):
             bias = self._bias.t().view(1, -1, 1, 1)
 
         return x + bias
-
-# A temporary solution from the master branch.
-# https://github.com/pytorch/pytorch/blob/7752fe5d4e50052b3b0bbc9109e599f8157febc0/torch/nn/init.py#L312
-# Remove after the next version of PyTorch gets release.
-def orthogonal(tensor, gain=1):
-    if tensor.ndimension() < 2:
-        raise ValueError("Only tensors with 2 or more dimensions are supported")
-
-    rows = tensor.size(0)
-    cols = tensor[0].numel()
-    flattened = torch.Tensor(rows, cols).normal_(0, 1)
-
-    if rows < cols:
-        flattened.t_()
-
-    # Compute the qr factorization
-    q, r = torch.qr(flattened)
-    # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
-    d = torch.diag(r, 0)
-    ph = d.sign()
-    q *= ph.expand_as(q)
-
-    if rows < cols:
-        q.t_()
-
-    tensor.view_as(q).copy_(q)
-    tensor.mul_(gain)
-    return tensor