瀏覽代碼

no imports from gym.core

saleml 2 年之前
父節點
當前提交
d20e9133b3
共有 1 個文件被更改,包括 69 次插入42 次删除
  1. 69 42
      gym_minigrid/wrappers.py

+ 69 - 42
gym_minigrid/wrappers.py

@@ -7,7 +7,8 @@ import gym
 from gym import error, spaces, utils
 from gym import error, spaces, utils
 from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX, Goal
 from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX, Goal
 
 
-class ReseedWrapper(gym.core.Wrapper):
+
+class ReseedWrapper(gym.Wrapper):
     """
     """
     Wrapper to always regenerate an environment with the same set of seeds.
     Wrapper to always regenerate an environment with the same set of seeds.
     This can be used to force an environment to always keep the same
     This can be used to force an environment to always keep the same
@@ -28,7 +29,8 @@ class ReseedWrapper(gym.core.Wrapper):
         obs, reward, done, info = self.env.step(action)
         obs, reward, done, info = self.env.step(action)
         return obs, reward, done, info
         return obs, reward, done, info
 
 
-class ActionBonus(gym.core.Wrapper):
+
+class ActionBonus(gym.Wrapper):
     """
     """
     Wrapper which adds an exploration bonus.
     Wrapper which adds an exploration bonus.
     This is a reward to encourage exploration of less
     This is a reward to encourage exploration of less
@@ -62,7 +64,8 @@ class ActionBonus(gym.core.Wrapper):
     def reset(self, **kwargs):
     def reset(self, **kwargs):
         return self.env.reset(**kwargs)
         return self.env.reset(**kwargs)
 
 
-class StateBonus(gym.core.Wrapper):
+
+class StateBonus(gym.Wrapper):
     """
     """
     Adds an exploration bonus based on which positions
     Adds an exploration bonus based on which positions
     are visited on the grid.
     are visited on the grid.
@@ -97,7 +100,8 @@ class StateBonus(gym.core.Wrapper):
     def reset(self, **kwargs):
     def reset(self, **kwargs):
         return self.env.reset(**kwargs)
         return self.env.reset(**kwargs)
 
 
-class ImgObsWrapper(gym.core.ObservationWrapper):
+
+class ImgObsWrapper(gym.ObservationWrapper):
     """
     """
     Use the image as the only observation output, no language/mission.
     Use the image as the only observation output, no language/mission.
     """
     """
@@ -109,7 +113,8 @@ class ImgObsWrapper(gym.core.ObservationWrapper):
     def observation(self, obs):
     def observation(self, obs):
         return obs['image']
         return obs['image']
 
 
-class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
+
+class OneHotPartialObsWrapper(gym.ObservationWrapper):
     """
     """
     Wrapper to get a one-hot encoding of a partially observable
     Wrapper to get a one-hot encoding of a partially observable
     agent view as observation.
     agent view as observation.
@@ -131,11 +136,13 @@ class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
             shape=(obs_shape[0], obs_shape[1], num_bits),
             shape=(obs_shape[0], obs_shape[1], num_bits),
             dtype='uint8'
             dtype='uint8'
         )
         )
-        self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
+        self.observation_space = spaces.Dict(
+            {**self.observation_space, 'image': new_image_space})
 
 
     def observation(self, obs):
     def observation(self, obs):
         img = obs['image']
         img = obs['image']
-        out = np.zeros(self.observation_space.spaces['image'].shape, dtype='uint8')
+        out = np.zeros(
+            self.observation_space.spaces['image'].shape, dtype='uint8')
 
 
         for i in range(img.shape[0]):
         for i in range(img.shape[0]):
             for j in range(img.shape[1]):
             for j in range(img.shape[1]):
@@ -152,7 +159,8 @@ class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
             'image': out
             'image': out
         }
         }
 
 
