Browse Source

Update FullyObsWrapper (#27)

* Classical env and wrappers (#6, #13, #22)

* Add Classical-v0 4 rooms env #

* Add image wrapper

* Add full state wrapper

* Updated according to #24

* Changed name to FourRooms

* Fix obs space in ObsWrapper

* Add test for FullObsWrapper

* revert

* Updated according to #24

* Changed name to FourRooms

* Fix obs space in ObsWrapper

* Add test for FullObsWrapper

* Removed doors

* Removed test env #24

* Revert minigrid

* Efficient full obs wrapper

* Update wrappers.py

* Fix as in #27

* Accepted changes in #27

* Merged
d3sm0 6 years ago
parent
commit
02c8ee98dd
1 changed files with 7 additions and 7 deletions
  1. 7 7
      gym_minigrid/wrappers.py

+ 7 - 7
gym_minigrid/wrappers.py

@@ -83,7 +83,7 @@ class ImgObsWrapper(gym.core.ObservationWrapper):
     def __init__(self, env):
         super().__init__(env)
         self.__dict__.update(vars(env))  # hack to pass values to super wrapper
-        self.observation_space = env.observation_space['image']
+        self.observation_space = env.observation_space.spaces['image']
 
     def observation(self, obs):
         return obs['image']
@@ -91,7 +91,7 @@ class ImgObsWrapper(gym.core.ObservationWrapper):
 
 class FullyObsWrapper(gym.core.ObservationWrapper):
     """
-    Fully observable gridworld
+    Fully observable gridworld using a compact grid encoding
     """
 
     def __init__(self, env):
@@ -99,15 +99,15 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
         self.__dict__.update(vars(env))  # hack to pass values to super wrapper
         self.observation_space = spaces.Box(
             low=0,
-            high=255,
-            shape=(self.env.grid_size * 32, self.env.grid_size * 32, 3),  # number of cells
+            high=self.env.grid_size,
+            shape=(self.env.grid_size, self.env.grid_size, 3),  # number of cells
             dtype='uint8'
         )
 
     def observation(self, obs):
-        if self.env.grid_render is None:
-            return np.zeros(shape=self.observation_space.shape)  # dark screen as init state?
-        return self.env.grid_render.getArray()
+        full_grid = self.env.grid.encode()
+        full_grid[self.env.agent_pos[0]][self.env.agent_pos[1]] = np.array([255, self.env.agent_dir, 0])
+        return full_grid
 
 
 class FlatObsWrapper(gym.core.ObservationWrapper):