model.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import operator
  2. from functools import reduce
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from distributions import Categorical, DiagGaussian
  7. from utils import orthogonal
  8. class FFPolicy(nn.Module):
  9. def __init__(self):
  10. super(FFPolicy, self).__init__()
  11. def forward(self, inputs, states, masks):
  12. raise NotImplementedError
  13. def act(self, inputs, states, masks, deterministic=False):
  14. value, x, states = self(inputs, states, masks)
  15. action = self.dist.sample(x, deterministic=deterministic)
  16. action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(x, action)
  17. return value, action, action_log_probs, states
  18. def evaluate_actions(self, inputs, states, masks, actions):
  19. value, x, states = self(inputs, states, masks)
  20. action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(x, actions)
  21. return value, action_log_probs, dist_entropy, states
  22. def weights_init_mlp(m):
  23. classname = m.__class__.__name__
  24. if classname.find('Linear') != -1:
  25. m.weight.data.normal_(0, 1)
  26. m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
  27. if m.bias is not None:
  28. m.bias.data.fill_(0)
  29. class RecMLPPolicy(FFPolicy):
  30. def __init__(self, num_inputs, action_space):
  31. super(RecMLPPolicy, self).__init__()
  32. self.action_space = action_space
  33. assert action_space.__class__.__name__ == "Discrete"
  34. num_outputs = action_space.n
  35. self.a_fc1 = nn.Linear(num_inputs, 64)
  36. self.a_fc2 = nn.Linear(64, 64)
  37. self.v_fc1 = nn.Linear(num_inputs, 64)
  38. self.v_fc2 = nn.Linear(64, 64)
  39. self.v_fc3 = nn.Linear(64, 1)
  40. # Input size, hidden size
  41. self.gru = nn.GRUCell(64, 64)
  42. self.dist = Categorical(64, num_outputs)
  43. self.train()
  44. self.reset_parameters()
  45. @property
  46. def state_size(self):
  47. """
  48. Size of the recurrent state of the model (propagated between steps
  49. """
  50. return 64
  51. def reset_parameters(self):
  52. self.apply(weights_init_mlp)
  53. orthogonal(self.gru.weight_ih.data)
  54. orthogonal(self.gru.weight_hh.data)
  55. self.gru.bias_ih.data.fill_(0)
  56. self.gru.bias_hh.data.fill_(0)
  57. if self.dist.__class__.__name__ == "DiagGaussian":
  58. self.dist.fc_mean.weight.data.mul_(0.01)
  59. def forward(self, inputs, states, masks):
  60. batch_numel = reduce(operator.mul, inputs.size()[1:], 1)
  61. inputs = inputs.view(-1, batch_numel)
  62. x = self.a_fc1(inputs)
  63. x = F.tanh(x)
  64. x = self.a_fc2(x)
  65. x = F.tanh(x)
  66. if hasattr(self, 'gru'):
  67. if inputs.size(0) == states.size(0):
  68. x = states = self.gru(x, states * masks)
  69. else:
  70. x = x.view(-1, states.size(0), x.size(1))
  71. masks = masks.view(-1, states.size(0), 1)
  72. outputs = []
  73. # For every element in the batch
  74. for i in range(x.size(0)):
  75. hx = states = self.gru(x[i], states * masks[i])
  76. outputs.append(hx)
  77. x = torch.cat(outputs, 0)
  78. actions = x
  79. x = self.v_fc1(inputs)
  80. x = F.tanh(x)
  81. x = self.v_fc2(x)
  82. x = F.tanh(x)
  83. x = self.v_fc3(x)
  84. value = x
  85. return value, actions, states
  86. class MLPPolicy(FFPolicy):
  87. def __init__(self, num_inputs, action_space):
  88. super(MLPPolicy, self).__init__()
  89. self.action_space = action_space
  90. self.a_fc1 = nn.Linear(num_inputs, 64)
  91. self.a_fc2 = nn.Linear(64, 64)
  92. self.v_fc1 = nn.Linear(num_inputs, 64)
  93. self.v_fc2 = nn.Linear(64, 64)
  94. self.v_fc3 = nn.Linear(64, 1)
  95. if action_space.__class__.__name__ == "Discrete":
  96. num_outputs = action_space.n
  97. self.dist = Categorical(64, num_outputs)
  98. elif action_space.__class__.__name__ == "Box":
  99. num_outputs = action_space.shape[0]
  100. self.dist = DiagGaussian(64, num_outputs)
  101. else:
  102. raise NotImplementedError
  103. self.train()
  104. self.reset_parameters()
  105. @property
  106. def state_size(self):
  107. return 1
  108. def reset_parameters(self):
  109. self.apply(weights_init_mlp)
  110. """
  111. tanh_gain = nn.init.calculate_gain('tanh')
  112. self.a_fc1.weight.data.mul_(tanh_gain)
  113. self.a_fc2.weight.data.mul_(tanh_gain)
  114. self.v_fc1.weight.data.mul_(tanh_gain)
  115. self.v_fc2.weight.data.mul_(tanh_gain)
  116. """
  117. if self.dist.__class__.__name__ == "DiagGaussian":
  118. self.dist.fc_mean.weight.data.mul_(0.01)
  119. def forward(self, inputs, states, masks):
  120. batch_numel = reduce(operator.mul, inputs.size()[1:], 1)
  121. inputs = inputs.view(-1, batch_numel)
  122. x = self.v_fc1(inputs)
  123. x = F.tanh(x)
  124. x = self.v_fc2(x)
  125. x = F.tanh(x)
  126. x = self.v_fc3(x)
  127. value = x
  128. x = self.a_fc1(inputs)
  129. x = F.tanh(x)
  130. x = self.a_fc2(x)
  131. x = F.tanh(x)
  132. return value, x, states
  133. def weights_init_cnn(m):
  134. classname = m.__class__.__name__
  135. if classname.find('Conv') != -1 or classname.find('Linear') != -1:
  136. orthogonal(m.weight.data)
  137. if m.bias is not None:
  138. m.bias.data.fill_(0)
  139. class CNNPolicy(FFPolicy):
  140. def __init__(self, num_inputs, action_space, use_gru):
  141. super(CNNPolicy, self).__init__()
  142. self.conv1 = nn.Conv2d(num_inputs, 32, 8, stride=4)
  143. self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
  144. self.conv3 = nn.Conv2d(64, 32, 3, stride=1)
  145. self.linear1 = nn.Linear(32 * 7 * 7, 512)
  146. if use_gru:
  147. self.gru = nn.GRUCell(512, 512)
  148. self.critic_linear = nn.Linear(512, 1)
  149. if action_space.__class__.__name__ == "Discrete":
  150. num_outputs = action_space.n
  151. self.dist = Categorical(512, num_outputs)
  152. elif action_space.__class__.__name__ == "Box":
  153. num_outputs = action_space.shape[0]
  154. self.dist = DiagGaussian(512, num_outputs)
  155. else:
  156. raise NotImplementedError
  157. self.train()
  158. self.reset_parameters()
  159. @property
  160. def state_size(self):
  161. if hasattr(self, 'gru'):
  162. return 512
  163. else:
  164. return 1
  165. def reset_parameters(self):
  166. self.apply(weights_init_cnn)
  167. relu_gain = nn.init.calculate_gain('relu')
  168. self.conv1.weight.data.mul_(relu_gain)
  169. self.conv2.weight.data.mul_(relu_gain)
  170. self.conv3.weight.data.mul_(relu_gain)
  171. self.linear1.weight.data.mul_(relu_gain)
  172. if hasattr(self, 'gru'):
  173. orthogonal(self.gru.weight_ih.data)
  174. orthogonal(self.gru.weight_hh.data)
  175. self.gru.bias_ih.data.fill_(0)
  176. self.gru.bias_hh.data.fill_(0)
  177. if self.dist.__class__.__name__ == "DiagGaussian":
  178. self.dist.fc_mean.weight.data.mul_(0.01)
  179. def forward(self, inputs, states, masks):
  180. x = self.conv1(inputs / 255.0)
  181. x = F.relu(x)
  182. x = self.conv2(x)
  183. x = F.relu(x)
  184. x = self.conv3(x)
  185. x = F.relu(x)
  186. x = x.view(-1, 32 * 7 * 7)
  187. x = self.linear1(x)
  188. x = F.relu(x)
  189. if hasattr(self, 'gru'):
  190. if inputs.size(0) == states.size(0):
  191. x = states = self.gru(x, states * masks)
  192. else:
  193. x = x.view(-1, states.size(0), x.size(1))
  194. masks = masks.view(-1, states.size(0), 1)
  195. outputs = []
  196. for i in range(x.size(0)):
  197. hx = states = self.gru(x[i], states * masks[i])
  198. outputs.append(hx)
  199. x = torch.cat(outputs, 0)
  200. return self.critic_linear(x), x, states