|
@@ -7,7 +7,8 @@ import gym
|
|
from gym import error, spaces, utils
|
|
from gym import error, spaces, utils
|
|
from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX, Goal
|
|
from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX, Goal
|
|
|
|
|
|
-class ReseedWrapper(gym.core.Wrapper):
|
|
|
|
|
|
+
|
|
|
|
+class ReseedWrapper(gym.Wrapper):
|
|
"""
|
|
"""
|
|
Wrapper to always regenerate an environment with the same set of seeds.
|
|
Wrapper to always regenerate an environment with the same set of seeds.
|
|
This can be used to force an environment to always keep the same
|
|
This can be used to force an environment to always keep the same
|
|
@@ -28,7 +29,8 @@ class ReseedWrapper(gym.core.Wrapper):
|
|
obs, reward, done, info = self.env.step(action)
|
|
obs, reward, done, info = self.env.step(action)
|
|
return obs, reward, done, info
|
|
return obs, reward, done, info
|
|
|
|
|
|
-class ActionBonus(gym.core.Wrapper):
|
|
|
|
|
|
+
|
|
|
|
+class ActionBonus(gym.Wrapper):
|
|
"""
|
|
"""
|
|
Wrapper which adds an exploration bonus.
|
|
Wrapper which adds an exploration bonus.
|
|
This is a reward to encourage exploration of less
|
|
This is a reward to encourage exploration of less
|
|
@@ -62,7 +64,8 @@ class ActionBonus(gym.core.Wrapper):
|
|
def reset(self, **kwargs):
|
|
def reset(self, **kwargs):
|
|
return self.env.reset(**kwargs)
|
|
return self.env.reset(**kwargs)
|
|
|
|
|
|
-class StateBonus(gym.core.Wrapper):
|
|
|
|
|
|
+
|
|
|
|
+class StateBonus(gym.Wrapper):
|
|
"""
|
|
"""
|
|
Adds an exploration bonus based on which positions
|
|
Adds an exploration bonus based on which positions
|
|
are visited on the grid.
|
|
are visited on the grid.
|
|
@@ -97,7 +100,8 @@ class StateBonus(gym.core.Wrapper):
|
|
def reset(self, **kwargs):
|
|
def reset(self, **kwargs):
|
|
return self.env.reset(**kwargs)
|
|
return self.env.reset(**kwargs)
|
|
|
|
|
|
-class ImgObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
+
|
|
|
|
+class ImgObsWrapper(gym.ObservationWrapper):
|
|
"""
|
|
"""
|
|
Use the image as the only observation output, no language/mission.
|
|
Use the image as the only observation output, no language/mission.
|
|
"""
|
|
"""
|
|
@@ -109,7 +113,8 @@ class ImgObsWrapper(gym.core.ObservationWrapper):
|
|
def observation(self, obs):
|
|
def observation(self, obs):
|
|
return obs['image']
|
|
return obs['image']
|
|
|
|
|
|
-class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
+
|
|
|
|
+class OneHotPartialObsWrapper(gym.ObservationWrapper):
|
|
"""
|
|
"""
|
|
Wrapper to get a one-hot encoding of a partially observable
|
|
Wrapper to get a one-hot encoding of a partially observable
|
|
agent view as observation.
|
|
agent view as observation.
|
|
@@ -131,11 +136,13 @@ class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
|
|
shape=(obs_shape[0], obs_shape[1], num_bits),
|
|
shape=(obs_shape[0], obs_shape[1], num_bits),
|
|
dtype='uint8'
|
|
dtype='uint8'
|
|
)
|
|
)
|
|
- self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
|
|
|
|
|
|
+ self.observation_space = spaces.Dict(
|
|
|
|
+ {**self.observation_space, 'image': new_image_space})
|
|
|
|
|
|
def observation(self, obs):
|
|
def observation(self, obs):
|
|
img = obs['image']
|
|
img = obs['image']
|
|
- out = np.zeros(self.observation_space.spaces['image'].shape, dtype='uint8')
|
|
|
|
|
|
+ out = np.zeros(
|
|
|
|
+ self.observation_space.spaces['image'].shape, dtype='uint8')
|
|
|
|
|
|
for i in range(img.shape[0]):
|
|
for i in range(img.shape[0]):
|
|
for j in range(img.shape[1]):
|
|
for j in range(img.shape[1]):
|
|
@@ -152,7 +159,8 @@ class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
|
|
'image': out
|
|
'image': out
|
|
}
|
|
}
|
|
|
|
|
|
-class RGBImgObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
+
|
|
|
|
+class RGBImgObsWrapper(gym.ObservationWrapper):
|
|
"""
|
|
"""
|
|
Wrapper to use fully observable RGB image as observation,
|
|
Wrapper to use fully observable RGB image as observation,
|
|
This can be used to have the agent to solve the gridworld in pixel space.
|
|
This can be used to have the agent to solve the gridworld in pixel space.
|
|
@@ -171,7 +179,8 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
|
|
dtype='uint8'
|
|
dtype='uint8'
|
|
)
|
|
)
|
|
|
|
|
|
- self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
|
|
|
|
|
|
+ self.observation_space = spaces.Dict(
|
|
|
|
+ {**self.observation_space, 'image': new_image_space})
|
|
|
|
|
|
def observation(self, obs):
|
|
def observation(self, obs):
|
|
env = self.unwrapped
|
|
env = self.unwrapped
|
|
@@ -188,7 +197,7 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
-class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
+class RGBImgPartialObsWrapper(gym.ObservationWrapper):
|
|
"""
|
|
"""
|
|
Wrapper to use partially observable RGB image as observation.
|
|
Wrapper to use partially observable RGB image as observation.
|
|
This can be used to have the agent to solve the gridworld in pixel space.
|
|
This can be used to have the agent to solve the gridworld in pixel space.
|
|
@@ -207,7 +216,8 @@ class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
|
|
dtype='uint8'
|
|
dtype='uint8'
|
|
)
|
|
)
|
|
|
|
|
|
- self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
|
|
|
|
|
|
+ self.observation_space = spaces.Dict(
|
|
|
|
+ {**self.observation_space, 'image': new_image_space})
|
|
|
|
|
|
def observation(self, obs):
|
|
def observation(self, obs):
|
|
env = self.unwrapped
|
|
env = self.unwrapped
|
|
@@ -222,7 +232,8 @@ class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
|
|
'image': rgb_img_partial
|
|
'image': rgb_img_partial
|
|
}
|
|
}
|
|
|
|
|
|
-class FullyObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
+
|
|
|
|
+class FullyObsWrapper(gym.ObservationWrapper):
|
|
"""
|
|
"""
|
|
Fully observable gridworld using a compact grid encoding
|
|
Fully observable gridworld using a compact grid encoding
|
|
"""
|
|
"""
|
|
@@ -237,7 +248,8 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
|
|
dtype='uint8'
|
|
dtype='uint8'
|
|
)
|
|
)
|
|
|
|
|
|
- self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
|
|
|
|
|
|
+ self.observation_space = spaces.Dict(
|
|
|
|
+ {**self.observation_space, 'image': new_image_space})
|
|
|
|
|
|
def observation(self, obs):
|
|
def observation(self, obs):
|
|
env = self.unwrapped
|
|
env = self.unwrapped
|
|
@@ -253,7 +265,8 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
|
|
'image': full_grid
|
|
'image': full_grid
|
|
}
|
|
}
|
|
|
|
|
|
-class DictObservationSpaceWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
+
|
|
|
|
+class DictObservationSpaceWrapper(gym.ObservationWrapper):
|
|
"""
|
|
"""
|
|
Use a Dict Obsevation Space encoding images, missions, and directions
|
|
Use a Dict Obsevation Space encoding images, missions, and directions
|
|
"""
|
|
"""
|
|
@@ -268,7 +281,7 @@ class DictObservationSpaceWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
if word_dict is None:
|
|
if word_dict is None:
|
|
word_dict = DictObservationSpaceWrapper.get_minigrid_words()
|
|
word_dict = DictObservationSpaceWrapper.get_minigrid_words()
|
|
-
|
|
|
|
|
|
+
|
|
self.max_words_in_mission = max_words_in_mission
|
|
self.max_words_in_mission = max_words_in_mission
|
|
self.word_dict = word_dict
|
|
self.word_dict = word_dict
|
|
|
|
|
|
@@ -282,24 +295,23 @@ class DictObservationSpaceWrapper(gym.core.ObservationWrapper):
|
|
'image': image_observation_space,
|
|
'image': image_observation_space,
|
|
'direction': spaces.Discrete(4),
|
|
'direction': spaces.Discrete(4),
|
|
'mission': spaces.MultiDiscrete([len(self.word_dict.keys())]
|
|
'mission': spaces.MultiDiscrete([len(self.word_dict.keys())]
|
|
- * max_words_in_mission)
|
|
|
|
|
|
+ * max_words_in_mission)
|
|
})
|
|
})
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
def get_minigrid_words():
|
|
def get_minigrid_words():
|
|
colors = ['red', 'green', 'blue', 'yellow', 'purple', 'grey']
|
|
colors = ['red', 'green', 'blue', 'yellow', 'purple', 'grey']
|
|
objects = ['unseen', 'empty', 'wall', 'floor', 'box', 'key', 'ball',
|
|
objects = ['unseen', 'empty', 'wall', 'floor', 'box', 'key', 'ball',
|
|
- 'door', 'goal', 'agent', 'lava']
|
|
|
|
|
|
+ 'door', 'goal', 'agent', 'lava']
|
|
|
|
|
|
verbs = ['pick', 'avoid', 'get', 'find', 'put',
|
|
verbs = ['pick', 'avoid', 'get', 'find', 'put',
|
|
- 'use', 'open', 'go', 'fetch',
|
|
|
|
- 'reach', 'unlock', 'traverse']
|
|
|
|
|
|
+ 'use', 'open', 'go', 'fetch',
|
|
|
|
+ 'reach', 'unlock', 'traverse']
|
|
|
|
|
|
extra_words = ['up', 'the', 'a', 'at', ',', 'square',
|
|
extra_words = ['up', 'the', 'a', 'at', ',', 'square',
|
|
- 'and', 'then', 'to', 'of', 'rooms', 'near',
|
|
|
|
- 'opening', 'must', 'you', 'matching', 'end',
|
|
|
|
- 'hallway', 'object', 'from', 'room']
|
|
|
|
-
|
|
|
|
|
|
+ 'and', 'then', 'to', 'of', 'rooms', 'near',
|
|
|
|
+ 'opening', 'must', 'you', 'matching', 'end',
|
|
|
|
+ 'hallway', 'object', 'from', 'room']
|
|
|
|
|
|
all_words = colors + objects + verbs + extra_words
|
|
all_words = colors + objects + verbs + extra_words
|
|
assert len(all_words) == len(set(all_words))
|
|
assert len(all_words) == len(set(all_words))
|
|
@@ -310,22 +322,25 @@ class DictObservationSpaceWrapper(gym.core.ObservationWrapper):
|
|
Convert a string to a list of indices.
|
|
Convert a string to a list of indices.
|
|
"""
|
|
"""
|
|
indices = []
|
|
indices = []
|
|
- string = string.replace(',', ' , ') # adding space before and after commas
|
|
|
|
|
|
+ # adding space before and after commas
|
|
|
|
+ string = string.replace(',', ' , ')
|
|
for word in string.split():
|
|
for word in string.split():
|
|
if word in self.word_dict.keys():
|
|
if word in self.word_dict.keys():
|
|
indices.append(self.word_dict[word] + offset)
|
|
indices.append(self.word_dict[word] + offset)
|
|
else:
|
|
else:
|
|
raise ValueError('Unknown word: {}'.format(word))
|
|
raise ValueError('Unknown word: {}'.format(word))
|
|
return indices
|
|
return indices
|
|
-
|
|
|
|
|
|
+
|
|
def observation(self, obs):
|
|
def observation(self, obs):
|
|
obs['mission'] = self.string_to_indices(obs['mission'])
|
|
obs['mission'] = self.string_to_indices(obs['mission'])
|
|
assert len(obs['mission']) < self.max_words_in_mission
|
|
assert len(obs['mission']) < self.max_words_in_mission
|
|
- obs['mission'] += [0] * (self.max_words_in_mission - len(obs['mission']))
|
|
|
|
|
|
+ obs['mission'] += [0] * \
|
|
|
|
+ (self.max_words_in_mission - len(obs['mission']))
|
|
|
|
|
|
return obs
|
|
return obs
|
|
-
|
|
|
|
-class FlatObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class FlatObsWrapper(gym.ObservationWrapper):
|
|
"""
|
|
"""
|
|
Encode mission strings using a one-hot scheme,
|
|
Encode mission strings using a one-hot scheme,
|
|
and combine these with observed images into one flat array
|
|
and combine these with observed images into one flat array
|
|
@@ -356,10 +371,12 @@ class FlatObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
# Cache the last-encoded mission string
|
|
# Cache the last-encoded mission string
|
|
if mission != self.cachedStr:
|
|
if mission != self.cachedStr:
|
|
- assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(len(mission))
|
|
|
|
|
|
+ assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(
|
|
|
|
+ len(mission))
|
|
mission = mission.lower()
|
|
mission = mission.lower()
|
|
|
|
|
|
- strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
|
|
|
|
|
|
+ strArray = np.zeros(
|
|
|
|
+ shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
|
|
|
|
|
|
for idx, ch in enumerate(mission):
|
|
for idx, ch in enumerate(mission):
|
|
if ch >= 'a' and ch <= 'z':
|
|
if ch >= 'a' and ch <= 'z':
|
|
@@ -376,7 +393,8 @@ class FlatObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
return obs
|
|
return obs
|
|
|
|
|
|
-class ViewSizeWrapper(gym.core.Wrapper):
|
|
|
|
|
|
+
|
|
|
|
+class ViewSizeWrapper(gym.Wrapper):
|
|
"""
|
|
"""
|
|
Wrapper to customize the agent field of view size.
|
|
Wrapper to customize the agent field of view size.
|
|
This cannot be used with fully observable wrappers.
|
|
This cannot be used with fully observable wrappers.
|
|
@@ -399,7 +417,8 @@ class ViewSizeWrapper(gym.core.Wrapper):
|
|
)
|
|
)
|
|
|
|
|
|
# Override the environment's observation spaceexit
|
|
# Override the environment's observation spaceexit
|
|
- self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
|
|
|
|
|
|
+ self.observation_space = spaces.Dict(
|
|
|
|
+ {**self.observation_space, 'image': new_image_space})
|
|
|
|
|
|
def observation(self, obs):
|
|
def observation(self, obs):
|
|
env = self.unwrapped
|
|
env = self.unwrapped
|
|
@@ -409,18 +428,19 @@ class ViewSizeWrapper(gym.core.Wrapper):
|
|
# Encode the partially observable view into a numpy array
|
|
# Encode the partially observable view into a numpy array
|
|
image = grid.encode(vis_mask)
|
|
image = grid.encode(vis_mask)
|
|
|
|
|
|
-
|
|
|
|
return {
|
|
return {
|
|
**obs,
|
|
**obs,
|
|
'image': image
|
|
'image': image
|
|
}
|
|
}
|
|
|
|
|
|
-class DirectionObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
+
|
|
|
|
+class DirectionObsWrapper(gym.ObservationWrapper):
|
|
"""
|
|
"""
|
|
Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
|
|
Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
|
|
type = {slope , angle}
|
|
type = {slope , angle}
|
|
"""
|
|
"""
|
|
- def __init__(self, env,type='slope'):
|
|
|
|
|
|
+
|
|
|
|
+ def __init__(self, env, type='slope'):
|
|
super().__init__(env)
|
|
super().__init__(env)
|
|
self.goal_position = None
|
|
self.goal_position = None
|
|
self.type = type
|
|
self.type = type
|
|
@@ -428,17 +448,23 @@ class DirectionObsWrapper(gym.core.ObservationWrapper):
|
|
def reset(self):
|
|
def reset(self):
|
|
obs = self.env.reset()
|
|
obs = self.env.reset()
|
|
if not self.goal_position:
|
|
if not self.goal_position:
|
|
- self.goal_position = [x for x,y in enumerate(self.grid.grid) if isinstance(y,(Goal) ) ]
|
|
|
|
- if len(self.goal_position) >= 1: # in case there are multiple goals , needs to be handled for other env types
|
|
|
|
- self.goal_position = (int(self.goal_position[0]/self.height) , self.goal_position[0]%self.width)
|
|
|
|
|
|
+ self.goal_position = [x for x, y in enumerate(
|
|
|
|
+ self.grid.grid) if isinstance(y, (Goal))]
|
|
|
|
+ # in case there are multiple goals , needs to be handled for other env types
|
|
|
|
+ if len(self.goal_position) >= 1:
|
|
|
|
+ self.goal_position = (
|
|
|
|
+ int(self.goal_position[0]/self.height), self.goal_position[0] % self.width)
|
|
return obs
|
|
return obs
|
|
|
|
|
|
def observation(self, obs):
|
|
def observation(self, obs):
|
|
- slope = np.divide( self.goal_position[1] - self.agent_pos[1] , self.goal_position[0] - self.agent_pos[0])
|
|
|
|
- obs['goal_direction'] = np.arctan( slope ) if self.type == 'angle' else slope
|
|
|
|
|
|
+ slope = np.divide(
|
|
|
|
+ self.goal_position[1] - self.agent_pos[1], self.goal_position[0] - self.agent_pos[0])
|
|
|
|
+ obs['goal_direction'] = np.arctan(
|
|
|
|
+ slope) if self.type == 'angle' else slope
|
|
return obs
|
|
return obs
|
|
|
|
|
|
-class SymbolicObsWrapper(gym.core.ObservationWrapper):
|
|
|
|
|
|
+
|
|
|
|
+class SymbolicObsWrapper(gym.ObservationWrapper):
|
|
"""
|
|
"""
|
|
Fully observable grid with a symbolic state representation.
|
|
Fully observable grid with a symbolic state representation.
|
|
The symbol is a triple of (X, Y, IDX), where X and Y are
|
|
The symbol is a triple of (X, Y, IDX), where X and Y are
|
|
@@ -454,7 +480,8 @@ class SymbolicObsWrapper(gym.core.ObservationWrapper):
|
|
shape=(self.env.width, self.env.height, 3), # number of cells
|
|
shape=(self.env.width, self.env.height, 3), # number of cells
|
|
dtype="uint8",
|
|
dtype="uint8",
|
|
)
|
|
)
|
|
- self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
|
|
|
|
|
|
+ self.observation_space = spaces.Dict(
|
|
|
|
+ {**self.observation_space, 'image': new_image_space})
|
|
|
|
|
|
def observation(self, obs):
|
|
def observation(self, obs):
|
|
objects = np.array(
|
|
objects = np.array(
|