-class RGBImgObsWrapper(gym.core.ObservationWrapper):
+
+class RGBImgObsWrapper(gym.ObservationWrapper):
     """
     """
     Wrapper to use fully observable RGB image as observation,
     Wrapper to use fully observable RGB image as observation,
     This can be used to have the agent to solve the gridworld in pixel space.
     This can be used to have the agent to solve the gridworld in pixel space.
@@ -171,7 +179,8 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
             dtype='uint8'
             dtype='uint8'
         )
         )
 
 
-        self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
+        self.observation_space = spaces.Dict(
+            {**self.observation_space, 'image': new_image_space})
 
 
     def observation(self, obs):
     def observation(self, obs):
         env = self.unwrapped
         env = self.unwrapped
@@ -188,7 +197,7 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
         }
         }
 
 
 
 
-class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
+class RGBImgPartialObsWrapper(gym.ObservationWrapper):
     """
     """
     Wrapper to use partially observable RGB image as observation.
     Wrapper to use partially observable RGB image as observation.
     This can be used to have the agent to solve the gridworld in pixel space.
     This can be used to have the agent to solve the gridworld in pixel space.
@@ -207,7 +216,8 @@ class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
             dtype='uint8'
             dtype='uint8'
         )
         )
 
 
-        self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
+        self.observation_space = spaces.Dict(
+            {**self.observation_space, 'image': new_image_space})
 
 
     def observation(self, obs):
     def observation(self, obs):
         env = self.unwrapped
         env = self.unwrapped
@@ -222,7 +232,8 @@ class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
             'image': rgb_img_partial
             'image': rgb_img_partial
         }
         }
 
 
-class FullyObsWrapper(gym.core.ObservationWrapper):
+
+class FullyObsWrapper(gym.ObservationWrapper):
     """
     """
     Fully observable gridworld using a compact grid encoding
     Fully observable gridworld using a compact grid encoding
     """
     """
@@ -237,7 +248,8 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
             dtype='uint8'
             dtype='uint8'
         )
         )
 
 
-        self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
+        self.observation_space = spaces.Dict(
+            {**self.observation_space, 'image': new_image_space})
 
 
     def observation(self, obs):
     def observation(self, obs):
         env = self.unwrapped
         env = self.unwrapped
@@ -253,7 +265,8 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
             'image': full_grid
             'image': full_grid
         }
         }
 
 
-class DictObservationSpaceWrapper(gym.core.ObservationWrapper):
+
+class DictObservationSpaceWrapper(gym.ObservationWrapper):
     """
     """
     Use a Dict Obsevation Space encoding images, missions, and directions
     Use a Dict Obsevation Space encoding images, missions, and directions
     """
     """
@@ -268,7 +281,7 @@ class DictObservationSpaceWrapper(gym.core.ObservationWrapper):
 
 
         if word_dict is None:
         if word_dict is None:
             word_dict = DictObservationSpaceWrapper.get_minigrid_words()
             word_dict = DictObservationSpaceWrapper.get_minigrid_words()
-            
+
         self.max_words_in_mission = max_words_in_mission
         self.max_words_in_mission = max_words_in_mission
         self.word_dict = word_dict
         self.word_dict = word_dict
 
 
