Browse Source

Include all observations in wrapper (#181) (#182)

Howard Huang 2 years ago
parent
commit
29f8a09735
2 changed files with 60 additions and 4 deletions
  1. 4 4
      gym_minigrid/wrappers.py
  2. 56 0
      run_tests.py

+ 4 - 4
gym_minigrid/wrappers.py

@@ -148,7 +148,7 @@ class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
                 out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1
 
         return {
-            'mission': obs['mission'],
+            **obs,
             'image': out
         }
 
@@ -180,7 +180,7 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
         )
 
         return {
-            'mission': obs['mission'],
+            **obs,
             'image': rgb_img
         }
 
@@ -213,7 +213,7 @@ class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
         )
 
         return {
-            'mission': obs['mission'],
+            **obs,
             'image': rgb_img_partial
         }
 
@@ -242,7 +242,7 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
         ])
 
         return {
-            'mission': obs['mission'],
+            **obs,
             'image': full_grid
         }
 

+ 56 - 0
run_tests.py

@@ -3,6 +3,7 @@
 import random
 import numpy as np
 import gym
+import gym_minigrid
 from gym_minigrid.register import env_list
 from gym_minigrid.minigrid import Grid, OBJECT_TO_IDX
 
@@ -133,6 +134,61 @@ for env_idx, env_name in enumerate(env_list):
 
 ##############################################################################
 
+print('testing extra observations')
+class EmptyEnvWithExtraObs(gym_minigrid.envs.EmptyEnv5x5):
+    """
+    Custom environment with an extra observation
+    """
+    def __init__(self) -> None:
+        super().__init__()
+        self.observation_space['size'] = spaces.Box(
+            low=0,
+            high=np.iinfo(np.uint).max,
+            shape=(2,),
+            dtype=np.uint
+        )
+
+    def reset(self):
+        obs = super().reset()
+        obs['size'] = np.array([self.width, self.height])
+        return obs
+
+    def step(self, action):
+        obs, reward, done, info = super().step(action)
+        obs['size'] = np.array([self.width, self.height])
+        return obs, reward, done, info
+
+wrappers = [
+    OneHotPartialObsWrapper,
+    RGBImgObsWrapper,
+    RGBImgPartialObsWrapper,
+    FullyObsWrapper,
+]
+for wrapper in wrappers:
+    env1 = wrapper(EmptyEnvWithExtraObs())
+    env2 = wrapper(gym.make('MiniGrid-Empty-5x5-v0'))
+
+    env1.seed(0)
+    env2.seed(0)
+
+    obs1 = env1.reset()
+    obs2 = env2.reset()
+    assert 'size' in obs1
+    assert obs1['size'].shape == (2,)
+    assert (obs1['size'] == [5,5]).all()
+    for key in obs2:
+        assert np.array_equal(obs1[key], obs2[key])
+
+    obs1, reward1, done1, _ = env1.step(0)
+    obs2, reward2, done2, _ = env2.step(0)
+    assert 'size' in obs1
+    assert obs1['size'].shape == (2,)
+    assert (obs1['size'] == [5,5]).all()
+    for key in obs2:
+        assert np.array_equal(obs1[key], obs2[key])
+
+##############################################################################
+
 print('testing agent_sees method')
 env = gym.make('MiniGrid-DoorKey-6x6-v0')
 goal_pos = (env.grid.width - 2, env.grid.height - 2)