Преглед на файлове

no imports from gym.core

saleml преди 2 години
родител
ревизия
d20e9133b3
променени са 1 файла, в които са добавени 69 реда и са изтрити 42 реда
  1. 69 42
      gym_minigrid/wrappers.py

+ 69 - 42
gym_minigrid/wrappers.py

@@ -7,7 +7,8 @@ import gym
 from gym import error, spaces, utils
 from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX, Goal
 
-class ReseedWrapper(gym.core.Wrapper):
+
+class ReseedWrapper(gym.Wrapper):
     """
     Wrapper to always regenerate an environment with the same set of seeds.
     This can be used to force an environment to always keep the same
@@ -28,7 +29,8 @@ class ReseedWrapper(gym.core.Wrapper):
         obs, reward, done, info = self.env.step(action)
         return obs, reward, done, info
 
-class ActionBonus(gym.core.Wrapper):
+
+class ActionBonus(gym.Wrapper):
     """
     Wrapper which adds an exploration bonus.
     This is a reward to encourage exploration of less
@@ -62,7 +64,8 @@ class ActionBonus(gym.core.Wrapper):
     def reset(self, **kwargs):
         return self.env.reset(**kwargs)
 
-class StateBonus(gym.core.Wrapper):
+
+class StateBonus(gym.Wrapper):
     """
     Adds an exploration bonus based on which positions
     are visited on the grid.
@@ -97,7 +100,8 @@ class StateBonus(gym.core.Wrapper):
     def reset(self, **kwargs):
         return self.env.reset(**kwargs)
 
-class ImgObsWrapper(gym.core.ObservationWrapper):
+
+class ImgObsWrapper(gym.ObservationWrapper):
     """
     Use the image as the only observation output, no language/mission.
     """
@@ -109,7 +113,8 @@ class ImgObsWrapper(gym.core.ObservationWrapper):
     def observation(self, obs):
         return obs['image']
 
-class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
+
+class OneHotPartialObsWrapper(gym.ObservationWrapper):
     """
     Wrapper to get a one-hot encoding of a partially observable
     agent view as observation.
@@ -131,11 +136,13 @@ class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
             shape=(obs_shape[0], obs_shape[1], num_bits),
             dtype='uint8'
         )
-        self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
+        self.observation_space = spaces.Dict(
+            {**self.observation_space, 'image': new_image_space})
 
     def observation(self, obs):
         img = obs['image']
-        out = np.zeros(self.observation_space.spaces['image'].shape, dtype='uint8')
+        out = np.zeros(
+            self.observation_space.spaces['image'].shape, dtype='uint8')
 
         for i in range(img.shape[0]):
             for j in range(img.shape[1]):
@@ -152,7 +159,8 @@ class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
             'image': out
         }
 
-class RGBImgObsWrapper(gym.core.ObservationWrapper):
+
+class RGBImgObsWrapper(gym.ObservationWrapper):
     """
     Wrapper to use fully observable RGB image as observation,
     This can be used to have the agent to solve the gridworld in pixel space.
@@ -171,7 +179,8 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
             dtype='uint8'
         )
 
-        self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
+        self.observation_space = spaces.Dict(
+            {**self.observation_space, 'image': new_image_space})
 
     def observation(self, obs):
         env = self.unwrapped
@@ -188,7 +197,7 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
         }
 
 
-class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
+class RGBImgPartialObsWrapper(gym.ObservationWrapper):
     """
     Wrapper to use partially observable RGB image as observation.
     This can be used to have the agent to solve the gridworld in pixel space.
@@ -207,7 +216,8 @@ class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
             dtype='uint8'
         )
 
-        self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
+        self.observation_space = spaces.Dict(
+            {**self.observation_space, 'image': new_image_space})
 
     def observation(self, obs):
         env = self.unwrapped
@@ -222,7 +232,8 @@ class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
             'image': rgb_img_partial
         }
 
-class FullyObsWrapper(gym.core.ObservationWrapper):
+
+class FullyObsWrapper(gym.ObservationWrapper):
     """
     Fully observable gridworld using a compact grid encoding
     """
@@ -237,7 +248,8 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
             dtype='uint8'
         )
 
-        self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
+        self.observation_space = spaces.Dict(
+            {**self.observation_space, 'image': new_image_space})
 
     def observation(self, obs):
         env = self.unwrapped
@@ -253,7 +265,8 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
             'image': full_grid
         }
 
-class DictObservationSpaceWrapper(gym.core.ObservationWrapper):
+
+class DictObservationSpaceWrapper(gym.ObservationWrapper):
     """
     Use a Dict Obsevation Space encoding images, missions, and directions
     """
@@ -268,7 +281,7 @@ class DictObservationSpaceWrapper(gym.core.ObservationWrapper):
 
         if word_dict is None:
             word_dict = DictObservationSpaceWrapper.get_minigrid_words()
-            
+
         self.max_words_in_mission = max_words_in_mission
         self.word_dict = word_dict
 
