model.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import operator
  2. from functools import reduce
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from distributions import Categorical, DiagGaussian
  7. from utils import orthogonal
  8. class FFPolicy(nn.Module):
  9. def __init__(self):
  10. super(FFPolicy, self).__init__()
  11. def forward(self, inputs, states, masks):
  12. raise NotImplementedError
  13. def act(self, inputs, states, masks, deterministic=False):
  14. value, x, states = self(inputs, states, masks)
  15. action = self.dist.sample(x, deterministic=deterministic)
  16. action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(x, action)
  17. return value, action, action_log_probs, states
  18. def evaluate_actions(self, inputs, states, masks, actions):
  19. value, x, states = self(inputs, states, masks)
  20. action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(x, actions)
  21. return value, action_log_probs, dist_entropy, states
  22. def weights_init_mlp(m):
  23. classname = m.__class__.__name__
  24. if classname.find('Linear') != -1:
  25. m.weight.data.normal_(0, 1)
  26. m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
  27. if m.bias is not None:
  28. m.bias.data.fill_(0)
  29. class RecMLPPolicy(FFPolicy):
  30. def __init__(self, num_inputs, action_space):
  31. super(RecMLPPolicy, self).__init__()
  32. self.action_space = action_space
  33. assert action_space.__class__.__name__ == "Discrete"
  34. num_outputs = action_space.n
  35. self.a_fc1 = nn.Linear(num_inputs, 64)
  36. self.a_fc2 = nn.Linear(64, 64)
  37. self.v_fc1 = nn.Linear(num_inputs, 64)
  38. self.v_fc2 = nn.Linear(64, 64)
  39. self.v_fc3 = nn.Linear(64, 1)
  40. # Input size, hidden size
  41. self.gru = nn.GRUCell(64, 64)
  42. self.dist = Categorical(64, num_outputs)
  43. self.train()
  44. self.reset_parameters()
  45. @property
  46. def state_size(self):
  47. """
  48. Size of the recurrent state of the model (propagated between steps
  49. """
  50. return 64
  51. def reset_parameters(self):
  52. self.apply(weights_init_mlp)
  53. orthogonal(self.gru.weight_ih.data)
  54. orthogonal(self.gru.weight_hh.data)
  55. self.gru.bias_ih.data.fill_(0)
  56. self.gru.bias_hh.data.fill_(0)
  57. if self.dist.__class__.__name__ == "DiagGaussian":
  58. self.dist.fc_mean.weight.data.mul_(0.01)
  59. def forward(self, inputs, states, masks):
  60. batch_numel = reduce(operator.mul, inputs.size()[1:], 1)
  61. inputs = inputs.view(-1, batch_numel)
  62. x = self.a_fc1(inputs)
  63. x = F.tanh(x)
  64. x = self.a_fc2(x)
  65. x = F.tanh(x)
  66. assert inputs.size(0) == states.size(0)
  67. x = states = self.gru(x, states * masks)
  68. actions = x
  69. x = self.v_fc1(inputs)
  70. x = F.tanh(x)
  71. x = self.v_fc2(x)
  72. x = F.tanh(x)
  73. x = self.v_fc3(x)
  74. value = x
  75. return value, actions, states
  76. class MLPPolicy(FFPolicy):
  77. def __init__(self, num_inputs, action_space):
  78. super(MLPPolicy, self).__init__()
  79. self.action_space = action_space
  80. self.a_fc1 = nn.Linear(num_inputs, 64)
  81. self.a_fc2 = nn.Linear(64, 64)
  82. self.v_fc1 = nn.Linear(num_inputs, 64)
  83. self.v_fc2 = nn.Linear(64, 64)
  84. self.v_fc3 = nn.Linear(64, 1)
  85. if action_space.__class__.__name__ == "Discrete":
  86. num_outputs = action_space.n
  87. self.dist = Categorical(64, num_outputs)
  88. elif action_space.__class__.__name__ == "Box":
  89. num_outputs = action_space.shape[0]
  90. self.dist = DiagGaussian(64, num_outputs)
  91. else:
  92. raise NotImplementedError
  93. self.train()
  94. self.reset_parameters()
  95. @property
  96. def state_size(self):
  97. return 1
  98. def reset_parameters(self):
  99. self.apply(weights_init_mlp)
  100. """
  101. tanh_gain = nn.init.calculate_gain('tanh')
  102. self.a_fc1.weight.data.mul_(tanh_gain)
  103. self.a_fc2.weight.data.mul_(tanh_gain)
  104. self.v_fc1.weight.data.mul_(tanh_gain)
  105. self.v_fc2.weight.data.mul_(tanh_gain)
  106. """
  107. if self.dist.__class__.__name__ == "DiagGaussian":
  108. self.dist.fc_mean.weight.data.mul_(0.01)
  109. def forward(self, inputs, states, masks):
  110. batch_numel = reduce(operator.mul, inputs.size()[1:], 1)
  111. inputs = inputs.view(-1, batch_numel)
  112. x = self.v_fc1(inputs)
  113. x = F.tanh(x)
  114. x = self.v_fc2(x)
  115. x = F.tanh(x)
  116. x = self.v_fc3(x)
  117. value = x
  118. x = self.a_fc1(inputs)
  119. x = F.tanh(x)
  120. x = self.a_fc2(x)
  121. x = F.tanh(x)
  122. return value, x, states
  123. def weights_init_cnn(m):
  124. classname = m.__class__.__name__
  125. if classname.find('Conv') != -1 or classname.find('Linear') != -1:
  126. orthogonal(m.weight.data)
  127. if m.bias is not None:
  128. m.bias.data.fill_(0)
  129. class CNNPolicy(FFPolicy):
  130. def __init__(self, num_inputs, action_space, use_gru):
  131. super(CNNPolicy, self).__init__()
  132. self.conv1 = nn.Conv2d(num_inputs, 32, 8, stride=4)
  133. self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
  134. self.conv3 = nn.Conv2d(64, 32, 3, stride=1)
  135. self.linear1 = nn.Linear(32 * 7 * 7, 512)
  136. if use_gru:
  137. self.gru = nn.GRUCell(512, 512)
  138. self.critic_linear = nn.Linear(512, 1)
  139. if action_space.__class__.__name__ == "Discrete":
  140. num_outputs = action_space.n
  141. self.dist = Categorical(512, num_outputs)
  142. elif action_space.__class__.__name__ == "Box":
  143. num_outputs = action_space.shape[0]
  144. self.dist = DiagGaussian(512, num_outputs)
  145. else:
  146. raise NotImplementedError
  147. self.train()
  148. self.reset_parameters()
  149. @property
  150. def state_size(self):
  151. if hasattr(self, 'gru'):
  152. return 512
  153. else:
  154. return 1
  155. def reset_parameters(self):
  156. self.apply(weights_init_cnn)
  157. relu_gain = nn.init.calculate_gain('relu')
  158. self.conv1.weight.data.mul_(relu_gain)
  159. self.conv2.weight.data.mul_(relu_gain)
  160. self.conv3.weight.data.mul_(relu_gain)
  161. self.linear1.weight.data.mul_(relu_gain)
  162. if hasattr(self, 'gru'):
  163. orthogonal(self.gru.weight_ih.data)
  164. orthogonal(self.gru.weight_hh.data)
  165. self.gru.bias_ih.data.fill_(0)
  166. self.gru.bias_hh.data.fill_(0)
  167. if self.dist.__class__.__name__ == "DiagGaussian":
  168. self.dist.fc_mean.weight.data.mul_(0.01)
  169. def forward(self, inputs, states, masks):
  170. x = self.conv1(inputs / 255.0)
  171. x = F.relu(x)
  172. x = self.conv2(x)
  173. x = F.relu(x)
  174. x = self.conv3(x)
  175. x = F.relu(x)
  176. x = x.view(-1, 32 * 7 * 7)
  177. x = self.linear1(x)
  178. x = F.relu(x)
  179. if hasattr(self, 'gru'):
  180. x = states = self.gru(x, states * masks)
  181. return self.critic_linear(x), x, states