model.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import operator
  2. from functools import reduce
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from distributions import Categorical, DiagGaussian
  7. from utils import orthogonal
  8. class FFPolicy(nn.Module):
  9. def __init__(self):
  10. super().__init__()
  11. def forward(self, inputs, states, masks):
  12. raise NotImplementedError
  13. def act(self, inputs, states, masks, deterministic=False):
  14. value, x, states = self(inputs, states, masks)
  15. action = self.dist.sample(x, deterministic=deterministic)
  16. action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(x, action)
  17. return value, action, action_log_probs, states
  18. def evaluate_actions(self, inputs, states, masks, actions):
  19. value, x, states = self(inputs, states, masks)
  20. action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(x, actions)
  21. return value, action_log_probs, dist_entropy, states
  22. def weights_init_mlp(m):
  23. classname = m.__class__.__name__
  24. if classname.find('Linear') != -1:
  25. nn.init.xavier_normal(m.weight)
  26. if m.bias is not None:
  27. m.bias.data.fill_(0)
  28. class Policy(FFPolicy):
  29. def __init__(self, num_inputs, action_space):
  30. super().__init__()
  31. self.action_space = action_space
  32. assert action_space.__class__.__name__ == "Discrete"
  33. num_outputs = action_space.n
  34. self.fc1 = nn.Linear(num_inputs, 128)
  35. self.fc2 = nn.Linear(128, 128)
  36. # Input size, hidden state size
  37. self.gru = nn.GRUCell(128, 128)
  38. self.a_fc1 = nn.Linear(128, 128)
  39. self.a_fc2 = nn.Linear(128, 128)
  40. self.dist = Categorical(128, num_outputs)
  41. self.v_fc1 = nn.Linear(128, 128)
  42. self.v_fc2 = nn.Linear(128, 128)
  43. self.v_fc3 = nn.Linear(128, 1)
  44. self.train()
  45. self.reset_parameters()
  46. @property
  47. def state_size(self):
  48. """
  49. Size of the recurrent state of the model (propagated between steps)
  50. """
  51. return 128
  52. def reset_parameters(self):
  53. self.apply(weights_init_mlp)
  54. orthogonal(self.gru.weight_ih.data)
  55. orthogonal(self.gru.weight_hh.data)
  56. self.gru.bias_ih.data.fill_(0)
  57. self.gru.bias_hh.data.fill_(0)
  58. if self.dist.__class__.__name__ == "DiagGaussian":
  59. self.dist.fc_mean.weight.data.mul_(0.01)
  60. def forward(self, inputs, states, masks):
  61. batch_numel = reduce(operator.mul, inputs.size()[1:], 1)
  62. inputs = inputs.view(-1, batch_numel)
  63. x = self.fc1(inputs)
  64. x = F.tanh(x)
  65. x = self.fc2(x)
  66. x = F.tanh(x)
  67. assert inputs.size(0) == states.size(0)
  68. states = self.gru(x, states * masks)
  69. x = self.a_fc1(states)
  70. x = F.tanh(x)
  71. x = self.a_fc2(x)
  72. actions = x
  73. x = self.v_fc1(states)
  74. x = F.tanh(x)
  75. x = self.v_fc2(x)
  76. x = F.tanh(x)
  77. x = self.v_fc3(x)
  78. value = x
  79. return value, actions, states