123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- import torch
- from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
- class RolloutStorage(object):
- def __init__(self, num_steps, num_processes, obs_shape, action_space, state_size):
- self.observations = torch.zeros(num_steps + 1, num_processes, *obs_shape)
- self.states = torch.zeros(num_steps + 1, num_processes, state_size)
- self.rewards = torch.zeros(num_steps, num_processes, 1)
- self.value_preds = torch.zeros(num_steps + 1, num_processes, 1)
- self.returns = torch.zeros(num_steps + 1, num_processes, 1)
- self.action_log_probs = torch.zeros(num_steps, num_processes, 1)
- if action_space.__class__.__name__ == 'Discrete':
- action_shape = 1
- else:
- action_shape = action_space.shape[0]
- self.actions = torch.zeros(num_steps, num_processes, action_shape)
- if action_space.__class__.__name__ == 'Discrete':
- self.actions = self.actions.long()
- self.masks = torch.ones(num_steps + 1, num_processes, 1)
- def cuda(self):
- self.observations = self.observations.cuda()
- self.states = self.states.cuda()
- self.rewards = self.rewards.cuda()
- self.value_preds = self.value_preds.cuda()
- self.returns = self.returns.cuda()
- self.action_log_probs = self.action_log_probs.cuda()
- self.actions = self.actions.cuda()
- self.masks = self.masks.cuda()
- def insert(self, step, current_obs, state, action, action_log_prob, value_pred, reward, mask):
- self.observations[step + 1].copy_(current_obs)
- self.states[step + 1].copy_(state)
- self.actions[step].copy_(action)
- self.action_log_probs[step].copy_(action_log_prob)
- self.value_preds[step].copy_(value_pred)
- self.rewards[step].copy_(reward)
- self.masks[step + 1].copy_(mask)
- def after_update(self):
- self.observations[0].copy_(self.observations[-1])
- self.states[0].copy_(self.states[-1])
- self.masks[0].copy_(self.masks[-1])
- def compute_returns(self, next_value, use_gae, gamma, tau):
- if use_gae:
- self.value_preds[-1] = next_value
- gae = 0
- for step in reversed(range(self.rewards.size(0))):
- delta = self.rewards[step] + gamma * self.value_preds[step + 1] * self.masks[step + 1] - self.value_preds[step]
- gae = delta + gamma * tau * self.masks[step + 1] * gae
- self.returns[step] = gae + self.value_preds[step]
- else:
- self.returns[-1] = next_value
- for step in reversed(range(self.rewards.size(0))):
- self.returns[step] = self.returns[step + 1] * \
- gamma * self.masks[step + 1] + self.rewards[step]
- def feed_forward_generator(self, advantages, num_mini_batch):
- num_steps, num_processes = self.rewards.size()[0:2]
- batch_size = num_processes * num_steps
- mini_batch_size = batch_size // num_mini_batch
- sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)
- for indices in sampler:
- indices = torch.LongTensor(indices)
- if advantages.is_cuda:
- indices = indices.cuda()
- observations_batch = self.observations[:-1].view(-1,
- *self.observations.size()[2:])[indices]
- states_batch = self.states[:-1].view(-1, self.states.size(-1))[indices]
- actions_batch = self.actions.view(-1, self.actions.size(-1))[indices]
- return_batch = self.returns[:-1].view(-1, 1)[indices]
- masks_batch = self.masks[:-1].view(-1, 1)[indices]
- old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices]
- adv_targ = advantages.view(-1, 1)[indices]
- yield observations_batch, states_batch, actions_batch, \
- return_batch, masks_batch, old_action_log_probs_batch, adv_targ
- def recurrent_generator(self, advantages, num_mini_batch):
- num_processes = self.rewards.size(1)
- num_envs_per_batch = num_processes // num_mini_batch
- perm = torch.randperm(num_processes)
- for start_ind in range(0, num_processes, num_envs_per_batch):
- observations_batch = []
- states_batch = []
- actions_batch = []
- return_batch = []
- masks_batch = []
- old_action_log_probs_batch = []
- adv_targ = []
- for offset in range(num_envs_per_batch):
- ind = perm[start_ind + offset]
- observations_batch.append(self.observations[:-1, ind])
- states_batch.append(self.states[0:1, ind])
- actions_batch.append(self.actions[:, ind])
- return_batch.append(self.returns[:-1, ind])
- masks_batch.append(self.masks[:-1, ind])
- old_action_log_probs_batch.append(self.action_log_probs[:, ind])
- adv_targ.append(advantages[:, ind])
- observations_batch = torch.cat(observations_batch, 0)
- states_batch = torch.cat(states_batch, 0)
- actions_batch = torch.cat(actions_batch, 0)
- return_batch = torch.cat(return_batch, 0)
- masks_batch = torch.cat(masks_batch, 0)
- old_action_log_probs_batch = torch.cat(old_action_log_probs_batch, 0)
- adv_targ = torch.cat(adv_targ, 0)
- yield observations_batch, states_batch, actions_batch, \
- return_batch, masks_batch, old_action_log_probs_batch, adv_targ
|