@@ -282,24 +295,23 @@ class DictObservationSpaceWrapper(gym.core.ObservationWrapper):
             'image': image_observation_space,
             'direction': spaces.Discrete(4),
             'mission': spaces.MultiDiscrete([len(self.word_dict.keys())]
-             * max_words_in_mission)
+                                            * max_words_in_mission)
         })
 
     @staticmethod
     def get_minigrid_words():
         colors = ['red', 'green', 'blue', 'yellow', 'purple', 'grey']
         objects = ['unseen', 'empty', 'wall', 'floor', 'box', 'key', 'ball',
-        'door', 'goal', 'agent', 'lava']
+                   'door', 'goal', 'agent', 'lava']
 
         verbs = ['pick', 'avoid', 'get', 'find', 'put',
-                'use', 'open', 'go', 'fetch',
-                'reach', 'unlock', 'traverse']
+                 'use', 'open', 'go', 'fetch',
+                 'reach', 'unlock', 'traverse']
 
         extra_words = ['up', 'the', 'a', 'at', ',', 'square',
-                    'and', 'then', 'to', 'of', 'rooms', 'near',
-                    'opening', 'must', 'you', 'matching', 'end',
-                    'hallway', 'object', 'from', 'room']
-
+                       'and', 'then', 'to', 'of', 'rooms', 'near',
+                       'opening', 'must', 'you', 'matching', 'end',
+                       'hallway', 'object', 'from', 'room']
 
         all_words = colors + objects + verbs + extra_words
         assert len(all_words) == len(set(all_words))
@@ -310,22 +322,25 @@ class DictObservationSpaceWrapper(gym.core.ObservationWrapper):
         Convert a string to a list of indices.
         """
         indices = []
-        string = string.replace(',', ' , ')  # adding space before and after commas
+        # adding space before and after commas
+        string = string.replace(',', ' , ')
         for word in string.split():
             if word in self.word_dict.keys():
                 indices.append(self.word_dict[word] + offset)
             else:
                 raise ValueError('Unknown word: {}'.format(word))
         return indices
-        
+
     def observation(self, obs):
         obs['mission'] = self.string_to_indices(obs['mission'])
         assert len(obs['mission']) < self.max_words_in_mission
-        obs['mission'] += [0] * (self.max_words_in_mission - len(obs['mission']))
+        obs['mission'] += [0] * \
+            (self.max_words_in_mission - len(obs['mission']))
 
         return obs
-        
-class FlatObsWrapper(gym.core.ObservationWrapper):
+
+
+class FlatObsWrapper(gym.ObservationWrapper):
     """
     Encode mission strings using a one-hot scheme,
     and combine these with observed images into one flat array
@@ -356,10 +371,12 @@ class FlatObsWrapper(gym.core.ObservationWrapper):
 
         # Cache the last-encoded mission string
         if mission != self.cachedStr:
-            assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(len(mission))
+            assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(
+                len(mission))
             mission = mission.lower()
 
-            strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
+            strArray = np.zeros(
+                shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
 
             for idx, ch in enumerate(mission):
                 if ch >= 'a' and ch <= 'z':
@@ -376,7 +393,8 @@ class FlatObsWrapper(gym.core.ObservationWrapper):
 
         return obs
 
-class ViewSizeWrapper(gym.core.Wrapper):
+
+class ViewSizeWrapper(gym.Wrapper):
     """
     Wrapper to customize the agent field of view size.
     This cannot be used with fully observable wrappers.
@@ -399,7 +417,8 @@ class ViewSizeWrapper(gym.core.Wrapper):
         )
 
         # Override the environment's observation spaceexit
-        self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
+        self.observation_space = spaces.Dict(
+            {**self.observation_space, 'image': new_image_space})
 
     def observation(self, obs):
         env = self.unwrapped
@@ -409,18 +428,19 @@ class ViewSizeWrapper(gym.core.Wrapper):
         # Encode the partially observable view into a numpy array
         image = grid.encode(vis_mask)
 
-
         return {
             **obs,
             'image': image
         }
 
-class DirectionObsWrapper(gym.core.ObservationWrapper):
+
+class DirectionObsWrapper(gym.ObservationWrapper):
     """
     Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
     type = {slope , angle}
     """
-    def __init__(self, env,type='slope'):
+
+    def __init__(self, env, type='slope'):
         super().__init__(env)
         self.goal_position = None
         self.type = type
@@ -428,17 +448,23 @@ class DirectionObsWrapper(gym.core.ObservationWrapper):
     def reset(self):
         obs = self.env.reset()
         if not self.goal_position:
-            self.goal_position = [x for x,y in enumerate(self.grid.grid) if isinstance(y,(Goal) ) ]
-            if len(self.goal_position) >= 1: # in case there are multiple goals , needs to be handled for other env types
-                self.goal_position = (int(self.goal_position[0]/self.height) , self.goal_position[0]%self.width)
+            self.goal_position = [x for x, y in enumerate(
+                self.grid.grid) if isinstance(y, (Goal))]
+            # in case there are multiple goals , needs to be handled for other env types
+            if len(self.goal_position) >= 1:
+                self.goal_position = (
+                    int(self.goal_position[0]/self.height), self.goal_position[0] % self.width)
         return obs
 
     def observation(self, obs):
-        slope = np.divide( self.goal_position[1] - self.agent_pos[1] ,  self.goal_position[0] - self.agent_pos[0])
-        obs['goal_direction'] = np.arctan( slope ) if self.type == 'angle' else slope
+        slope = np.divide(
+            self.goal_position[1] - self.agent_pos[1],  self.goal_position[0] - self.agent_pos[0])
+        obs['goal_direction'] = np.arctan(
+            slope) if self.type == 'angle' else slope
         return obs
 
-class SymbolicObsWrapper(gym.core.ObservationWrapper):
+
+class SymbolicObsWrapper(gym.ObservationWrapper):
     """
     Fully observable grid with a symbolic state representation.
     The symbol is a triple of (X, Y, IDX), where X and Y are
@@ -454,7 +480,8 @@ class SymbolicObsWrapper(gym.core.ObservationWrapper):
             shape=(self.env.width, self.env.height, 3),  # number of cells
             dtype="uint8",
         )
-        self.observation_space = spaces.Dict({**self.observation_space, 'image':new_image_space})
+        self.observation_space = spaces.Dict(
+            {**self.observation_space, 'image': new_image_space})
 
     def observation(self, obs):
         objects = np.array(