@@ -282,24 +295,23 @@ class DictObservationSpaceWrapper(gym.core.ObservationWrapper):
             'image': image_observation_space,
             'image': image_observation_space,
             'direction': spaces.Discrete(4),
             'direction': spaces.Discrete(4),
             'mission': spaces.MultiDiscrete([len(self.word_dict.keys())]
             'mission': spaces.MultiDiscrete([len(self.word_dict.keys())]
-             * max_words_in_mission)
+                                            * max_words_in_mission)
         })
         })
 
 
     @staticmethod
     @staticmethod
     def get_minigrid_words():
     def get_minigrid_words():
         colors = ['red', 'green', 'blue', 'yellow', 'purple', 'grey']
         colors = ['red', 'green', 'blue', 'yellow', 'purple', 'grey']
         objects = ['unseen', 'empty', 'wall', 'floor', 'box', 'key', 'ball',
         objects = ['unseen', 'empty', 'wall', 'floor', 'box', 'key', 'ball',
-        'door', 'goal', 'agent', 'lava']
+                   'door', 'goal', 'agent', 'lava']
 
 
         verbs = ['pick', 'avoid', 'get', 'find', 'put',
         verbs = ['pick', 'avoid', 'get', 'find', 'put',
-                'use', 'open', 'go', 'fetch',
-                'reach', 'unlock', 'traverse']
+                 'use', 'open', 'go', 'fetch',
+                 'reach', 'unlock', 'traverse']
 
 
         extra_words = ['up', 'the', 'a', 'at', ',', 'square',
         extra_words = ['up', 'the', 'a', 'at', ',', 'square',
-                    'and', 'then', 'to', 'of', 'rooms', 'near',
-                    'opening', 'must', 'you', 'matching', 'end',
-                    'hallway', 'object', 'from', 'room']
-
+                       'and', 'then', 'to', 'of', 'rooms', 'near',
+                       'opening', 'must', 'you', 'matching', 'end',
+                       'hallway', 'object', 'from', 'room']
 
 
         all_words = colors + objects + verbs + extra_words
         all_words = colors + objects + verbs + extra_words
         assert len(all_words) == len(set(all_words))
         assert len(all_words) == len(set(all_words))
@@ -310,22 +322,25 @@ class DictObservationSpaceWrapper(gym.core.ObservationWrapper):
         Convert a string to a list of indices.
         Convert a string to a list of indices.
         """
         """
         indices = []
         indices = []
-        string = string.replace(',', ' , ')  # adding space before and after commas
+        # adding space before and after commas
+        string = string.replace(',', ' , ')
         for word in string.split():
         for word in string.split():
             if word in self.word_dict.keys():
             if word in self.word_dict.keys():
                 indices.append(self.word_dict[word] + offset)
                 indices.append(self.word_dict[word] + offset)
             else:
             else:
                 raise ValueError('Unknown word: {}'.format(word))
                 raise ValueError('Unknown word: {}'.format(word))
         return indices
         return indices
-        
+
     def observation(self, obs):
     def observation(self, obs):
         obs['mission'] = self.string_to_indices(obs['mission'])
         obs['mission'] = self.string_to_indices(obs['mission'])
         assert len(obs['mission']) < self.max_words_in_mission
         assert len(obs['mission']) < self.max_words_in_mission
-        obs['mission'] += [0] * (self.max_words_in_mission - len(obs['mission']))
+        obs['mission'] += [0] * \
+            (self.max_words_in_mission - len(obs['mission']))
 
 
         return obs
         return obs
-        
-class FlatObsWrapper(gym.core.ObservationWrapper):
+
+
+class FlatObsWrapper(gym.ObservationWrapper):
     """
     """
     Encode mission strings using a one-hot scheme,
     Encode mission strings using a one-hot scheme,
     and combine these with observed images into one flat array
     and combine these with observed images into one flat array
@@ -356,10 +371,12 @@ class FlatObsWrapper(gym.core.ObservationWrapper):
 
 
         # Cache the last-encoded mission string
         # Cache the last-encoded mission string
         if mission != self.cachedStr:
         if mission != self.cachedStr:
-            assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(len(mission))
+            assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(
+                len(mission))
             mission = mission.lower()
             mission = mission.lower()
 
 
-            strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
+            strArray = np.zeros(
+                shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
 
 
             for idx, ch in enumerate(mission):
             for idx, ch in enumerate(mission):
                 if ch >= 'a' and ch <= 'z':
                 if ch >= 'a' and ch <= 'z':
@@ -376,7 +393,8 @@ class FlatObsWrapper(gym.core.ObservationWrapper):
 
 
         return obs
         return obs
 
 
-class ViewSizeWrapper(gym.core.Wrapper):
+
+class ViewSizeWrapper(gym.Wrapper):
     """
     """
     Wrapper to customize the agent field of view size.
     Wrapper to customize the agent field of view size.
     This cannot be used with fully observable wrappers.
     This cannot be used with fully observable wrappers.
@@ -399,7 +417,8 @@ class ViewSizeWrapper(gym.core.Wrapper):
         )
         )
 
 
         # Override the environment's observation spaceexit
         # Override the environment's observation spaceexit
