12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- import math
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.autograd import Variable
- from utils import AddBias
- class Categorical(nn.Module):
- def __init__(self, num_inputs, num_outputs):
- super(Categorical, self).__init__()
- self.linear = nn.Linear(num_inputs, num_outputs)
- def forward(self, x):
- x = self.linear(x)
- return x
- def sample(self, x, deterministic):
- x = self(x)
- probs = F.softmax(x)
- if deterministic is False:
- action = probs.multinomial()
- else:
- action = probs.max(1, keepdim=True)[1]
- return action
- def logprobs_and_entropy(self, x, actions):
- x = self(x)
- log_probs = F.log_softmax(x)
- probs = F.softmax(x)
- action_log_probs = log_probs.gather(1, actions)
- dist_entropy = -(log_probs * probs).sum(-1).mean()
- return action_log_probs, dist_entropy
- class DiagGaussian(nn.Module):
- def __init__(self, num_inputs, num_outputs):
- super(DiagGaussian, self).__init__()
- self.fc_mean = nn.Linear(num_inputs, num_outputs)
- self.logstd = AddBias(torch.zeros(num_outputs))
- def forward(self, x):
- action_mean = self.fc_mean(x)
- # An ugly hack for my KFAC implementation.
- zeros = Variable(torch.zeros(action_mean.size()), volatile=x.volatile)
- if x.is_cuda:
- zeros = zeros.cuda()
- action_logstd = self.logstd(zeros)
- return action_mean, action_logstd
- def sample(self, x, deterministic):
- action_mean, action_logstd = self(x)
- action_std = action_logstd.exp()
- if deterministic is False:
- noise = Variable(torch.randn(action_std.size()))
- if action_std.is_cuda:
- noise = noise.cuda()
- action = action_mean + action_std * noise
- else:
- action = action_mean
- return action
- def logprobs_and_entropy(self, x, actions):
- action_mean, action_logstd = self(x)
- action_std = action_logstd.exp()
- action_log_probs = -0.5 * ((actions - action_mean) / action_std).pow(2) - 0.5 * math.log(2 * math.pi) - action_logstd
- action_log_probs = action_log_probs.sum(-1, keepdim=True)
- dist_entropy = 0.5 + 0.5 * math.log(2 * math.pi) + action_logstd
- dist_entropy = dist_entropy.sum(-1).mean()
- return action_log_probs, dist_entropy
|