wrappers.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import math
  2. import operator
  3. from functools import reduce
  4. import numpy as np
  5. import gym
  6. from gym import error, spaces, utils
  7. class ActionBonus(gym.core.Wrapper):
  8. """
  9. Wrapper which adds an exploration bonus.
  10. This is a reward to encourage exploration of less
  11. visited (state,action) pairs.
  12. """
  13. def __init__(self, env):
  14. super().__init__(env)
  15. self.counts = {}
  16. def step(self, action):
  17. obs, reward, done, info = self.env.step(action)
  18. env = self.unwrapped
  19. tup = (env.agentPos, env.agentDir, action)
  20. # Get the count for this (s,a) pair
  21. preCnt = 0
  22. if tup in self.counts:
  23. preCnt = self.counts[tup]
  24. # Update the count for this (s,a) pair
  25. newCnt = preCnt + 1
  26. self.counts[tup] = newCnt
  27. bonus = 1 / math.sqrt(newCnt)
  28. reward += bonus
  29. return obs, reward, done, info
  30. class StateBonus(gym.core.Wrapper):
  31. """
  32. Adds an exploration bonus based on which positions
  33. are visited on the grid.
  34. """
  35. def __init__(self, env):
  36. super().__init__(env)
  37. self.counts = {}
  38. def step(self, action):
  39. obs, reward, done, info = self.env.step(action)
  40. # Tuple based on which we index the counts
  41. # We use the position after an update
  42. env = self.unwrapped
  43. tup = (env.agentPos)
  44. # Get the count for this key
  45. preCnt = 0
  46. if tup in self.counts:
  47. preCnt = self.counts[tup]
  48. # Update the count for this key
  49. newCnt = preCnt + 1
  50. self.counts[tup] = newCnt
  51. bonus = 1 / math.sqrt(newCnt)
  52. reward += bonus
  53. return obs, reward, done, info
  54. class FlatObsWrapper(gym.core.ObservationWrapper):
  55. """
  56. Encode mission strings using a one-hot scheme,
  57. and combine these with observed images into one flat array
  58. """
  59. def __init__(self, env, maxStrLen=64):
  60. super().__init__(env)
  61. self.maxStrLen = maxStrLen
  62. self.numCharCodes = 27
  63. imgSpace = env.observation_space.spaces['image']
  64. imgSize = reduce(operator.mul, imgSpace.shape, 1)
  65. self.observation_space = spaces.Box(
  66. low=0,
  67. high=255,
  68. shape=(1, imgSize + self.numCharCodes * self.maxStrLen),
  69. dtype='uint8'
  70. )
  71. self.cachedStr = None
  72. self.cachedArray = None
  73. def observation(self, obs):
  74. image = obs['image']
  75. mission = obs['mission']
  76. # Cache the last-encoded mission string
  77. if mission != self.cachedStr:
  78. assert len(mission) <= self.maxStrLen, "mission string too long"
  79. mission = mission.lower()
  80. strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
  81. for idx, ch in enumerate(mission):
  82. if ch >= 'a' and ch <= 'z':
  83. chNo = ord(ch) - ord('a')
  84. elif ch == ' ':
  85. chNo = ord('z') - ord('a') + 1
  86. assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo)
  87. strArray[idx, chNo] = 1
  88. self.cachedStr = mission
  89. self.cachedArray = strArray
  90. obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
  91. return obs