distributions.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.autograd import Variable
  6. from utils import AddBias
  7. class Categorical(nn.Module):
  8. def __init__(self, num_inputs, num_outputs):
  9. super(Categorical, self).__init__()
  10. self.linear = nn.Linear(num_inputs, num_outputs)
  11. def forward(self, x):
  12. x = self.linear(x)
  13. return x
  14. def sample(self, x, deterministic):
  15. x = self(x)
  16. probs = F.softmax(x, dim=1)
  17. if deterministic is False:
  18. action = probs.multinomial()
  19. else:
  20. action = probs.max(1, keepdim=True)[1]
  21. return action
  22. def logprobs_and_entropy(self, x, actions):
  23. x = self(x)
  24. log_probs = F.log_softmax(x, dim=1)
  25. probs = F.softmax(x, dim=1)
  26. action_log_probs = log_probs.gather(1, actions)
  27. dist_entropy = -(log_probs * probs).sum(-1).mean()
  28. return action_log_probs, dist_entropy
  29. class DiagGaussian(nn.Module):
  30. def __init__(self, num_inputs, num_outputs):
  31. super(DiagGaussian, self).__init__()
  32. self.fc_mean = nn.Linear(num_inputs, num_outputs)
  33. self.logstd = AddBias(torch.zeros(num_outputs))
  34. def forward(self, x):
  35. action_mean = self.fc_mean(x)
  36. # An ugly hack for my KFAC implementation.
  37. zeros = Variable(torch.zeros(action_mean.size()), volatile=x.volatile)
  38. if x.is_cuda:
  39. zeros = zeros.cuda()
  40. action_logstd = self.logstd(zeros)
  41. return action_mean, action_logstd
  42. def sample(self, x, deterministic):
  43. action_mean, action_logstd = self(x)
  44. action_std = action_logstd.exp()
  45. if deterministic is False:
  46. noise = Variable(torch.randn(action_std.size()))
  47. if action_std.is_cuda:
  48. noise = noise.cuda()
  49. action = action_mean + action_std * noise
  50. else:
  51. action = action_mean
  52. return action
  53. def logprobs_and_entropy(self, x, actions):
  54. action_mean, action_logstd = self(x)
  55. action_std = action_logstd.exp()
  56. action_log_probs = -0.5 * ((actions - action_mean) / action_std).pow(2) - 0.5 * math.log(2 * math.pi) - action_logstd
  57. action_log_probs = action_log_probs.sum(-1, keepdim=True)
  58. dist_entropy = 0.5 + 0.5 * math.log(2 * math.pi) + action_logstd
  59. dist_entropy = dist_entropy.sum(-1).mean()
  60. return action_log_probs, dist_entropy