model.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import operator
  2. from functools import reduce
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from distributions import Categorical
  7. class FFPolicy(nn.Module):
  8. def __init__(self):
  9. super().__init__()
  10. def forward(self, inputs, states, masks):
  11. raise NotImplementedError
  12. def act(self, inputs, states, masks, deterministic=False):
  13. value, x, states = self(inputs, states, masks)
  14. action = self.dist.sample(x, deterministic=deterministic)
  15. action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(x, action)
  16. return value, action, action_log_probs, states
  17. def evaluate_actions(self, inputs, states, masks, actions):
  18. value, x, states = self(inputs, states, masks)
  19. action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(x, actions)
  20. return value, action_log_probs, dist_entropy, states
  21. def weights_init_mlp(m):
  22. classname = m.__class__.__name__
  23. if classname.find('Linear') != -1:
  24. nn.init.xavier_normal(m.weight)
  25. if m.bias is not None:
  26. m.bias.data.fill_(0)
  27. class Policy(FFPolicy):
  28. def __init__(self, num_inputs, action_space):
  29. super().__init__()
  30. self.action_space = action_space
  31. assert action_space.__class__.__name__ == "Discrete"
  32. num_outputs = action_space.n
  33. # Input size, hidden size, num layers
  34. self.textGRU = nn.GRU(27, 128, 1)
  35. self.fc1 = nn.Linear(8339, 128)
  36. self.fc2 = nn.Linear(128, 128)
  37. # Input size, hidden size
  38. self.gru = nn.GRUCell(128, 128)
  39. self.a_fc1 = nn.Linear(128, 128)
  40. self.a_fc2 = nn.Linear(128, 128)
  41. self.dist = Categorical(128, num_outputs)
  42. self.v_fc1 = nn.Linear(128, 128)
  43. self.v_fc2 = nn.Linear(128, 128)
  44. self.v_fc3 = nn.Linear(128, 1)
  45. self.train()
  46. self.reset_parameters()
  47. @property
  48. def state_size(self):
  49. """
  50. Size of the recurrent state of the model (propagated between steps)
  51. """
  52. return 128
  53. def reset_parameters(self):
  54. self.apply(weights_init_mlp)
  55. nn.init.orthogonal(self.gru.weight_ih.data)
  56. nn.init.orthogonal(self.gru.weight_hh.data)
  57. self.gru.bias_ih.data.fill_(0)
  58. self.gru.bias_hh.data.fill_(0)
  59. if self.dist.__class__.__name__ == "DiagGaussian":
  60. self.dist.fc_mean.weight.data.mul_(0.01)
  61. def forward(self, inputs, states, masks):
  62. batch_size = inputs.size()[0]
  63. image = inputs[:, 0, :147]
  64. text = inputs[:, 0, 147:]
  65. # input (seq_len, batch, input_size)
  66. # output (seq_len, batch, hidden_size * num_directions)
  67. text = text.contiguous().view(batch_size, -1, 27)
  68. text = text.transpose(0, 1)
  69. output, hn = self.textGRU(text)
  70. output = output.transpose(0, 1)
  71. output = output.contiguous().view(batch_size, -1)
  72. inputs = torch.cat((image, output), dim=1)
  73. x = self.fc1(inputs)
  74. x = F.tanh(x)
  75. x = self.fc2(x)
  76. x = F.tanh(x)
  77. assert inputs.size(0) == states.size(0)
  78. states = self.gru(x, states * masks)
  79. x = self.a_fc1(states)
  80. x = F.tanh(x)
  81. x = self.a_fc2(x)
  82. actions = x
  83. x = self.v_fc1(states)
  84. x = F.tanh(x)
  85. x = self.v_fc2(x)
  86. x = F.tanh(x)
  87. x = self.v_fc3(x)
  88. value = x
  89. return value, actions, states