model.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import operator
  2. from functools import reduce
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from distributions import Categorical, DiagGaussian
  7. from utils import orthogonal
  8. def weights_init(m):
  9. classname = m.__class__.__name__
  10. if classname.find('Conv') != -1 or classname.find('Linear') != -1:
  11. orthogonal(m.weight.data)
  12. if m.bias is not None:
  13. m.bias.data.fill_(0)
  14. class FFPolicy(nn.Module):
  15. def __init__(self):
  16. super(FFPolicy, self).__init__()
  17. def forward(self, inputs, states, masks):
  18. raise NotImplementedError
  19. def act(self, inputs, states, masks, deterministic=False):
  20. value, x, states = self(inputs, states, masks)
  21. action = self.dist.sample(x, deterministic=deterministic)
  22. action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(x, action)
  23. return value, action, action_log_probs, states
  24. def evaluate_actions(self, inputs, states, masks, actions):
  25. value, x, states = self(inputs, states, masks)
  26. action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(x, actions)
  27. return value, action_log_probs, dist_entropy, states
  28. class CNNPolicy(FFPolicy):
  29. def __init__(self, num_inputs, action_space, use_gru):
  30. super(CNNPolicy, self).__init__()
  31. self.conv1 = nn.Conv2d(num_inputs, 32, 8, stride=4)
  32. self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
  33. self.conv3 = nn.Conv2d(64, 32, 3, stride=1)
  34. self.linear1 = nn.Linear(32 * 7 * 7, 512)
  35. if use_gru:
  36. self.gru = nn.GRUCell(512, 512)
  37. self.critic_linear = nn.Linear(512, 1)
  38. if action_space.__class__.__name__ == "Discrete":
  39. num_outputs = action_space.n
  40. self.dist = Categorical(512, num_outputs)
  41. elif action_space.__class__.__name__ == "Box":
  42. num_outputs = action_space.shape[0]
  43. self.dist = DiagGaussian(512, num_outputs)
  44. else:
  45. raise NotImplementedError
  46. self.train()
  47. self.reset_parameters()
  48. @property
  49. def state_size(self):
  50. if hasattr(self, 'gru'):
  51. return 512
  52. else:
  53. return 1
  54. def reset_parameters(self):
  55. self.apply(weights_init)
  56. relu_gain = nn.init.calculate_gain('relu')
  57. self.conv1.weight.data.mul_(relu_gain)
  58. self.conv2.weight.data.mul_(relu_gain)
  59. self.conv3.weight.data.mul_(relu_gain)
  60. self.linear1.weight.data.mul_(relu_gain)
  61. if hasattr(self, 'gru'):
  62. orthogonal(self.gru.weight_ih.data)
  63. orthogonal(self.gru.weight_hh.data)
  64. self.gru.bias_ih.data.fill_(0)
  65. self.gru.bias_hh.data.fill_(0)
  66. if self.dist.__class__.__name__ == "DiagGaussian":
  67. self.dist.fc_mean.weight.data.mul_(0.01)
  68. def forward(self, inputs, states, masks):
  69. x = self.conv1(inputs / 255.0)
  70. x = F.relu(x)
  71. x = self.conv2(x)
  72. x = F.relu(x)
  73. x = self.conv3(x)
  74. x = F.relu(x)
  75. x = x.view(-1, 32 * 7 * 7)
  76. x = self.linear1(x)
  77. x = F.relu(x)
  78. if hasattr(self, 'gru'):
  79. if inputs.size(0) == states.size(0):
  80. x = states = self.gru(x, states * masks)
  81. else:
  82. x = x.view(-1, states.size(0), x.size(1))
  83. masks = masks.view(-1, states.size(0), 1)
  84. outputs = []
  85. for i in range(x.size(0)):
  86. hx = states = self.gru(x[i], states * masks[i])
  87. outputs.append(hx)
  88. x = torch.cat(outputs, 0)
  89. return self.critic_linear(x), x, states
  90. def weights_init_mlp(m):
  91. classname = m.__class__.__name__
  92. if classname.find('Linear') != -1:
  93. m.weight.data.normal_(0, 1)
  94. m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
  95. if m.bias is not None:
  96. m.bias.data.fill_(0)
  97. class MLPPolicy(FFPolicy):
  98. def __init__(self, num_inputs, action_space):
  99. super(MLPPolicy, self).__init__()
  100. self.action_space = action_space
  101. self.a_fc1 = nn.Linear(num_inputs, 64)
  102. self.a_fc2 = nn.Linear(64, 64)
  103. self.v_fc1 = nn.Linear(num_inputs, 64)
  104. self.v_fc2 = nn.Linear(64, 64)
  105. self.v_fc3 = nn.Linear(64, 1)
  106. if action_space.__class__.__name__ == "Discrete":
  107. num_outputs = action_space.n
  108. self.dist = Categorical(64, num_outputs)
  109. elif action_space.__class__.__name__ == "Box":
  110. num_outputs = action_space.shape[0]
  111. self.dist = DiagGaussian(64, num_outputs)
  112. else:
  113. raise NotImplementedError
  114. self.train()
  115. self.reset_parameters()
  116. @property
  117. def state_size(self):
  118. return 1
  119. def reset_parameters(self):
  120. self.apply(weights_init_mlp)
  121. """
  122. tanh_gain = nn.init.calculate_gain('tanh')
  123. self.a_fc1.weight.data.mul_(tanh_gain)
  124. self.a_fc2.weight.data.mul_(tanh_gain)
  125. self.v_fc1.weight.data.mul_(tanh_gain)
  126. self.v_fc2.weight.data.mul_(tanh_gain)
  127. """
  128. if self.dist.__class__.__name__ == "DiagGaussian":
  129. self.dist.fc_mean.weight.data.mul_(0.01)
  130. def forward(self, inputs, states, masks):
  131. batch_numel = reduce(operator.mul, inputs.size()[1:], 1)
  132. inputs = inputs.view(-1, batch_numel)
  133. x = self.v_fc1(inputs)
  134. x = F.tanh(x)
  135. x = self.v_fc2(x)
  136. x = F.tanh(x)
  137. x = self.v_fc3(x)
  138. value = x
  139. x = self.a_fc1(inputs)
  140. x = F.tanh(x)
  141. x = self.a_fc2(x)
  142. x = F.tanh(x)
  143. return value, x, states