|
@@ -15,31 +15,33 @@ class ActionBonus(gym.core.Wrapper):
|
|
|
"""
|
|
|
|
|
|
def __init__(self, env):
|
|
|
+ self.__dict__.update(vars(env)) # Pass values to super wrapper
|
|
|
super().__init__(env)
|
|
|
self.counts = {}
|
|
|
|
|
|
def step(self, action):
|
|
|
-
|
|
|
obs, reward, done, info = self.env.step(action)
|
|
|
|
|
|
env = self.unwrapped
|
|
|
- tup = (env.agentPos, env.agentDir, action)
|
|
|
+ tup = (tuple(env.agent_pos), env.agent_dir, action)
|
|
|
|
|
|
# Get the count for this (s,a) pair
|
|
|
- preCnt = 0
|
|
|
+ pre_count = 0
|
|
|
if tup in self.counts:
|
|
|
- preCnt = self.counts[tup]
|
|
|
+ pre_count = self.counts[tup]
|
|
|
|
|
|
# Update the count for this (s,a) pair
|
|
|
- newCnt = preCnt + 1
|
|
|
- self.counts[tup] = newCnt
|
|
|
-
|
|
|
- bonus = 1 / math.sqrt(newCnt)
|
|
|
+ new_count = pre_count + 1
|
|
|
+ self.counts[tup] = new_count
|
|
|
|
|
|
+ bonus = 1 / math.sqrt(new_count)
|
|
|
reward += bonus
|
|
|
|
|
|
return obs, reward, done, info
|
|
|
|
|
|
+ def reset(self, **kwargs):
|
|
|
+ return self.env.reset(**kwargs)
|
|
|
+
|
|
|
class StateBonus(gym.core.Wrapper):
|
|
|
"""
|
|
|
Adds an exploration bonus based on which positions
|
|
@@ -47,42 +49,44 @@ class StateBonus(gym.core.Wrapper):
|
|
|
"""
|
|
|
|
|
|
def __init__(self, env):
|
|
|
+ self.__dict__.update(vars(env)) # Pass values to super wrapper
|
|
|
super().__init__(env)
|
|
|
self.counts = {}
|
|
|
|
|
|
def step(self, action):
|
|
|
-
|
|
|
obs, reward, done, info = self.env.step(action)
|
|
|
|
|
|
# Tuple based on which we index the counts
|
|
|
# We use the position after an update
|
|
|
env = self.unwrapped
|
|
|
- tup = (env.agentPos)
|
|
|
+ tup = (tuple(env.agent_pos))
|
|
|
|
|
|
# Get the count for this key
|
|
|
- preCnt = 0
|
|
|
+ pre_count = 0
|
|
|
if tup in self.counts:
|
|
|
- preCnt = self.counts[tup]
|
|
|
+ pre_count = self.counts[tup]
|
|
|
|
|
|
# Update the count for this key
|
|
|
- newCnt = preCnt + 1
|
|
|
- self.counts[tup] = newCnt
|
|
|
-
|
|
|
- bonus = 1 / math.sqrt(newCnt)
|
|
|
+ new_count = pre_count + 1
|
|
|
+ self.counts[tup] = new_count
|
|
|
|
|
|
+ bonus = 1 / math.sqrt(new_count)
|
|
|
reward += bonus
|
|
|
|
|
|
return obs, reward, done, info
|
|
|
|
|
|
+ def reset(self, **kwargs):
|
|
|
+ return self.env.reset(**kwargs)
|
|
|
+
|
|
|
class ImgObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
|
Use the image as the only observation output, no language/mission.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, env):
|
|
|
+ self.__dict__.update(vars(env)) # Pass values to super wrapper
|
|
|
super().__init__(env)
|
|
|
- # Hack to pass values to super wrapper
|
|
|
- self.__dict__.update(vars(env))
|
|
|
+
|
|
|
self.observation_space = env.observation_space.spaces['image']
|
|
|
|
|
|
def observation(self, obs):
|
|
@@ -94,8 +98,9 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
|
|
|
|
def __init__(self, env):
|
|
|
+ self.__dict__.update(vars(env)) # Pass values to super wrapper
|
|
|
super().__init__(env)
|
|
|
- self.__dict__.update(vars(env)) # hack to pass values to super wrapper
|
|
|
+
|
|
|
self.observation_space = spaces.Box(
|
|
|
low=0,
|
|
|
high=255,
|
|
@@ -104,8 +109,9 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
|
|
|
)
|
|
|
|
|
|
def observation(self, obs):
|
|
|
- full_grid = self.env.grid.encode()
|
|
|
- full_grid[self.env.agent_pos[0]][self.env.agent_pos[1]] = np.array([255, self.env.agent_dir, 0])
|
|
|
+ env = self.unwrapped
|
|
|
+ full_grid = env.grid.encode()
|
|
|
+ full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([255, env.agent_dir, 0])
|
|
|
return full_grid
|
|
|
|
|
|
class FlatObsWrapper(gym.core.ObservationWrapper):
|
|
@@ -115,6 +121,7 @@ class FlatObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
|
|
|
|
def __init__(self, env, maxStrLen=64):
|
|
|
+ self.__dict__.update(vars(env)) # Pass values to super wrapper
|
|
|
super().__init__(env)
|
|
|
|
|
|
self.maxStrLen = maxStrLen
|
|
@@ -166,11 +173,11 @@ class AgentViewWrapper(gym.core.Wrapper):
|
|
|
"""
|
|
|
|
|
|
def __init__(self, env, agent_view_size=7):
|
|
|
+ self.__dict__.update(vars(env)) # Pass values to super wrapper
|
|
|
super(AgentViewWrapper, self).__init__(env)
|
|
|
- self.__dict__.update(vars(env)) # Hack to pass values to super wrapper
|
|
|
|
|
|
# Override default view size
|
|
|
- env.agent_view_size = agent_view_size
|
|
|
+ env.unwrapped.agent_view_size = agent_view_size
|
|
|
|
|
|
# Compute observation space with specified view size
|
|
|
observation_space = gym.spaces.Box(
|
|
@@ -184,3 +191,9 @@ class AgentViewWrapper(gym.core.Wrapper):
|
|
|
self.observation_space = spaces.Dict({
|
|
|
'image': observation_space
|
|
|
})
|
|
|
+
|
|
|
+ def reset(self, **kwargs):
|
|
|
+ return self.env.reset(**kwargs)
|
|
|
+
|
|
|
+ def step(self, action):
|
|
|
+ return self.env.step(action)
|