storage.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import torch
  2. from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
  3. class RolloutStorage(object):
  4. def __init__(self, num_steps, num_processes, obs_shape, action_space, state_size):
  5. self.observations = torch.zeros(num_steps + 1, num_processes, *obs_shape)
  6. self.states = torch.zeros(num_steps + 1, num_processes, state_size)
  7. self.rewards = torch.zeros(num_steps, num_processes, 1)
  8. self.value_preds = torch.zeros(num_steps + 1, num_processes, 1)
  9. self.returns = torch.zeros(num_steps + 1, num_processes, 1)
  10. self.action_log_probs = torch.zeros(num_steps, num_processes, 1)
  11. if action_space.__class__.__name__ == 'Discrete':
  12. action_shape = 1
  13. else:
  14. action_shape = action_space.shape[0]
  15. self.actions = torch.zeros(num_steps, num_processes, action_shape)
  16. if action_space.__class__.__name__ == 'Discrete':
  17. self.actions = self.actions.long()
  18. self.masks = torch.ones(num_steps + 1, num_processes, 1)
  19. def cuda(self):
  20. self.observations = self.observations.cuda()
  21. self.states = self.states.cuda()
  22. self.rewards = self.rewards.cuda()
  23. self.value_preds = self.value_preds.cuda()
  24. self.returns = self.returns.cuda()
  25. self.action_log_probs = self.action_log_probs.cuda()
  26. self.actions = self.actions.cuda()
  27. self.masks = self.masks.cuda()
  28. def insert(self, step, current_obs, state, action, action_log_prob, value_pred, reward, mask):
  29. self.observations[step + 1].copy_(current_obs)
  30. self.states[step + 1].copy_(state)
  31. self.actions[step].copy_(action)
  32. self.action_log_probs[step].copy_(action_log_prob)
  33. self.value_preds[step].copy_(value_pred)
  34. self.rewards[step].copy_(reward)
  35. self.masks[step + 1].copy_(mask)
  36. def after_update(self):
  37. self.observations[0].copy_(self.observations[-1])
  38. self.states[0].copy_(self.states[-1])
  39. self.masks[0].copy_(self.masks[-1])
  40. def compute_returns(self, next_value, use_gae, gamma, tau):
  41. if use_gae:
  42. self.value_preds[-1] = next_value
  43. gae = 0
  44. for step in reversed(range(self.rewards.size(0))):
  45. delta = self.rewards[step] + gamma * self.value_preds[step + 1] * self.masks[step + 1] - self.value_preds[step]
  46. gae = delta + gamma * tau * self.masks[step + 1] * gae
  47. self.returns[step] = gae + self.value_preds[step]
  48. else:
  49. self.returns[-1] = next_value
  50. for step in reversed(range(self.rewards.size(0))):
  51. self.returns[step] = self.returns[step + 1] * \
  52. gamma * self.masks[step + 1] + self.rewards[step]
  53. def feed_forward_generator(self, advantages, num_mini_batch):
  54. num_steps, num_processes = self.rewards.size()[0:2]
  55. batch_size = num_processes * num_steps
  56. mini_batch_size = batch_size // num_mini_batch
  57. sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)
  58. for indices in sampler:
  59. indices = torch.LongTensor(indices)
  60. if advantages.is_cuda:
  61. indices = indices.cuda()
  62. observations_batch = self.observations[:-1].view(-1,
  63. *self.observations.size()[2:])[indices]
  64. states_batch = self.states[:-1].view(-1, self.states.size(-1))[indices]
  65. actions_batch = self.actions.view(-1, self.actions.size(-1))[indices]
  66. return_batch = self.returns[:-1].view(-1, 1)[indices]
  67. masks_batch = self.masks[:-1].view(-1, 1)[indices]
  68. old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices]
  69. adv_targ = advantages.view(-1, 1)[indices]
  70. yield observations_batch, states_batch, actions_batch, \
  71. return_batch, masks_batch, old_action_log_probs_batch, adv_targ
  72. def recurrent_generator(self, advantages, num_mini_batch):
  73. num_processes = self.rewards.size(1)
  74. num_envs_per_batch = num_processes // num_mini_batch
  75. perm = torch.randperm(num_processes)
  76. for start_ind in range(0, num_processes, num_envs_per_batch):
  77. observations_batch = []
  78. states_batch = []
  79. actions_batch = []
  80. return_batch = []
  81. masks_batch = []
  82. old_action_log_probs_batch = []
  83. adv_targ = []
  84. for offset in range(num_envs_per_batch):
  85. ind = perm[start_ind + offset]
  86. observations_batch.append(self.observations[:-1, ind])
  87. states_batch.append(self.states[0:1, ind])
  88. actions_batch.append(self.actions[:, ind])
  89. return_batch.append(self.returns[:-1, ind])
  90. masks_batch.append(self.masks[:-1, ind])
  91. old_action_log_probs_batch.append(self.action_log_probs[:, ind])
  92. adv_targ.append(advantages[:, ind])
  93. observations_batch = torch.cat(observations_batch, 0)
  94. states_batch = torch.cat(states_batch, 0)
  95. actions_batch = torch.cat(actions_batch, 0)
  96. return_batch = torch.cat(return_batch, 0)
  97. masks_batch = torch.cat(masks_batch, 0)
  98. old_action_log_probs_batch = torch.cat(old_action_log_probs_batch, 0)
  99. adv_targ = torch.cat(adv_targ, 0)
  100. yield observations_batch, states_batch, actions_batch, \
  101. return_batch, masks_batch, old_action_log_probs_batch, adv_targ