浏览代码

New wrapper: Use MultiDiscrete gym space to represent textual observation. This allow for smooth usage with RL libraries that require observations to be portable to torch

saleml 2 年之前
父节点
当前提交
56cc2baa25
共有 2 个文件被更改,包括 81 次插入0 次删除
  1. 72 0
      gym_minigrid/wrappers.py
  2. 9 0
      run_tests.py

+ 72 - 0
gym_minigrid/wrappers.py

@@ -246,6 +246,78 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
             'image': full_grid
             'image': full_grid
         }
         }
 
 
+class DictObservationSpaceWrapper(gym.core.ObservationWrapper):
+    """
+    Use a Dict Obsevation Space encoding images, missions, and directions
+    """
+
+    def __init__(self, env, max_words_in_mission=50, word_dict=None):
+        """
+        max_words_in_mission is the length of the array to represent a mission, value 0 for missing words
+        word_dict is a dictionary of words to use (keys=words, values=indices from 1 to < max_words_in_mission),
+                  if None, use the Minigrid language
+        """
+        super().__init__(env)
+
+        if word_dict is None:
+            word_dict = DictObservationSpaceWrapper.get_minigrid_words()
+            
+        self.max_words_in_mission = max_words_in_mission
+        self.word_dict = word_dict
+
+        image_observation_space = spaces.Box(
+            low=0,
+            high=255,
+            shape=(self.agent_view_size, self.agent_view_size, 3),
+            dtype='uint8'
+        )
+        self.observation_space = spaces.Dict({
+            'image': image_observation_space,
+            'direction': spaces.Discrete(4),
+            'mission': spaces.MultiDiscrete([len(self.word_dict.keys())]
+             * max_words_in_mission)
+        })
+
+    @staticmethod
+    def get_minigrid_words():
+        colors = ['red', 'green', 'blue', 'yellow', 'purple', 'grey']
+        objects = ['unseen', 'empty', 'wall', 'floor', 'box', 'key', 'ball',
+        'door', 'goal', 'agent', 'lava']
+
+        verbs = ['pick', 'avoid', 'get', 'find', 'put',
+                'use', 'open', 'go', 'fetch',
+                'reach', 'unlock', 'traverse']
+
+        extra_words = ['up', 'the', 'a', 'at', ',', 'square',
+                    'and', 'then', 'to', 'of', 'rooms', 'near',
+                    'opening', 'must', 'you', 'matching', 'end',
+                    'hallway', 'object', 'from', 'room']
+
+
+        all_words = colors + objects + verbs + extra_words
+        assert len(all_words) == len(set(all_words))
+        return {word: i for i, word in enumerate(all_words)}
+
+    def string_to_indices(self, string, offset=1):
+        """
+        Convert a string to a list of indices.
+        """
+        indices = []
+        string = string.replace(',', ' , ')  # adding space before and after commas
+        for word in string.split():
+            if word in self.word_dict.keys():
+                indices.append(self.word_dict[word] + offset)
+            else:
+                raise ValueError('Unknown word: {}'.format(word))
+        return indices
+        
+    def observation(self, obs):
+        obs['mission'] = self.string_to_indices(obs['mission'])
+        assert len(obs['mission']) < self.max_words_in_mission
+        obs['mission'] += [0] * (self.max_words_in_mission - len(obs['mission']))
+
+        return obs
+        
 class FlatObsWrapper(gym.core.ObservationWrapper):
 class FlatObsWrapper(gym.core.ObservationWrapper):
     """
     """
     Encode mission strings using a one-hot scheme,
     Encode mission strings using a one-hot scheme,

+ 9 - 0
run_tests.py

@@ -112,6 +112,15 @@ for env_idx, env_name in enumerate(env_list):
     env.step(0)
     env.step(0)
     env.close()
     env.close()
 
 
+    # Test the DictObservationSpaceWrapper
+    env = gym.make(env_name)
+    env = DictObservationSpaceWrapper(env)
+    env.reset()
+    mission = env.mission
+    obs, _, _, _ = env.step(0)
+    assert env.string_to_indices(mission) == [value for value in obs['mission'] if value != 0]
+    env.close()
+
     # Test the wrappers return proper observation spaces.
     # Test the wrappers return proper observation spaces.
     wrappers = [
     wrappers = [
         RGBImgObsWrapper,
         RGBImgObsWrapper,