|
@@ -3,8 +3,7 @@ from functools import reduce
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
-from distributions import Categorical, DiagGaussian
|
|
|
-from utils import orthogonal
|
|
|
+from distributions import Categorical
|
|
|
|
|
|
class FFPolicy(nn.Module):
|
|
|
def __init__(self):
|
|
@@ -39,10 +38,13 @@ class Policy(FFPolicy):
|
|
|
assert action_space.__class__.__name__ == "Discrete"
|
|
|
num_outputs = action_space.n
|
|
|
|
|
|
- self.fc1 = nn.Linear(num_inputs, 128)
|
|
|
+ # Input size, hidden size, num layers
|
|
|
+ self.textGRU = nn.GRU(27, 128, 1)
|
|
|
+
|
|
|
+ self.fc1 = nn.Linear(8339, 128)
|
|
|
self.fc2 = nn.Linear(128, 128)
|
|
|
|
|
|
- # Input size, hidden state size
|
|
|
+ # Input size, hidden size
|
|
|
self.gru = nn.GRUCell(128, 128)
|
|
|
|
|
|
self.a_fc1 = nn.Linear(128, 128)
|
|
@@ -66,8 +68,8 @@ class Policy(FFPolicy):
|
|
|
def reset_parameters(self):
|
|
|
self.apply(weights_init_mlp)
|
|
|
|
|
|
- orthogonal(self.gru.weight_ih.data)
|
|
|
- orthogonal(self.gru.weight_hh.data)
|
|
|
+ nn.init.orthogonal(self.gru.weight_ih.data)
|
|
|
+ nn.init.orthogonal(self.gru.weight_hh.data)
|
|
|
self.gru.bias_ih.data.fill_(0)
|
|
|
self.gru.bias_hh.data.fill_(0)
|
|
|
|
|
@@ -75,8 +77,19 @@ class Policy(FFPolicy):
|
|
|
self.dist.fc_mean.weight.data.mul_(0.01)
|
|
|
|
|
|
def forward(self, inputs, states, masks):
|
|
|
- batch_numel = reduce(operator.mul, inputs.size()[1:], 1)
|
|
|
- inputs = inputs.view(-1, batch_numel)
|
|
|
+ batch_size = inputs.size()[0]
|
|
|
+ image = inputs[:, 0, :147]
|
|
|
+ text = inputs[:, 0, 147:]
|
|
|
+
|
|
|
+ # input (seq_len, batch, input_size)
|
|
|
+ # output (seq_len, batch, hidden_size * num_directions)
|
|
|
+ text = text.contiguous().view(batch_size, -1, 27)
|
|
|
+ text = text.transpose(0, 1)
|
|
|
+ output, hn = self.textGRU(text)
|
|
|
+ output = output.transpose(0, 1)
|
|
|
+ output = output.contiguous().view(batch_size, -1)
|
|
|
+
|
|
|
+ inputs = torch.cat((image, output), dim=1)
|
|
|
|
|
|
x = self.fc1(inputs)
|
|
|
x = F.tanh(x)
|