Bläddra i källkod

move EmptyEnvWithExtraObs to empty.py so that it can be called the right way with gym.make

saleml 2 år sedan
förälder
incheckning
bec1633024
2 ändrade filer med 28 tillägg och 22 borttagningar
  1. 28 0
      gym_minigrid/envs/empty.py
  2. 0 22
      run_tests.py

+ 28 - 0
gym_minigrid/envs/empty.py

@@ -63,6 +63,29 @@ class EmptyEnv16x16(EmptyEnv):
     def __init__(self, **kwargs):
         super().__init__(size=16, **kwargs)
 
+class EmptyEnvWithExtraObs(EmptyEnv5x5):
+    """
+    Custom environment with an extra observation
+    """
+    def __init__(self, **kwargs) -> None:
+        super().__init__(**kwargs)
+        self.observation_space['size'] = spaces.Box(
+            low=0,
+            high=1000,  #gym does not like np.iinfo(np.uint).max,  
+            shape=(2,),
+            dtype=np.uint
+        )
+
+    def reset(self, **kwargs):
+        obs = super().reset(**kwargs)
+        obs['size'] = np.array([self.width, self.height], dtype=np.uint)
+        return obs
+
+    def step(self, action):
+        obs, reward, done, info = super().step(action)
+        obs['size'] = np.array([self.width, self.height], dtype=np.uint)
+        return obs, reward, done, info
+
 register(
     id='MiniGrid-Empty-5x5-v0',
     entry_point='gym_minigrid.envs:EmptyEnv5x5'
@@ -92,3 +115,8 @@ register(
     id='MiniGrid-Empty-16x16-v0',
     entry_point='gym_minigrid.envs:EmptyEnv16x16'
 )
+
+register(
+    id='MiniGrid-EmptyWithExtraObs-v0',
+    entry_point='gym_minigrid.envs:EmptyEnvWithExtraObs',
+)

+ 0 - 22
run_tests.py

@@ -144,28 +144,6 @@ for env_idx, env_name in enumerate(env_list):
 ##############################################################################
 
 print('testing extra observations')
-class EmptyEnvWithExtraObs(gym_minigrid.envs.EmptyEnv5x5):
-    """
-    Custom environment with an extra observation
-    """
-    def __init__(self) -> None:
-        super().__init__()
-        self.observation_space['size'] = spaces.Box(
-            low=0,
-            high=np.iinfo(np.uint).max,
-            shape=(2,),
-            dtype=np.uint
-        )
-
-    def reset(self, **kwargs):
-        obs = super().reset(**kwargs)
-        obs['size'] = np.array([self.width, self.height])
-        return obs
-
-    def step(self, action):
-        obs, reward, done, info = super().step(action)
-        obs['size'] = np.array([self.width, self.height])
-        return obs, reward, done, info
 
 wrappers = [
     OneHotPartialObsWrapper,