Pārlūkot izejas kodu

revert definition of `EmptyEnvWithExtraObs` to `run_tests.py`

saleml 2 gadi atpakaļ
vecāks
revīzija
6335899c78
2 mainītis faili ar 38 papildinājumiem un 32 dzēšanām
  1. 7 28
      gym_minigrid/envs/empty.py
  2. 31 4
      run_tests.py

+ 7 - 28
gym_minigrid/envs/empty.py

@@ -1,6 +1,7 @@
 from gym_minigrid.minigrid import *
 from gym_minigrid.register import register
 
+
 class EmptyEnv(MiniGridEnv):
     """
     Empty grid environment, no obstacles, sparse reward
@@ -9,7 +10,7 @@ class EmptyEnv(MiniGridEnv):
     def __init__(
         self,
         size=8,
-        agent_start_pos=(1,1),
+        agent_start_pos=(1, 1),
         agent_start_dir=0,
         **kwargs
     ):
@@ -43,48 +44,31 @@ class EmptyEnv(MiniGridEnv):
 
         self.mission = "get to the green goal square"
 
+
 class EmptyEnv5x5(EmptyEnv):
     def __init__(self, **kwargs):
         super().__init__(size=5, **kwargs)
 
+
 class EmptyRandomEnv5x5(EmptyEnv):
     def __init__(self, **kwargs):
         super().__init__(size=5, agent_start_pos=None, **kwargs)
 
+
 class EmptyEnv6x6(EmptyEnv):
     def __init__(self, **kwargs):
         super().__init__(size=6, **kwargs)
 
+
 class EmptyRandomEnv6x6(EmptyEnv):
     def __init__(self, **kwargs):
         super().__init__(size=6, agent_start_pos=None, **kwargs)
 
+
 class EmptyEnv16x16(EmptyEnv):
     def __init__(self, **kwargs):
         super().__init__(size=16, **kwargs)
 
-class EmptyEnvWithExtraObs(EmptyEnv5x5):
-    """
-    Custom environment with an extra observation
-    """
-    def __init__(self, **kwargs) -> None:
-        super().__init__(**kwargs)
-        self.observation_space['size'] = spaces.Box(
-            low=0,
-            high=1000,  #gym does not like np.iinfo(np.uint).max,  
-            shape=(2,),
-            dtype=np.uint
-        )
-
-    def reset(self, **kwargs):
-        obs = super().reset(**kwargs)
-        obs['size'] = np.array([self.width, self.height], dtype=np.uint)
-        return obs
-
-    def step(self, action):
-        obs, reward, done, info = super().step(action)
-        obs['size'] = np.array([self.width, self.height], dtype=np.uint)
-        return obs, reward, done, info
 
 register(
     id='MiniGrid-Empty-5x5-v0',
@@ -115,8 +99,3 @@ register(
     id='MiniGrid-Empty-16x16-v0',
     entry_point='gym_minigrid.envs:EmptyEnv16x16'
 )
-
-register(
-    id='MiniGrid-EmptyWithExtraObs-v0',
-    entry_point='gym_minigrid.envs:EmptyEnvWithExtraObs',
-)

+ 31 - 4
run_tests.py

@@ -119,7 +119,8 @@ for env_idx, env_name in enumerate(env_list):
     env.reset()
     mission = env.mission
     obs, _, _, _ = env.step(0)
-    assert env.string_to_indices(mission) == [value for value in obs['mission'] if value != 0]
+    assert env.string_to_indices(mission) == [
+        value for value in obs['mission'] if value != 0]
     env.close()
 
     # Test the wrappers return proper observation spaces.
@@ -146,6 +147,32 @@ for env_idx, env_name in enumerate(env_list):
 
 print('testing extra observations')
 
+
+class EmptyEnvWithExtraObs(gym_minigrid.envs.EmptyEnv5x5):
+    """
+    Custom environment with an extra observation
+    """
+
+    def __init__(self, **kwargs) -> None:
+        super().__init__(**kwargs)
+        self.observation_space['size'] = spaces.Box(
+            low=0,
+            high=1000,  # gym does not like np.iinfo(np.uint).max,
+            shape=(2,),
+            dtype=np.uint
+        )
+
+    def reset(self, **kwargs):
+        obs = super().reset(**kwargs)
+        obs['size'] = np.array([self.width, self.height], dtype=np.uint)
+        return obs
+
+    def step(self, action):
+        obs, reward, done, info = super().step(action)
+        obs['size'] = np.array([self.width, self.height], dtype=np.uint)
+        return obs, reward, done, info
+
+
 wrappers = [
     OneHotPartialObsWrapper,
     RGBImgObsWrapper,
@@ -153,14 +180,14 @@ wrappers = [
     FullyObsWrapper,
 ]
 for wrapper in wrappers:
-    env1 = wrapper(gym.make('MiniGrid-EmptyWithExtraObs-v0', render_mode='rgb_array'))
+    env1 = wrapper(EmptyEnvWithExtraObs(render_mode='rgb_array'))
     env2 = wrapper(gym.make('MiniGrid-Empty-5x5-v0', render_mode='rgb_array'))
 
     obs1 = env1.reset(seed=0)
     obs2 = env2.reset(seed=0)
     assert 'size' in obs1
     assert obs1['size'].shape == (2,)
-    assert (obs1['size'] == [5,5]).all()
+    assert (obs1['size'] == [5, 5]).all()
     for key in obs2:
         assert np.array_equal(obs1[key], obs2[key])
 
@@ -168,7 +195,7 @@ for wrapper in wrappers:
     obs2, reward2, done2, _ = env2.step(0)
     assert 'size' in obs1
     assert obs1['size'].shape == (2,)
-    assert (obs1['size'] == [5,5]).all()
+    assert (obs1['size'] == [5, 5]).all()
     for key in obs2:
         assert np.array_equal(obs1[key], obs2[key])