| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 | 
							- import torch
 
- from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
 
- class RolloutStorage(object):
 
-     def __init__(self, num_steps, num_processes, obs_shape, action_space, state_size):
 
-         self.observations = torch.zeros(num_steps + 1, num_processes, *obs_shape)
 
-         self.states = torch.zeros(num_steps + 1, num_processes, state_size)
 
-         self.rewards = torch.zeros(num_steps, num_processes, 1)
 
-         self.value_preds = torch.zeros(num_steps + 1, num_processes, 1)
 
-         self.returns = torch.zeros(num_steps + 1, num_processes, 1)
 
-         self.action_log_probs = torch.zeros(num_steps, num_processes, 1)
 
-         if action_space.__class__.__name__ == 'Discrete':
 
-             action_shape = 1
 
-         else:
 
-             action_shape = action_space.shape[0]
 
-         self.actions = torch.zeros(num_steps, num_processes, action_shape)
 
-         if action_space.__class__.__name__ == 'Discrete':
 
-             self.actions = self.actions.long()
 
-         self.masks = torch.ones(num_steps + 1, num_processes, 1)
 
-     def cuda(self):
 
-         self.observations = self.observations.cuda()
 
-         self.states = self.states.cuda()
 
-         self.rewards = self.rewards.cuda()
 
-         self.value_preds = self.value_preds.cuda()
 
-         self.returns = self.returns.cuda()
 
-         self.action_log_probs = self.action_log_probs.cuda()
 
-         self.actions = self.actions.cuda()
 
-         self.masks = self.masks.cuda()
 
-     def insert(self, step, current_obs, state, action, action_log_prob, value_pred, reward, mask):
 
-         self.observations[step + 1].copy_(current_obs)
 
-         self.states[step + 1].copy_(state)
 
-         self.actions[step].copy_(action)
 
-         self.action_log_probs[step].copy_(action_log_prob)
 
-         self.value_preds[step].copy_(value_pred)
 
-         self.rewards[step].copy_(reward)
 
-         self.masks[step + 1].copy_(mask)
 
-     def after_update(self):
 
-         self.observations[0].copy_(self.observations[-1])
 
-         self.states[0].copy_(self.states[-1])
 
-         self.masks[0].copy_(self.masks[-1])
 
-     def compute_returns(self, next_value, use_gae, gamma, tau):
 
-         if use_gae:
 
-             self.value_preds[-1] = next_value
 
-             gae = 0
 
-             for step in reversed(range(self.rewards.size(0))):
 
-                 delta = self.rewards[step] + gamma * self.value_preds[step + 1] * self.masks[step + 1] - self.value_preds[step]
 
-                 gae = delta + gamma * tau * self.masks[step + 1] * gae
 
-                 self.returns[step] = gae + self.value_preds[step]
 
-         else:
 
-             self.returns[-1] = next_value
 
-             for step in reversed(range(self.rewards.size(0))):
 
-                 self.returns[step] = self.returns[step + 1] * \
 
-                     gamma * self.masks[step + 1] + self.rewards[step]
 
-     def feed_forward_generator(self, advantages, num_mini_batch):
 
-         num_steps, num_processes = self.rewards.size()[0:2]
 
-         batch_size = num_processes * num_steps
 
-         mini_batch_size = batch_size // num_mini_batch
 
-         sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)
 
-         for indices in sampler:
 
-             indices = torch.LongTensor(indices)
 
-             if advantages.is_cuda:
 
-                 indices = indices.cuda()
 
-             observations_batch = self.observations[:-1].view(-1,
 
-                                         *self.observations.size()[2:])[indices]
 
-             states_batch = self.states[:-1].view(-1, self.states.size(-1))[indices]
 
-             actions_batch = self.actions.view(-1, self.actions.size(-1))[indices]
 
-             return_batch = self.returns[:-1].view(-1, 1)[indices]
 
-             masks_batch = self.masks[:-1].view(-1, 1)[indices]
 
-             old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices]
 
-             adv_targ = advantages.view(-1, 1)[indices]
 
-             yield observations_batch, states_batch, actions_batch, \
 
-                 return_batch, masks_batch, old_action_log_probs_batch, adv_targ
 
-     def recurrent_generator(self, advantages, num_mini_batch):
 
-         num_processes = self.rewards.size(1)
 
-         num_envs_per_batch = num_processes // num_mini_batch
 
-         perm = torch.randperm(num_processes)
 
-         for start_ind in range(0, num_processes, num_envs_per_batch):
 
-             observations_batch = []
 
-             states_batch = []
 
-             actions_batch = []
 
-             return_batch = []
 
-             masks_batch = []
 
-             old_action_log_probs_batch = []
 
-             adv_targ = []
 
-             for offset in range(num_envs_per_batch):
 
-                 ind = perm[start_ind + offset]
 
-                 observations_batch.append(self.observations[:-1, ind])
 
-                 states_batch.append(self.states[0:1, ind])
 
-                 actions_batch.append(self.actions[:, ind])
 
-                 return_batch.append(self.returns[:-1, ind])
 
-                 masks_batch.append(self.masks[:-1, ind])
 
-                 old_action_log_probs_batch.append(self.action_log_probs[:, ind])
 
-                 adv_targ.append(advantages[:, ind])
 
-             observations_batch = torch.cat(observations_batch, 0)
 
-             states_batch = torch.cat(states_batch, 0)
 
-             actions_batch = torch.cat(actions_batch, 0)
 
-             return_batch = torch.cat(return_batch, 0)
 
-             masks_batch = torch.cat(masks_batch, 0)
 
-             old_action_log_probs_batch = torch.cat(old_action_log_probs_batch, 0)
 
-             adv_targ = torch.cat(adv_targ, 0)
 
-             yield observations_batch, states_batch, actions_batch, \
 
-                 return_batch, masks_batch, old_action_log_probs_batch, adv_targ
 
 
  |