-        self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
+        self.observation_space = spaces.Dict(
+            {**self.observation_space, 'image': new_image_space})
 
 
     def observation(self, obs):
     def observation(self, obs):
         env = self.unwrapped
         env = self.unwrapped
@@ -409,18 +428,19 @@ class ViewSizeWrapper(gym.core.Wrapper):
         # Encode the partially observable view into a numpy array
         # Encode the partially observable view into a numpy array
         image = grid.encode(vis_mask)
         image = grid.encode(vis_mask)
 
 
-
         return {
         return {
             **obs,
             **obs,
             'image': image
             'image': image
         }
         }
 
 
-class DirectionObsWrapper(gym.core.ObservationWrapper):
+
+class DirectionObsWrapper(gym.ObservationWrapper):
     """
     """
     Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
     Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
     type = {slope , angle}
     type = {slope , angle}
     """
     """
-    def __init__(self, env,type='slope'):
+
+    def __init__(self, env, type='slope'):
         super().__init__(env)
         super().__init__(env)
         self.goal_position = None
         self.goal_position = None
         self.type = type
         self.type = type
@@ -428,17 +448,23 @@ class DirectionObsWrapper(gym.core.ObservationWrapper):
     def reset(self):
     def reset(self):
         obs = self.env.reset()
         obs = self.env.reset()
         if not self.goal_position:
         if not self.goal_position:
-            self.goal_position = [x for x,y in enumerate(self.grid.grid) if isinstance(y,(Goal) ) ]
-            if len(self.goal_position) >= 1: # in case there are multiple goals , needs to be handled for other env types
-                self.goal_position = (int(self.goal_position[0]/self.height) , self.goal_position[0]%self.width)
+            self.goal_position = [x for x, y in enumerate(
+                self.grid.grid) if isinstance(y, (Goal))]
+            # in case there are multiple goals , needs to be handled for other env types
+            if len(self.goal_position) >= 1:
+                self.goal_position = (
+                    int(self.goal_position[0]/self.height), self.goal_position[0] % self.width)
         return obs
         return obs
 
 
     def observation(self, obs):
     def observation(self, obs):
-        slope = np.divide( self.goal_position[1] - self.agent_pos[1] ,  self.goal_position[0] - self.agent_pos[0])
-        obs['goal_direction'] = np.arctan( slope ) if self.type == 'angle' else slope
+        slope = np.divide(
+            self.goal_position[1] - self.agent_pos[1],  self.goal_position[0] - self.agent_pos[0])
+        obs['goal_direction'] = np.arctan(
+            slope) if self.type == 'angle' else slope
         return obs
         return obs
 
 
-class SymbolicObsWrapper(gym.core.ObservationWrapper):
+
+class SymbolicObsWrapper(gym.ObservationWrapper):
     """
     """
     Fully observable grid with a symbolic state representation.
     Fully observable grid with a symbolic state representation.
     The symbol is a triple of (X, Y, IDX), where X and Y are
     The symbol is a triple of (X, Y, IDX), where X and Y are
@@ -454,7 +480,8 @@ class SymbolicObsWrapper(gym.core.ObservationWrapper):
             shape=(self.env.width, self.env.height, 3),  # number of cells
             shape=(self.env.width, self.env.height, 3),  # number of cells
             dtype="uint8",
             dtype="uint8",
         )
         )
-        self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
+        self.observation_space = spaces.Dict(
+            {**self.observation_space, 'image': new_image_space})
 
 
     def observation(self, obs):
     def observation(self, obs):
         objects = np.array(
         objects = np.array(