|
@@ -246,6 +246,78 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
|
|
'image': full_grid
|
|
'image': full_grid
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+class DictObservationSpaceWrapper(gym.core.ObservationWrapper):
|
|
|
|
+ """
|
|
|
|
+ Use a Dict Obsevation Space encoding images, missions, and directions
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ def __init__(self, env, max_words_in_mission=50, word_dict=None):
|
|
|
|
+ """
|
|
|
|
+ max_words_in_mission is the length of the array to represent a mission, value 0 for missing words
|
|
|
|
+ word_dict is a dictionary of words to use (keys=words, values=indices from 1 to < max_words_in_mission),
|
|
|
|
+ if None, use the Minigrid language
|
|
|
|
+ """
|
|
|
|
+ super().__init__(env)
|
|
|
|
+
|
|
|
|
+ if word_dict is None:
|
|
|
|
+ word_dict = DictObservationSpaceWrapper.get_minigrid_words()
|
|
|
|
+
|
|
|
|
+ self.max_words_in_mission = max_words_in_mission
|
|
|
|
+ self.word_dict = word_dict
|
|
|
|
+
|
|
|
|
+ image_observation_space = spaces.Box(
|
|
|
|
+ low=0,
|
|
|
|
+ high=255,
|
|
|
|
+ shape=(self.agent_view_size, self.agent_view_size, 3),
|
|
|
|
+ dtype='uint8'
|
|
|
|
+ )
|
|
|
|
+ self.observation_space = spaces.Dict({
|
|
|
|
+ 'image': image_observation_space,
|
|
|
|
+ 'direction': spaces.Discrete(4),
|
|
|
|
+ 'mission': spaces.MultiDiscrete([len(self.word_dict.keys())]
|
|
|
|
+ * max_words_in_mission)
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+ @staticmethod
|
|
|
|
+ def get_minigrid_words():
|
|
|
|
+ colors = ['red', 'green', 'blue', 'yellow', 'purple', 'grey']
|
|
|
|
+ objects = ['unseen', 'empty', 'wall', 'floor', 'box', 'key', 'ball',
|
|
|
|
+ 'door', 'goal', 'agent', 'lava']
|
|
|
|
+
|
|
|
|
+ verbs = ['pick', 'avoid', 'get', 'find', 'put',
|
|
|
|
+ 'use', 'open', 'go', 'fetch',
|
|
|
|
+ 'reach', 'unlock', 'traverse']
|
|
|
|
+
|
|
|
|
+ extra_words = ['up', 'the', 'a', 'at', ',', 'square',
|
|
|
|
+ 'and', 'then', 'to', 'of', 'rooms', 'near',
|
|
|
|
+ 'opening', 'must', 'you', 'matching', 'end',
|
|
|
|
+ 'hallway', 'object', 'from', 'room']
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ all_words = colors + objects + verbs + extra_words
|
|
|
|
+ assert len(all_words) == len(set(all_words))
|
|
|
|
+ return {word: i for i, word in enumerate(all_words)}
|
|
|
|
+
|
|
|
|
+ def string_to_indices(self, string, offset=1):
|
|
|
|
+ """
|
|
|
|
+ Convert a string to a list of indices.
|
|
|
|
+ """
|
|
|
|
+ indices = []
|
|
|
|
+ string = string.replace(',', ' , ') # adding space before and after commas
|
|
|
|
+ for word in string.split():
|
|
|
|
+ if word in self.word_dict.keys():
|
|
|
|
+ indices.append(self.word_dict[word] + offset)
|
|
|
|
+ else:
|
|
|
|
+ raise ValueError('Unknown word: {}'.format(word))
|
|
|
|
+ return indices
|
|
|
|
+
|
|
|
|
+ def observation(self, obs):
|
|
|
|
+ obs['mission'] = self.string_to_indices(obs['mission'])
|
|
|
|
+ assert len(obs['mission']) < self.max_words_in_mission
|
|
|
|
+ obs['mission'] += [0] * (self.max_words_in_mission - len(obs['mission']))
|
|
|
|
+
|
|
|
|
+ return obs
|
|
|
|
+
|
|
class FlatObsWrapper(gym.core.ObservationWrapper):
|
|
class FlatObsWrapper(gym.core.ObservationWrapper):
|
|
"""
|
|
"""
|
|
Encode mission strings using a one-hot scheme,
|
|
Encode mission strings using a one-hot scheme,
|