瀏覽代碼

Merge pull request #194 from saleml/minigrid-no-warning

Compatibility with new gym API + New wrapper for out-of-the box usage of RL libraries (e.g. stable_baselines3)
Mark Towers 2 年之前
父節點
當前提交
e2b77bd006

+ 1 - 1
.travis.yml

@@ -1,6 +1,6 @@
 language: python
 language: python
 python:
 python:
-  - "3.5"
+  - "3.9"
 
 
 # command to install dependencies
 # command to install dependencies
 install:
 install:

+ 2 - 2
README.md

@@ -10,8 +10,8 @@ laptop, which means you can run your experiments faster. A known-working RL
 implementation can be found [in this repository](https://github.com/lcswillems/torch-rl).
 implementation can be found [in this repository](https://github.com/lcswillems/torch-rl).
 
 
 Requirements:
 Requirements:
-- Python 3.5+
-- OpenAI Gym
+- Python 3.7+
+- OpenAI Gym 0.25+
 - NumPy
 - NumPy
 - Matplotlib (optional, only needed for display)
 - Matplotlib (optional, only needed for display)
 
 

+ 2 - 2
benchmark.py

@@ -17,7 +17,7 @@ parser.add_argument("--num_resets", default=200)
 parser.add_argument("--num_frames", default=5000)
 parser.add_argument("--num_frames", default=5000)
 args = parser.parse_args()
 args = parser.parse_args()
 
 
-env = gym.make(args.env_name)
+env = gym.make(args.env_name, render_mode='rgb_array')
 
 
 # Benchmark env.reset
 # Benchmark env.reset
 t0 = time.time()
 t0 = time.time()
@@ -30,7 +30,7 @@ reset_time = (1000 * dt) / args.num_resets
 # Benchmark rendering
 # Benchmark rendering
 t0 = time.time()
 t0 = time.time()
 for i in range(args.num_frames):
 for i in range(args.num_frames):
-    env.render('rgb_array')
+    env.render()
 t1 = time.time()
 t1 = time.time()
 dt = t1 - t0
 dt = t1 - t0
 frames_per_sec = args.num_frames / dt
 frames_per_sec = args.num_frames / dt

+ 2 - 2
gym_minigrid/envs/blockedunlockpickup.py

@@ -8,14 +8,14 @@ class BlockedUnlockPickup(RoomGrid):
     in another room
     in another room
     """
     """
 
 
-    def __init__(self, seed=None):
+    def __init__(self, **kwargs):
         room_size = 6
         room_size = 6
         super().__init__(
         super().__init__(
             num_rows=1,
             num_rows=1,
             num_cols=2,
             num_cols=2,
             room_size=room_size,
             room_size=room_size,
             max_steps=16*room_size**2,
             max_steps=16*room_size**2,
-            seed=seed
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):

+ 18 - 18
gym_minigrid/envs/crossing.py

@@ -9,7 +9,7 @@ class CrossingEnv(MiniGridEnv):
     Environment with wall or lava obstacles, sparse reward.
     Environment with wall or lava obstacles, sparse reward.
     """
     """
 
 
-    def __init__(self, size=9, num_crossings=1, obstacle_type=Lava, seed=None):
+    def __init__(self, size=9, num_crossings=1, obstacle_type=Lava, **kwargs):
         self.num_crossings = num_crossings
         self.num_crossings = num_crossings
         self.obstacle_type = obstacle_type
         self.obstacle_type = obstacle_type
         super().__init__(
         super().__init__(
@@ -17,7 +17,7 @@ class CrossingEnv(MiniGridEnv):
             max_steps=4*size*size,
             max_steps=4*size*size,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
             see_through_walls=False,
             see_through_walls=False,
-            seed=None
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -83,20 +83,20 @@ class CrossingEnv(MiniGridEnv):
         )
         )
 
 
 class LavaCrossingEnv(CrossingEnv):
 class LavaCrossingEnv(CrossingEnv):
-    def __init__(self):
-        super().__init__(size=9, num_crossings=1)
+    def __init__(self, **kwargs):
+        super().__init__(size=9, num_crossings=1, **kwargs)
 
 
 class LavaCrossingS9N2Env(CrossingEnv):
 class LavaCrossingS9N2Env(CrossingEnv):
-    def __init__(self):
-        super().__init__(size=9, num_crossings=2)
+    def __init__(self, **kwargs):
+        super().__init__(size=9, num_crossings=2, **kwargs)
 
 
 class LavaCrossingS9N3Env(CrossingEnv):
 class LavaCrossingS9N3Env(CrossingEnv):
-    def __init__(self):
-        super().__init__(size=9, num_crossings=3)
+    def __init__(self, **kwargs):
+        super().__init__(size=9, num_crossings=3, **kwargs)
 
 
 class LavaCrossingS11N5Env(CrossingEnv):
 class LavaCrossingS11N5Env(CrossingEnv):
-    def __init__(self):
-        super().__init__(size=11, num_crossings=5)
+    def __init__(self, **kwargs):
+        super().__init__(size=11, num_crossings=5, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-LavaCrossingS9N1-v0',
     id='MiniGrid-LavaCrossingS9N1-v0',
@@ -119,20 +119,20 @@ register(
 )
 )
 
 
 class SimpleCrossingEnv(CrossingEnv):
 class SimpleCrossingEnv(CrossingEnv):
-    def __init__(self):
-        super().__init__(size=9, num_crossings=1, obstacle_type=Wall)
+    def __init__(self, **kwargs):
+        super().__init__(size=9, num_crossings=1, obstacle_type=Wall, **kwargs)
 
 
 class SimpleCrossingS9N2Env(CrossingEnv):
 class SimpleCrossingS9N2Env(CrossingEnv):
-    def __init__(self):
-        super().__init__(size=9, num_crossings=2, obstacle_type=Wall)
+    def __init__(self, **kwargs):
+        super().__init__(size=9, num_crossings=2, obstacle_type=Wall, **kwargs)
 
 
 class SimpleCrossingS9N3Env(CrossingEnv):
 class SimpleCrossingS9N3Env(CrossingEnv):
-    def __init__(self):
-        super().__init__(size=9, num_crossings=3, obstacle_type=Wall)
+    def __init__(self, **kwargs):
+        super().__init__(size=9, num_crossings=3, obstacle_type=Wall, **kwargs)
 
 
 class SimpleCrossingS11N5Env(CrossingEnv):
 class SimpleCrossingS11N5Env(CrossingEnv):
-    def __init__(self):
-        super().__init__(size=11, num_crossings=5, obstacle_type=Wall)
+    def __init__(self, **kwargs):
+        super().__init__(size=11, num_crossings=5, obstacle_type=Wall, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-SimpleCrossingS9N1-v0',
     id='MiniGrid-SimpleCrossingS9N1-v0',

+ 8 - 6
gym_minigrid/envs/distshift.py

@@ -12,7 +12,8 @@ class DistShiftEnv(MiniGridEnv):
         height=7,
         height=7,
         agent_start_pos=(1,1),
         agent_start_pos=(1,1),
         agent_start_dir=0,
         agent_start_dir=0,
-        strip2_row=2
+        strip2_row=2,
+        **kwargs
     ):
     ):
         self.agent_start_pos = agent_start_pos
         self.agent_start_pos = agent_start_pos
         self.agent_start_dir = agent_start_dir
         self.agent_start_dir = agent_start_dir
@@ -24,7 +25,8 @@ class DistShiftEnv(MiniGridEnv):
             height=height,
             height=height,
             max_steps=4*width*height,
             max_steps=4*width*height,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
-            see_through_walls=True
+            see_through_walls=True,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -52,12 +54,12 @@ class DistShiftEnv(MiniGridEnv):
         self.mission = "get to the green goal square"
         self.mission = "get to the green goal square"
 
 
 class DistShift1(DistShiftEnv):
 class DistShift1(DistShiftEnv):
-    def __init__(self):
-        super().__init__(strip2_row=2)
+    def __init__(self, **kwargs):
+        super().__init__(strip2_row=2, **kwargs)
 
 
 class DistShift2(DistShiftEnv):
 class DistShift2(DistShiftEnv):
-    def __init__(self):
-        super().__init__(strip2_row=5)
+    def __init__(self, **kwargs):
+        super().__init__(strip2_row=5, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-DistShift1-v0',
     id='MiniGrid-DistShift1-v0',

+ 15 - 8
gym_minigrid/envs/doorkey.py

@@ -1,15 +1,18 @@
 from gym_minigrid.minigrid import *
 from gym_minigrid.minigrid import *
 from gym_minigrid.register import register
 from gym_minigrid.register import register
 
 
+
 class DoorKeyEnv(MiniGridEnv):
 class DoorKeyEnv(MiniGridEnv):
     """
     """
     Environment with a door and key, sparse reward
     Environment with a door and key, sparse reward
     """
     """
 
 
-    def __init__(self, size=8, max_steps=None):
+    def __init__(self, size=8, **kwargs):
+        if 'max_steps' not in kwargs:
+            kwargs['max_steps'] = 10 * size * size
         super().__init__(
         super().__init__(
             grid_size=size,
             grid_size=size,
-            max_steps=10*size*size if max_steps is None else max_steps
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -43,17 +46,21 @@ class DoorKeyEnv(MiniGridEnv):
 
 
         self.mission = "use the key to open the door and then get to the goal"
         self.mission = "use the key to open the door and then get to the goal"
 
 
+
 class DoorKeyEnv5x5(DoorKeyEnv):
 class DoorKeyEnv5x5(DoorKeyEnv):
-    def __init__(self, max_steps=None):
-        super().__init__(size=5, max_steps=max_steps)
+    def __init__(self, **kwargs):
+        super().__init__(size=5, **kwargs)
+
 
 
 class DoorKeyEnv6x6(DoorKeyEnv):
 class DoorKeyEnv6x6(DoorKeyEnv):
-    def __init__(self, max_steps=None):
-        super().__init__(size=6, max_steps=max_steps)
+    def __init__(self, **kwargs):
+        super().__init__(size=6, **kwargs)
+
 
 
 class DoorKeyEnv16x16(DoorKeyEnv):
 class DoorKeyEnv16x16(DoorKeyEnv):
-    def __init__(self, max_steps=None):
-        super().__init__(size=16, max_steps=max_steps)
+    def __init__(self, **kwargs):
+        super().__init__(size=16, **kwargs)
+
 
 
 register(
 register(
     id='MiniGrid-DoorKey-5x5-v0',
     id='MiniGrid-DoorKey-5x5-v0',

+ 13 - 11
gym_minigrid/envs/dynamicobstacles.py

@@ -12,7 +12,8 @@ class DynamicObstaclesEnv(MiniGridEnv):
             size=8,
             size=8,
             agent_start_pos=(1, 1),
             agent_start_pos=(1, 1),
             agent_start_dir=0,
             agent_start_dir=0,
-            n_obstacles=4
+            n_obstacles=4,
+            **kwargs
     ):
     ):
         self.agent_start_pos = agent_start_pos
         self.agent_start_pos = agent_start_pos
         self.agent_start_dir = agent_start_dir
         self.agent_start_dir = agent_start_dir
@@ -27,6 +28,7 @@ class DynamicObstaclesEnv(MiniGridEnv):
             max_steps=4 * size * size,
             max_steps=4 * size * size,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
             see_through_walls=True,
             see_through_walls=True,
+            **kwargs
         )
         )
         # Allow only 3 actions permitted: left, right, forward
         # Allow only 3 actions permitted: left, right, forward
         self.action_space = spaces.Discrete(self.actions.forward + 1)
         self.action_space = spaces.Discrete(self.actions.forward + 1)
@@ -89,24 +91,24 @@ class DynamicObstaclesEnv(MiniGridEnv):
         return obs, reward, done, info
         return obs, reward, done, info
 
 
 class DynamicObstaclesEnv5x5(DynamicObstaclesEnv):
 class DynamicObstaclesEnv5x5(DynamicObstaclesEnv):
-    def __init__(self):
-        super().__init__(size=5, n_obstacles=2)
+    def __init__(self, **kwargs):
+        super().__init__(size=5, n_obstacles=2, **kwargs)
 
 
 class DynamicObstaclesRandomEnv5x5(DynamicObstaclesEnv):
 class DynamicObstaclesRandomEnv5x5(DynamicObstaclesEnv):
-    def __init__(self):
-        super().__init__(size=5, agent_start_pos=None, n_obstacles=2)
+    def __init__(self, **kwargs):
+        super().__init__(size=5, agent_start_pos=None, n_obstacles=2, **kwargs)
 
 
 class DynamicObstaclesEnv6x6(DynamicObstaclesEnv):
 class DynamicObstaclesEnv6x6(DynamicObstaclesEnv):
-    def __init__(self):
-        super().__init__(size=6, n_obstacles=3)
+    def __init__(self, **kwargs):
+        super().__init__(size=6, n_obstacles=3, **kwargs)
 
 
 class DynamicObstaclesRandomEnv6x6(DynamicObstaclesEnv):
 class DynamicObstaclesRandomEnv6x6(DynamicObstaclesEnv):
-    def __init__(self):
-        super().__init__(size=6, agent_start_pos=None, n_obstacles=3)
+    def __init__(self, **kwargs):
+        super().__init__(size=6, agent_start_pos=None, n_obstacles=3, **kwargs)
 
 
 class DynamicObstaclesEnv16x16(DynamicObstaclesEnv):
 class DynamicObstaclesEnv16x16(DynamicObstaclesEnv):
-    def __init__(self):
-        super().__init__(size=16, n_obstacles=8)
+    def __init__(self, **kwargs):
+        super().__init__(size=16, n_obstacles=8, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-Dynamic-Obstacles-5x5-v0',
     id='MiniGrid-Dynamic-Obstacles-5x5-v0',

+ 15 - 6
gym_minigrid/envs/empty.py

@@ -1,6 +1,7 @@
 from gym_minigrid.minigrid import *
 from gym_minigrid.minigrid import *
 from gym_minigrid.register import register
 from gym_minigrid.register import register
 
 
+
 class EmptyEnv(MiniGridEnv):
 class EmptyEnv(MiniGridEnv):
     """
     """
     Empty grid environment, no obstacles, sparse reward
     Empty grid environment, no obstacles, sparse reward
@@ -9,8 +10,9 @@ class EmptyEnv(MiniGridEnv):
     def __init__(
     def __init__(
         self,
         self,
         size=8,
         size=8,
-        agent_start_pos=(1,1),
+        agent_start_pos=(1, 1),
         agent_start_dir=0,
         agent_start_dir=0,
+        **kwargs
     ):
     ):
         self.agent_start_pos = agent_start_pos
         self.agent_start_pos = agent_start_pos
         self.agent_start_dir = agent_start_dir
         self.agent_start_dir = agent_start_dir
@@ -19,7 +21,8 @@ class EmptyEnv(MiniGridEnv):
             grid_size=size,
             grid_size=size,
             max_steps=4*size*size,
             max_steps=4*size*size,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
-            see_through_walls=True
+            see_through_walls=True,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -41,26 +44,32 @@ class EmptyEnv(MiniGridEnv):
 
 
         self.mission = "get to the green goal square"
         self.mission = "get to the green goal square"
 
 
+
 class EmptyEnv5x5(EmptyEnv):
 class EmptyEnv5x5(EmptyEnv):
     def __init__(self, **kwargs):
     def __init__(self, **kwargs):
         super().__init__(size=5, **kwargs)
         super().__init__(size=5, **kwargs)
 
 
+
 class EmptyRandomEnv5x5(EmptyEnv):
 class EmptyRandomEnv5x5(EmptyEnv):
-    def __init__(self):
-        super().__init__(size=5, agent_start_pos=None)
+    def __init__(self, **kwargs):
+        super().__init__(size=5, agent_start_pos=None, **kwargs)
+
 
 
 class EmptyEnv6x6(EmptyEnv):
 class EmptyEnv6x6(EmptyEnv):
     def __init__(self, **kwargs):
     def __init__(self, **kwargs):
         super().__init__(size=6, **kwargs)
         super().__init__(size=6, **kwargs)
 
 
+
 class EmptyRandomEnv6x6(EmptyEnv):
 class EmptyRandomEnv6x6(EmptyEnv):
-    def __init__(self):
-        super().__init__(size=6, agent_start_pos=None)
+    def __init__(self, **kwargs):
+        super().__init__(size=6, agent_start_pos=None, **kwargs)
+
 
 
 class EmptyEnv16x16(EmptyEnv):
 class EmptyEnv16x16(EmptyEnv):
     def __init__(self, **kwargs):
     def __init__(self, **kwargs):
         super().__init__(size=16, **kwargs)
         super().__init__(size=16, **kwargs)
 
 
+
 register(
 register(
     id='MiniGrid-Empty-5x5-v0',
     id='MiniGrid-Empty-5x5-v0',
     entry_point='gym_minigrid.envs:EmptyEnv5x5'
     entry_point='gym_minigrid.envs:EmptyEnv5x5'

+ 8 - 6
gym_minigrid/envs/fetch.py

@@ -10,7 +10,8 @@ class FetchEnv(MiniGridEnv):
     def __init__(
     def __init__(
         self,
         self,
         size=8,
         size=8,
-        numObjs=3
+        numObjs=3,
+        **kwargs
     ):
     ):
         self.numObjs = numObjs
         self.numObjs = numObjs
 
 
@@ -18,7 +19,8 @@ class FetchEnv(MiniGridEnv):
             grid_size=size,
             grid_size=size,
             max_steps=5*size**2,
             max_steps=5*size**2,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
-            see_through_walls=True
+            see_through_walls=True,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -86,12 +88,12 @@ class FetchEnv(MiniGridEnv):
         return obs, reward, done, info
         return obs, reward, done, info
 
 
 class FetchEnv5x5N2(FetchEnv):
 class FetchEnv5x5N2(FetchEnv):
-    def __init__(self):
-        super().__init__(size=5, numObjs=2)
+    def __init__(self, **kwargs):
+        super().__init__(size=5, numObjs=2, **kwargs)
 
 
 class FetchEnv6x6N2(FetchEnv):
 class FetchEnv6x6N2(FetchEnv):
-    def __init__(self):
-        super().__init__(size=6, numObjs=2)
+    def __init__(self, **kwargs):
+        super().__init__(size=6, numObjs=2, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-Fetch-5x5-N2-v0',
     id='MiniGrid-Fetch-5x5-N2-v0',

+ 3 - 3
gym_minigrid/envs/fourrooms.py

@@ -11,10 +11,10 @@ class FourRoomsEnv(MiniGridEnv):
     Can specify agent and goal position, if not it set at random.
     Can specify agent and goal position, if not it set at random.
     """
     """
 
 
-    def __init__(self, agent_pos=None, goal_pos=None):
+    def __init__(self, agent_pos=None, goal_pos=None, **kwargs):
         self._agent_default_pos = agent_pos
         self._agent_default_pos = agent_pos
         self._goal_default_pos = goal_pos
         self._goal_default_pos = goal_pos
-        super().__init__(grid_size=19, max_steps=100)
+        super().__init__(grid_size=19, max_steps=100, **kwargs)
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
         # Create the grid
         # Create the grid
@@ -66,7 +66,7 @@ class FourRoomsEnv(MiniGridEnv):
         else:
         else:
             self.place_obj(Goal())
             self.place_obj(Goal())
 
 
-        self.mission = 'Reach the goal'
+        self.mission = 'reach the goal'
 
 
     def step(self, action):
     def step(self, action):
         obs, reward, done, info = MiniGridEnv.step(self, action)
         obs, reward, done, info = MiniGridEnv.step(self, action)

+ 8 - 6
gym_minigrid/envs/gotodoor.py

@@ -9,7 +9,8 @@ class GoToDoorEnv(MiniGridEnv):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        size=5
+        size=5,
+        **kwargs
     ):
     ):
         assert size >= 5
         assert size >= 5
 
 
@@ -17,7 +18,8 @@ class GoToDoorEnv(MiniGridEnv):
             grid_size=size,
             grid_size=size,
             max_steps=5*size**2,
             max_steps=5*size**2,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
-            see_through_walls=True
+            see_through_walls=True,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -81,12 +83,12 @@ class GoToDoorEnv(MiniGridEnv):
         return obs, reward, done, info
         return obs, reward, done, info
 
 
 class GoToDoor8x8Env(GoToDoorEnv):
 class GoToDoor8x8Env(GoToDoorEnv):
-    def __init__(self):
-        super().__init__(size=8)
+    def __init__(self, **kwargs):
+        super().__init__(size=8, **kwargs)
 
 
 class GoToDoor6x6Env(GoToDoorEnv):
 class GoToDoor6x6Env(GoToDoorEnv):
-    def __init__(self):
-        super().__init__(size=6)
+    def __init__(self, **kwargs):
+        super().__init__(size=6, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-GoToDoor-5x5-v0',
     id='MiniGrid-GoToDoor-5x5-v0',

+ 6 - 4
gym_minigrid/envs/gotoobject.py

@@ -10,7 +10,8 @@ class GoToObjectEnv(MiniGridEnv):
     def __init__(
     def __init__(
         self,
         self,
         size=6,
         size=6,
-        numObjs=2
+        numObjs=2,
+        **kwargs
     ):
     ):
         self.numObjs = numObjs
         self.numObjs = numObjs
 
 
@@ -18,7 +19,8 @@ class GoToObjectEnv(MiniGridEnv):
             grid_size=size,
             grid_size=size,
             max_steps=5*size**2,
             max_steps=5*size**2,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
-            see_through_walls=True
+            see_through_walls=True,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -84,8 +86,8 @@ class GoToObjectEnv(MiniGridEnv):
         return obs, reward, done, info
         return obs, reward, done, info
 
 
 class GotoEnv8x8N2(GoToObjectEnv):
 class GotoEnv8x8N2(GoToObjectEnv):
-    def __init__(self):
-        super().__init__(size=8, numObjs=2)
+    def __init__(self, **kwargs):
+        super().__init__(size=8, numObjs=2, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-GoToObject-6x6-N2-v0',
     id='MiniGrid-GoToObject-6x6-N2-v0',

+ 14 - 14
gym_minigrid/envs/keycorridor.py

@@ -12,7 +12,7 @@ class KeyCorridor(RoomGrid):
         num_rows=3,
         num_rows=3,
         obj_type="ball",
         obj_type="ball",
         room_size=6,
         room_size=6,
-        seed=None
+        **kwargs
     ):
     ):
         self.obj_type = obj_type
         self.obj_type = obj_type
 
 
@@ -20,7 +20,7 @@ class KeyCorridor(RoomGrid):
             room_size=room_size,
             room_size=room_size,
             num_rows=num_rows,
             num_rows=num_rows,
             max_steps=30*room_size**2,
             max_steps=30*room_size**2,
-            seed=seed,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -59,51 +59,51 @@ class KeyCorridor(RoomGrid):
         return obs, reward, done, info
         return obs, reward, done, info
 
 
 class KeyCorridorS3R1(KeyCorridor):
 class KeyCorridorS3R1(KeyCorridor):
-    def __init__(self, seed=None):
+    def __init__(self, **kwargs):
         super().__init__(
         super().__init__(
             room_size=3,
             room_size=3,
             num_rows=1,
             num_rows=1,
-            seed=seed
+            **kwargs
         )
         )
 
 
 class KeyCorridorS3R2(KeyCorridor):
 class KeyCorridorS3R2(KeyCorridor):
-    def __init__(self, seed=None):
+    def __init__(self, **kwargs):
         super().__init__(
         super().__init__(
             room_size=3,
             room_size=3,
             num_rows=2,
             num_rows=2,
-            seed=seed
+            **kwargs
         )
         )
 
 
 class KeyCorridorS3R3(KeyCorridor):
 class KeyCorridorS3R3(KeyCorridor):
-    def __init__(self, seed=None):
+    def __init__(self, **kwargs):
         super().__init__(
         super().__init__(
             room_size=3,
             room_size=3,
             num_rows=3,
             num_rows=3,
-            seed=seed
+            **kwargs
         )
         )
 
 
 class KeyCorridorS4R3(KeyCorridor):
 class KeyCorridorS4R3(KeyCorridor):
-    def __init__(self, seed=None):
+    def __init__(self,  **kwargs):
         super().__init__(
         super().__init__(
             room_size=4,
             room_size=4,
             num_rows=3,
             num_rows=3,
-            seed=seed
+            **kwargs
         )
         )
 
 
 class KeyCorridorS5R3(KeyCorridor):
 class KeyCorridorS5R3(KeyCorridor):
-    def __init__(self, seed=None):
+    def __init__(self, **kwargs):
         super().__init__(
         super().__init__(
             room_size=5,
             room_size=5,
             num_rows=3,
             num_rows=3,
-            seed=seed
+            **kwargs
         )
         )
 
 
 class KeyCorridorS6R3(KeyCorridor):
 class KeyCorridorS6R3(KeyCorridor):
-    def __init__(self, seed=None):
+    def __init__(self,  **kwargs):
         super().__init__(
         super().__init__(
             room_size=6,
             room_size=6,
             num_rows=3,
             num_rows=3,
-            seed=seed
+            **kwargs
         )
         )
 
 
 register(
 register(

+ 8 - 8
gym_minigrid/envs/lavagap.py

@@ -7,14 +7,14 @@ class LavaGapEnv(MiniGridEnv):
     This environment is similar to LavaCrossing but simpler in structure.
     This environment is similar to LavaCrossing but simpler in structure.
     """
     """
 
 
-    def __init__(self, size, obstacle_type=Lava, seed=None):
+    def __init__(self, size, obstacle_type=Lava, **kwargs):
         self.obstacle_type = obstacle_type
         self.obstacle_type = obstacle_type
         super().__init__(
         super().__init__(
             grid_size=size,
             grid_size=size,
             max_steps=4*size*size,
             max_steps=4*size*size,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
             see_through_walls=False,
             see_through_walls=False,
-            seed=None
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -53,16 +53,16 @@ class LavaGapEnv(MiniGridEnv):
         )
         )
 
 
 class LavaGapS5Env(LavaGapEnv):
 class LavaGapS5Env(LavaGapEnv):
-    def __init__(self):
-        super().__init__(size=5)
+    def __init__(self, **kwargs):
+        super().__init__(size=5, **kwargs)
 
 
 class LavaGapS6Env(LavaGapEnv):
 class LavaGapS6Env(LavaGapEnv):
-    def __init__(self):
-        super().__init__(size=6)
+    def __init__(self, **kwargs):
+        super().__init__(size=6, **kwargs)
 
 
 class LavaGapS7Env(LavaGapEnv):
 class LavaGapS7Env(LavaGapEnv):
-    def __init__(self):
-        super().__init__(size=7)
+    def __init__(self, **kwargs):
+        super().__init__(size=7, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-LavaGapS5-v0',
     id='MiniGrid-LavaGapS5-v0',

+ 3 - 2
gym_minigrid/envs/lockedroom.py

@@ -30,9 +30,10 @@ class LockedRoom(MiniGridEnv):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        size=19
+        size=19, 
+        **kwargs
     ):
     ):
-        super().__init__(grid_size=size, max_steps=10*size)
+        super().__init__(grid_size=size, max_steps=10*size, **kwargs)
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
         # Create the grid
         # Create the grid

+ 16 - 16
gym_minigrid/envs/memory.py

@@ -13,17 +13,17 @@ class MemoryEnv(MiniGridEnv):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        seed,
         size=8,
         size=8,
-        random_length=False,
+        random_length=False, 
+        **kwargs
     ):
     ):
         self.random_length = random_length
         self.random_length = random_length
         super().__init__(
         super().__init__(
-            seed=seed,
             grid_size=size,
             grid_size=size,
             max_steps=5*size**2,
             max_steps=5*size**2,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
-            see_through_walls=False,
+            see_through_walls=False, 
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -100,8 +100,8 @@ class MemoryEnv(MiniGridEnv):
         return obs, reward, done, info
         return obs, reward, done, info
 
 
 class MemoryS17Random(MemoryEnv):
 class MemoryS17Random(MemoryEnv):
-    def __init__(self, seed=None):
-        super().__init__(seed=seed, size=17, random_length=True)
+    def __init__(self, **kwargs):
+        super().__init__(size=17, random_length=True, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-MemoryS17Random-v0',
     id='MiniGrid-MemoryS17Random-v0',
@@ -109,8 +109,8 @@ register(
 )
 )
 
 
 class MemoryS13Random(MemoryEnv):
 class MemoryS13Random(MemoryEnv):
-    def __init__(self, seed=None):
-        super().__init__(seed=seed, size=13, random_length=True)
+    def __init__(self, **kwargs):
+        super().__init__(size=13, random_length=True, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-MemoryS13Random-v0',
     id='MiniGrid-MemoryS13Random-v0',
@@ -118,8 +118,8 @@ register(
 )
 )
 
 
 class MemoryS13(MemoryEnv):
 class MemoryS13(MemoryEnv):
-    def __init__(self, seed=None):
-        super().__init__(seed=seed, size=13)
+    def __init__(self, **kwargs):
+        super().__init__(size=13, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-MemoryS13-v0',
     id='MiniGrid-MemoryS13-v0',
@@ -127,8 +127,8 @@ register(
 )
 )
 
 
 class MemoryS11(MemoryEnv):
 class MemoryS11(MemoryEnv):
-    def __init__(self, seed=None):
-        super().__init__(seed=seed, size=11)
+    def __init__(self, **kwargs):
+        super().__init__(size=11, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-MemoryS11-v0',
     id='MiniGrid-MemoryS11-v0',
@@ -136,8 +136,8 @@ register(
 )
 )
 
 
 class MemoryS9(MemoryEnv):
 class MemoryS9(MemoryEnv):
-    def __init__(self, seed=None):
-        super().__init__(seed=seed, size=9)
+    def __init__(self, **kwargs):
+        super().__init__(size=9, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-MemoryS9-v0',
     id='MiniGrid-MemoryS9-v0',
@@ -145,8 +145,8 @@ register(
 )
 )
 
 
 class MemoryS7(MemoryEnv):
 class MemoryS7(MemoryEnv):
-    def __init__(self, seed=None):
-        super().__init__(seed=seed, size=7)
+    def __init__(self, **kwargs):
+        super().__init__(size=7, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-MemoryS7-v0',
     id='MiniGrid-MemoryS7-v0',

+ 13 - 8
gym_minigrid/envs/multiroom.py

@@ -21,7 +21,8 @@ class MultiRoomEnv(MiniGridEnv):
     def __init__(self,
     def __init__(self,
         minNumRooms,
         minNumRooms,
         maxNumRooms,
         maxNumRooms,
-        maxRoomSize=10
+        maxRoomSize=10,
+        **kwargs
     ):
     ):
         assert minNumRooms > 0
         assert minNumRooms > 0
         assert maxNumRooms >= minNumRooms
         assert maxNumRooms >= minNumRooms
@@ -35,7 +36,8 @@ class MultiRoomEnv(MiniGridEnv):
 
 
         super(MultiRoomEnv, self).__init__(
         super(MultiRoomEnv, self).__init__(
             grid_size=25,
             grid_size=25,
-            max_steps=self.maxNumRooms * 20
+            max_steps=self.maxNumRooms * 20,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -237,26 +239,29 @@ class MultiRoomEnv(MiniGridEnv):
         return True
         return True
 
 
 class MultiRoomEnvN2S4(MultiRoomEnv):
 class MultiRoomEnvN2S4(MultiRoomEnv):
-    def __init__(self):
+    def __init__(self, **kwargs):
         super().__init__(
         super().__init__(
             minNumRooms=2,
             minNumRooms=2,
             maxNumRooms=2,
             maxNumRooms=2,
-            maxRoomSize=4
+            maxRoomSize=4,
+            **kwargs
         )
         )
 
 
 class MultiRoomEnvN4S5(MultiRoomEnv):
 class MultiRoomEnvN4S5(MultiRoomEnv):
-    def __init__(self):
+    def __init__(self, **kwargs):
         super().__init__(
         super().__init__(
             minNumRooms=4,
             minNumRooms=4,
             maxNumRooms=4,
             maxNumRooms=4,
-            maxRoomSize=5
+            maxRoomSize=5,
+            **kwargs
         )
         )
 
 
 class MultiRoomEnvN6(MultiRoomEnv):
 class MultiRoomEnvN6(MultiRoomEnv):
-    def __init__(self):
+    def __init__(self, **kwargs):
         super().__init__(
         super().__init__(
             minNumRooms=6,
             minNumRooms=6,
-            maxNumRooms=6
+            maxNumRooms=6,
+            **kwargs
         )
         )
 
 
 register(
 register(

+ 21 - 21
gym_minigrid/envs/obstructedmaze.py

@@ -12,7 +12,7 @@ class ObstructedMazeEnv(RoomGrid):
         num_rows,
         num_rows,
         num_cols,
         num_cols,
         num_rooms_visited,
         num_rooms_visited,
-        seed=None
+        **kwargs
     ):
     ):
         room_size = 6
         room_size = 6
         max_steps = 4*num_rooms_visited*room_size**2
         max_steps = 4*num_rooms_visited*room_size**2
@@ -21,8 +21,8 @@ class ObstructedMazeEnv(RoomGrid):
             room_size=room_size,
             room_size=room_size,
             num_rows=num_rows,
             num_rows=num_rows,
             num_cols=num_cols,
             num_cols=num_cols,
-            max_steps=max_steps,
-            seed=seed
+            max_steps=max_steps,        
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -79,7 +79,7 @@ class ObstructedMaze_1Dlhb(ObstructedMazeEnv):
     rooms. Doors are obstructed by a ball and keys are hidden in boxes.
     rooms. Doors are obstructed by a ball and keys are hidden in boxes.
     """
     """
 
 
-    def __init__(self, key_in_box=True, blocked=True, seed=None):
+    def __init__(self, key_in_box=True, blocked=True, **kwargs):
         self.key_in_box = key_in_box
         self.key_in_box = key_in_box
         self.blocked = blocked
         self.blocked = blocked
 
 
@@ -87,7 +87,7 @@ class ObstructedMaze_1Dlhb(ObstructedMazeEnv):
             num_rows=1,
             num_rows=1,
             num_cols=2,
             num_cols=2,
             num_rooms_visited=2,
             num_rooms_visited=2,
-            seed=seed
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -102,12 +102,12 @@ class ObstructedMaze_1Dlhb(ObstructedMazeEnv):
         self.place_agent(0, 0)
         self.place_agent(0, 0)
 
 
 class ObstructedMaze_1Dl(ObstructedMaze_1Dlhb):
 class ObstructedMaze_1Dl(ObstructedMaze_1Dlhb):
-    def __init__(self, seed=None):
-        super().__init__(False, False, seed)
+    def __init__(self, **kwargs):
+        super().__init__(False, False, **kwargs)
 
 
 class ObstructedMaze_1Dlh(ObstructedMaze_1Dlhb):
 class ObstructedMaze_1Dlh(ObstructedMaze_1Dlhb):
-    def __init__(self, seed=None):
-        super().__init__(True, False, seed)
+    def __init__(self, **kwargs):
+        super().__init__(True, False, **kwargs)
 
 
 class ObstructedMaze_Full(ObstructedMazeEnv):
 class ObstructedMaze_Full(ObstructedMazeEnv):
     """
     """
@@ -117,7 +117,7 @@ class ObstructedMaze_Full(ObstructedMazeEnv):
     """
     """
 
 
     def __init__(self, agent_room=(1, 1), key_in_box=True, blocked=True,
     def __init__(self, agent_room=(1, 1), key_in_box=True, blocked=True,
-                 num_quarters=4, num_rooms_visited=25, seed=None):
+                 num_quarters=4, num_rooms_visited=25, **kwargs):
         self.agent_room = agent_room
         self.agent_room = agent_room
         self.key_in_box = key_in_box
         self.key_in_box = key_in_box
         self.blocked = blocked
         self.blocked = blocked
@@ -127,7 +127,7 @@ class ObstructedMaze_Full(ObstructedMazeEnv):
             num_rows=3,
             num_rows=3,
             num_cols=3,
             num_cols=3,
             num_rooms_visited=num_rooms_visited,
             num_rooms_visited=num_rooms_visited,
-            seed=seed
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -158,25 +158,25 @@ class ObstructedMaze_Full(ObstructedMazeEnv):
         self.place_agent(*self.agent_room)
         self.place_agent(*self.agent_room)
 
 
 class ObstructedMaze_2Dl(ObstructedMaze_Full):
 class ObstructedMaze_2Dl(ObstructedMaze_Full):
-    def __init__(self, seed=None):
-        super().__init__((2, 1), False, False, 1, 4, seed)
+    def __init__(self, **kwargs):
+        super().__init__((2, 1), False, False, 1, 4, **kwargs)
 
 
 class ObstructedMaze_2Dlh(ObstructedMaze_Full):
 class ObstructedMaze_2Dlh(ObstructedMaze_Full):
-    def __init__(self, seed=None):
-        super().__init__((2, 1), True, False, 1, 4, seed)
+    def __init__(self, **kwargs):
+        super().__init__((2, 1), True, False, 1, 4, **kwargs)
 
 
 
 
 class ObstructedMaze_2Dlhb(ObstructedMaze_Full):
 class ObstructedMaze_2Dlhb(ObstructedMaze_Full):
-    def __init__(self, seed=None):
-        super().__init__((2, 1), True, True, 1, 4, seed)
+    def __init__(self, **kwargs):
+        super().__init__((2, 1), True, True, 1, 4, **kwargs)
 
 
 class ObstructedMaze_1Q(ObstructedMaze_Full):
 class ObstructedMaze_1Q(ObstructedMaze_Full):
-    def __init__(self, seed=None):
-        super().__init__((1, 1), True, True, 1, 5, seed)
+    def __init__(self, **kwargs):
+        super().__init__((1, 1), True, True, 1, 5, **kwargs)
 
 
 class ObstructedMaze_2Q(ObstructedMaze_Full):
 class ObstructedMaze_2Q(ObstructedMaze_Full):
-    def __init__(self, seed=None):
-        super().__init__((1, 1), True, True, 2, 11, seed)
+    def __init__(self, **kwargs):
+        super().__init__((1, 1), True, True, 2, 11, **kwargs)
 
 
 register(
 register(
     id="MiniGrid-ObstructedMaze-1Dl-v0",
     id="MiniGrid-ObstructedMaze-1Dl-v0",

+ 2 - 2
gym_minigrid/envs/playground_v0.py

@@ -7,8 +7,8 @@ class PlaygroundV0(MiniGridEnv):
     This environment has no specific goals or rewards.
     This environment has no specific goals or rewards.
     """
     """
 
 
-    def __init__(self):
-        super().__init__(grid_size=19, max_steps=100)
+    def __init__(self, **kwargs):
+        super().__init__(grid_size=19, max_steps=100, **kwargs)
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
         # Create the grid
         # Create the grid

+ 6 - 4
gym_minigrid/envs/putnear.py

@@ -10,7 +10,8 @@ class PutNearEnv(MiniGridEnv):
     def __init__(
     def __init__(
         self,
         self,
         size=6,
         size=6,
-        numObjs=2
+        numObjs=2, 
+        **kwargs
     ):
     ):
         self.numObjs = numObjs
         self.numObjs = numObjs
 
 
@@ -18,7 +19,8 @@ class PutNearEnv(MiniGridEnv):
             grid_size=size,
             grid_size=size,
             max_steps=5*size,
             max_steps=5*size,
             # Set this to True for maximum speed
             # Set this to True for maximum speed
-            see_through_walls=True
+            see_through_walls=True, 
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -112,8 +114,8 @@ class PutNearEnv(MiniGridEnv):
         return obs, reward, done, info
         return obs, reward, done, info
 
 
 class PutNear8x8N3(PutNearEnv):
 class PutNear8x8N3(PutNearEnv):
-    def __init__(self):
-        super().__init__(size=8, numObjs=3)
+    def __init__(self, **kwargs):
+        super().__init__(size=8, numObjs=3, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-PutNear-6x6-N2-v0',
     id='MiniGrid-PutNear-6x6-N2-v0',

+ 5 - 4
gym_minigrid/envs/redbluedoors.py

@@ -8,13 +8,14 @@ class RedBlueDoorEnv(MiniGridEnv):
     obtain a reward.
     obtain a reward.
     """
     """
 
 
-    def __init__(self, size=8):
+    def __init__(self, size=8, **kwargs):
         self.size = size
         self.size = size
 
 
         super().__init__(
         super().__init__(
             width=2*size,
             width=2*size,
             height=size,
             height=size,
-            max_steps=20*size*size
+            max_steps=20*size*size,
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):
@@ -66,8 +67,8 @@ class RedBlueDoorEnv(MiniGridEnv):
         return obs, reward, done, info
         return obs, reward, done, info
 
 
 class RedBlueDoorEnv6x6(RedBlueDoorEnv):
 class RedBlueDoorEnv6x6(RedBlueDoorEnv):
-    def __init__(self):
-        super().__init__(size=6)
+    def __init__(self, **kwargs):
+        super().__init__(size=6, **kwargs)
 
 
 register(
 register(
     id='MiniGrid-RedBlueDoors-6x6-v0',
     id='MiniGrid-RedBlueDoors-6x6-v0',

+ 2 - 2
gym_minigrid/envs/unlock.py

@@ -7,14 +7,14 @@ class Unlock(RoomGrid):
     Unlock a door
     Unlock a door
     """
     """
 
 
-    def __init__(self, seed=None):
+    def __init__(self, **kwargs):
         room_size = 6
         room_size = 6
         super().__init__(
         super().__init__(
             num_rows=1,
             num_rows=1,
             num_cols=2,
             num_cols=2,
             room_size=room_size,
             room_size=room_size,
             max_steps=8*room_size**2,
             max_steps=8*room_size**2,
-            seed=seed
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):

+ 2 - 2
gym_minigrid/envs/unlockpickup.py

@@ -7,14 +7,14 @@ class UnlockPickup(RoomGrid):
     Unlock a door, then pick up a box in another room
     Unlock a door, then pick up a box in another room
     """
     """
 
 
-    def __init__(self, seed=None):
+    def __init__(self, **kwargs):
         room_size = 6
         room_size = 6
         super().__init__(
         super().__init__(
             num_rows=1,
             num_rows=1,
             num_cols=2,
             num_cols=2,
             room_size=room_size,
             room_size=room_size,
             max_steps=8*room_size**2,
             max_steps=8*room_size**2,
-            seed=seed
+            **kwargs
         )
         )
 
 
     def _gen_grid(self, width, height):
     def _gen_grid(self, width, height):

+ 116 - 85
gym_minigrid/minigrid.py

@@ -1,10 +1,10 @@
 import math
 import math
 import hashlib
 import hashlib
+import string
 import gym
 import gym
 from enum import IntEnum
 from enum import IntEnum
 import numpy as np
 import numpy as np
 from gym import error, spaces, utils
 from gym import error, spaces, utils
-from gym.utils import seeding
 from .rendering import *
 from .rendering import *
 
 
 # Size in pixels of a tile in the full-scale human view
 # Size in pixels of a tile in the full-scale human view
@@ -12,48 +12,48 @@ TILE_PIXELS = 32
 
 
 # Map of color names to RGB values
 # Map of color names to RGB values
 COLORS = {
 COLORS = {
-    'red'   : np.array([255, 0, 0]),
-    'green' : np.array([0, 255, 0]),
-    'blue'  : np.array([0, 0, 255]),
+    'red': np.array([255, 0, 0]),
+    'green': np.array([0, 255, 0]),
+    'blue': np.array([0, 0, 255]),
     'purple': np.array([112, 39, 195]),
     'purple': np.array([112, 39, 195]),
     'yellow': np.array([255, 255, 0]),
     'yellow': np.array([255, 255, 0]),
-    'grey'  : np.array([100, 100, 100])
+    'grey': np.array([100, 100, 100])
 }
 }
 
 
 COLOR_NAMES = sorted(list(COLORS.keys()))
 COLOR_NAMES = sorted(list(COLORS.keys()))
 
 
 # Used to map colors to integers
 # Used to map colors to integers
 COLOR_TO_IDX = {
 COLOR_TO_IDX = {
-    'red'   : 0,
-    'green' : 1,
-    'blue'  : 2,
+    'red': 0,
+    'green': 1,
+    'blue': 2,
     'purple': 3,
     'purple': 3,
     'yellow': 4,
     'yellow': 4,
-    'grey'  : 5
+    'grey': 5
 }
 }
 
 
 IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
 IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
 
 
 # Map of object type to integers
 # Map of object type to integers
 OBJECT_TO_IDX = {
 OBJECT_TO_IDX = {
-    'unseen'        : 0,
-    'empty'         : 1,
-    'wall'          : 2,
-    'floor'         : 3,
-    'door'          : 4,
-    'key'           : 5,
-    'ball'          : 6,
-    'box'           : 7,
-    'goal'          : 8,
-    'lava'          : 9,
-    'agent'         : 10,
+    'unseen': 0,
+    'empty': 1,
+    'wall': 2,
+    'floor': 3,
+    'door': 4,
+    'key': 5,
+    'ball': 6,
+    'box': 7,
+    'goal': 8,
+    'lava': 9,
+    'agent': 10,
 }
 }
 
 
 IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
 IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
 
 
 # Map of state names to integers
 # Map of state names to integers
 STATE_TO_IDX = {
 STATE_TO_IDX = {
-    'open'  : 0,
+    'open': 0,
     'closed': 1,
     'closed': 1,
     'locked': 2,
     'locked': 2,
 }
 }
@@ -70,6 +70,7 @@ DIR_TO_VEC = [
     np.array((0, -1)),
     np.array((0, -1)),
 ]
 ]
 
 
+
 class WorldObj:
 class WorldObj:
     """
     """
     Base class for grid world objects
     Base class for grid world objects
@@ -151,6 +152,7 @@ class WorldObj:
         """Draw this object with the given renderer"""
         """Draw this object with the given renderer"""
         raise NotImplementedError
         raise NotImplementedError
 
 
+
 class Goal(WorldObj):
 class Goal(WorldObj):
     def __init__(self):
     def __init__(self):
         super().__init__('goal', 'green')
         super().__init__('goal', 'green')
@@ -161,6 +163,7 @@ class Goal(WorldObj):
     def render(self, img):
     def render(self, img):
         fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
         fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
 
 
+
 class Floor(WorldObj):
 class Floor(WorldObj):
     """
     """
     Colored floor tile the agent can walk over
     Colored floor tile the agent can walk over
@@ -195,10 +198,15 @@ class Lava(WorldObj):
         for i in range(3):
         for i in range(3):
             ylo = 0.3 + 0.2 * i
             ylo = 0.3 + 0.2 * i
             yhi = 0.4 + 0.2 * i
             yhi = 0.4 + 0.2 * i
-            fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0,0,0))
-            fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0,0,0))
-            fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0,0,0))
-            fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0,0,0))
+            fill_coords(img, point_in_line(
+                0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
+            fill_coords(img, point_in_line(
+                0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
+            fill_coords(img, point_in_line(
+                0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
+            fill_coords(img, point_in_line(
+                0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
+
 
 
 class Wall(WorldObj):
 class Wall(WorldObj):
     def __init__(self, color='grey'):
     def __init__(self, color='grey'):
@@ -210,6 +218,7 @@ class Wall(WorldObj):
     def render(self, img):
     def render(self, img):
         fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
         fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
 
 
+
 class Door(WorldObj):
 class Door(WorldObj):
     def __init__(self, color, is_open=False, is_locked=False):
     def __init__(self, color, is_open=False, is_locked=False):
         super().__init__('door', color)
         super().__init__('door', color)
@@ -253,25 +262,27 @@ class Door(WorldObj):
 
 
         if self.is_open:
         if self.is_open:
             fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
             fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
-            fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0,0,0))
+            fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
             return
             return
 
 
         # Door frame and door
         # Door frame and door
         if self.is_locked:
         if self.is_locked:
             fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
             fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
-            fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
+            fill_coords(img, point_in_rect(
+                0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
 
 
             # Draw key slot
             # Draw key slot
             fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
             fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
         else:
         else:
             fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
             fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
-            fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0,0,0))
+            fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
             fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
             fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
-            fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0,0,0))
+            fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
 
 
             # Draw door handle
             # Draw door handle
             fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
             fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
 
 
+
 class Key(WorldObj):
 class Key(WorldObj):
     def __init__(self, color='blue'):
     def __init__(self, color='blue'):
         super(Key, self).__init__('key', color)
         super(Key, self).__init__('key', color)
@@ -291,7 +302,8 @@ class Key(WorldObj):
 
 
         # Ring
         # Ring
         fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
         fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
-        fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0,0,0))
+        fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
+
 
 
 class Ball(WorldObj):
 class Ball(WorldObj):
     def __init__(self, color='blue'):
     def __init__(self, color='blue'):
@@ -303,6 +315,7 @@ class Ball(WorldObj):
     def render(self, img):
     def render(self, img):
         fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
         fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
 
 
+
 class Box(WorldObj):
 class Box(WorldObj):
     def __init__(self, color, contains=None):
     def __init__(self, color, contains=None):
         super(Box, self).__init__('box', color)
         super(Box, self).__init__('box', color)
@@ -316,7 +329,7 @@ class Box(WorldObj):
 
 
         # Outline
         # Outline
         fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
         fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
-        fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0,0,0))
+        fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
 
 
         # Horizontal slit
         # Horizontal slit
         fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
         fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
@@ -326,6 +339,7 @@ class Box(WorldObj):
         env.grid.set(*pos, self.contains)
         env.grid.set(*pos, self.contains)
         return True
         return True
 
 
+
 class Grid:
 class Grid:
     """
     """
     Represent a grid and operations on it
     Represent a grid and operations on it
@@ -359,7 +373,7 @@ class Grid:
         return False
         return False
 
 
     def __eq__(self, other):
     def __eq__(self, other):
-        grid1  = self.encode()
+        grid1 = self.encode()
         grid2 = other.encode()
         grid2 = other.encode()
         return np.array_equal(grid2, grid1)
         return np.array_equal(grid2, grid1)
 
 
@@ -454,7 +468,8 @@ class Grid:
         if key in cls.tile_cache:
         if key in cls.tile_cache:
             return cls.tile_cache[key]
             return cls.tile_cache[key]
 
 
-        img = np.zeros(shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8)
+        img = np.zeros(shape=(tile_size * subdivs,
+                       tile_size * subdivs, 3), dtype=np.uint8)
 
 
         # Draw the grid lines (top and left edges)
         # Draw the grid lines (top and left edges)
         fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
         fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
@@ -472,7 +487,8 @@ class Grid:
             )
             )
 
 
             # Rotate the agent based on its direction
             # Rotate the agent based on its direction
-            tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir)
+            tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5,
+                               theta=0.5*math.pi*agent_dir)
             fill_coords(img, tri_fn, (255, 0, 0))
             fill_coords(img, tri_fn, (255, 0, 0))
 
 
         # Highlight the cell if needed
         # Highlight the cell if needed
@@ -501,7 +517,8 @@ class Grid:
         """
         """
 
 
         if highlight_mask is None:
         if highlight_mask is None:
-            highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
+            highlight_mask = np.zeros(
+                shape=(self.width, self.height), dtype=bool)
 
 
         # Compute the total grid size
         # Compute the total grid size
         width_px = self.width * tile_size
         width_px = self.width * tile_size
@@ -615,14 +632,18 @@ class Grid:
 
 
         return mask
         return mask
 
 
+
 class MiniGridEnv(gym.Env):
 class MiniGridEnv(gym.Env):
     """
     """
     2D grid world game environment
     2D grid world game environment
     """
     """
 
 
     metadata = {
     metadata = {
+        # Deprecated: use 'render_modes' instead
         'render.modes': ['human', 'rgb_array'],
         'render.modes': ['human', 'rgb_array'],
-        'video.frames_per_second' : 10
+        'video.frames_per_second': 10,  # Deprecated: use 'render_fps' instead
+        'render_modes': ['human', 'rgb_array'],
+        'render_fps': 10
     }
     }
 
 
     # Enumeration of possible actions
     # Enumeration of possible actions
@@ -649,8 +670,9 @@ class MiniGridEnv(gym.Env):
         height=None,
         height=None,
         max_steps=100,
         max_steps=100,
         see_through_walls=False,
         see_through_walls=False,
-        seed=1337,
-        agent_view_size=7
+        agent_view_size=7,
+        render_mode=None,
+        **kwargs
     ):
     ):
         # Can't set both grid_size and width/height
         # Can't set both grid_size and width/height
         if grid_size:
         if grid_size:
@@ -678,9 +700,16 @@ class MiniGridEnv(gym.Env):
             dtype='uint8'
             dtype='uint8'
         )
         )
         self.observation_space = spaces.Dict({
         self.observation_space = spaces.Dict({
-            'image': self.observation_space
+            'image': self.observation_space,
+            'direction': spaces.Discrete(4),
+            'mission': spaces.Text(max_length=200,
+                                   charset=string.ascii_letters + string.digits + ' .,!-'
+                                   )
         })
         })
 
 
+        # render mode
+        self.render_mode = render_mode
+
         # Range of possible rewards
         # Range of possible rewards
         self.reward_range = (0, 1)
         self.reward_range = (0, 1)
 
 
@@ -697,20 +726,16 @@ class MiniGridEnv(gym.Env):
         self.agent_pos = None
         self.agent_pos = None
         self.agent_dir = None
         self.agent_dir = None
 
 
-        # Initialize the RNG
-        self.seed(seed=seed)
-
         # Initialize the state
         # Initialize the state
         self.reset()
         self.reset()
 
 
-    def reset(self):
+    def reset(self, *, seed=None, return_info=False, options=None):
+        super().reset(seed=seed)
         # Current position and direction of the agent
         # Current position and direction of the agent
         self.agent_pos = None
         self.agent_pos = None
         self.agent_dir = None
         self.agent_dir = None
 
 
         # Generate a new random grid at the start of each episode
         # Generate a new random grid at the start of each episode
-        # To keep the same grid for each episode, call env.seed() with
-        # the same seed before calling env.reset()
         self._gen_grid(self.width, self.height)
         self._gen_grid(self.width, self.height)
 
 
         # These fields should be defined by _gen_grid
         # These fields should be defined by _gen_grid
@@ -731,18 +756,14 @@ class MiniGridEnv(gym.Env):
         obs = self.gen_obs()
         obs = self.gen_obs()
         return obs
         return obs
 
 
-    def seed(self, seed=1337):
-        # Seed the random number generator
-        self.np_random, _ = seeding.np_random(seed)
-        return [seed]
-
     def hash(self, size=16):
     def hash(self, size=16):
         """Compute a hash that uniquely identifies the current state of the environment.
         """Compute a hash that uniquely identifies the current state of the environment.
         :param size: Size of the hashing
         :param size: Size of the hashing
         """
         """
         sample_hash = hashlib.sha256()
         sample_hash = hashlib.sha256()
 
 
-        to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
+        to_encode = [self.grid.encode().tolist(), self.agent_pos,
+                     self.agent_dir]
         for item in to_encode:
         for item in to_encode:
             sample_hash.update(str(item).encode('utf8'))
             sample_hash.update(str(item).encode('utf8'))
 
 
@@ -761,14 +782,14 @@ class MiniGridEnv(gym.Env):
 
 
         # Map of object types to short string
         # Map of object types to short string
         OBJECT_TO_STR = {
         OBJECT_TO_STR = {
-            'wall'          : 'W',
-            'floor'         : 'F',
-            'door'          : 'D',
-            'key'           : 'K',
-            'ball'          : 'A',
-            'box'           : 'B',
-            'goal'          : 'G',
-            'lava'          : 'V',
+            'wall': 'W',
+            'floor': 'F',
+            'door': 'D',
+            'key': 'K',
+            'ball': 'A',
+            'box': 'B',
+            'goal': 'G',
+            'lava': 'V',
         }
         }
 
 
         # Short string for opened door
         # Short string for opened door
@@ -828,7 +849,7 @@ class MiniGridEnv(gym.Env):
         Generate random integer in [low,high[
         Generate random integer in [low,high[
         """
         """
 
 
-        return self.np_random.randint(low, high)
+        return self.np_random.integers(low, high)
 
 
     def _rand_float(self, low, high):
     def _rand_float(self, low, high):
         """
         """
@@ -842,7 +863,7 @@ class MiniGridEnv(gym.Env):
         Generate random boolean value
         Generate random boolean value
         """
         """
 
 
-        return (self.np_random.randint(0, 2) == 0)
+        return (self.np_random.integers(0, 2) == 0)
 
 
     def _rand_elem(self, iterable):
     def _rand_elem(self, iterable):
         """
         """
@@ -883,17 +904,17 @@ class MiniGridEnv(gym.Env):
         """
         """
 
 
         return (
         return (
-            self.np_random.randint(xLow, xHigh),
-            self.np_random.randint(yLow, yHigh)
+            self.np_random.integers(xLow, xHigh),
+            self.np_random.integers(yLow, yHigh)
         )
         )
 
 
     def place_obj(self,
     def place_obj(self,
-        obj,
-        top=None,
-        size=None,
-        reject_fn=None,
-        max_tries=math.inf
-    ):
+                  obj,
+                  top=None,
+                  size=None,
+                  reject_fn=None,
+                  max_tries=math.inf
+                  ):
         """
         """
         Place an object at an empty position in the grid
         Place an object at an empty position in the grid
 
 
@@ -1030,33 +1051,36 @@ class MiniGridEnv(gym.Env):
 
 
         return vx, vy
         return vx, vy
 
 
-    def get_view_exts(self):
+    def get_view_exts(self, agent_view_size=None):
         """
         """
         Get the extents of the square set of tiles visible to the agent
         Get the extents of the square set of tiles visible to the agent
         Note: the bottom extent indices are not included in the set
         Note: the bottom extent indices are not included in the set
+        if agent_view_size is None, use self.agent_view_size
         """
         """
 
 
+        agent_view_size = agent_view_size or self.agent_view_size
+
         # Facing right
         # Facing right
         if self.agent_dir == 0:
         if self.agent_dir == 0:
             topX = self.agent_pos[0]
             topX = self.agent_pos[0]
-            topY = self.agent_pos[1] - self.agent_view_size // 2
+            topY = self.agent_pos[1] - agent_view_size // 2
         # Facing down
         # Facing down
         elif self.agent_dir == 1:
         elif self.agent_dir == 1:
-            topX = self.agent_pos[0] - self.agent_view_size // 2
+            topX = self.agent_pos[0] - agent_view_size // 2
             topY = self.agent_pos[1]
             topY = self.agent_pos[1]
         # Facing left
         # Facing left
         elif self.agent_dir == 2:
         elif self.agent_dir == 2:
-            topX = self.agent_pos[0] - self.agent_view_size + 1
-            topY = self.agent_pos[1] - self.agent_view_size // 2
+            topX = self.agent_pos[0] - agent_view_size + 1
+            topY = self.agent_pos[1] - agent_view_size // 2
         # Facing up
         # Facing up
         elif self.agent_dir == 3:
         elif self.agent_dir == 3:
-            topX = self.agent_pos[0] - self.agent_view_size // 2
-            topY = self.agent_pos[1] - self.agent_view_size + 1
+            topX = self.agent_pos[0] - agent_view_size // 2
+            topY = self.agent_pos[1] - agent_view_size + 1
         else:
         else:
             assert False, "invalid agent direction"
             assert False, "invalid agent direction"
 
 
-        botX = topX + self.agent_view_size
-        botY = topY + self.agent_view_size
+        botX = topX + agent_view_size
+        botY = topY + agent_view_size
 
 
         return (topX, topY, botX, botY)
         return (topX, topY, botX, botY)
 
 
@@ -1162,16 +1186,19 @@ class MiniGridEnv(gym.Env):
 
 
         return obs, reward, done, {}
         return obs, reward, done, {}
 
 
-    def gen_obs_grid(self):
+    def gen_obs_grid(self, agent_view_size=None):
         """
         """
         Generate the sub-grid observed by the agent.
         Generate the sub-grid observed by the agent.
         This method also outputs a visibility mask telling us which grid
         This method also outputs a visibility mask telling us which grid
         cells the agent can actually see.
         cells the agent can actually see.
+        if agent_view_size is None, self.agent_view_size is used
         """
         """
 
 
-        topX, topY, botX, botY = self.get_view_exts()
+        topX, topY, botX, botY = self.get_view_exts(agent_view_size)
 
 
-        grid = self.grid.slice(topX, topY, self.agent_view_size, self.agent_view_size)
+        agent_view_size = agent_view_size or self.agent_view_size
+
+        grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
 
 
         for i in range(self.agent_dir + 1):
         for i in range(self.agent_dir + 1):
             grid = grid.rotate_left()
             grid = grid.rotate_left()
@@ -1179,7 +1206,8 @@ class MiniGridEnv(gym.Env):
         # Process occluders and visibility
         # Process occluders and visibility
         # Note that this incurs some performance cost
         # Note that this incurs some performance cost
         if not self.see_through_walls:
         if not self.see_through_walls:
-            vis_mask = grid.process_vis(agent_pos=(self.agent_view_size // 2 , self.agent_view_size - 1))
+            vis_mask = grid.process_vis(agent_pos=(
+                agent_view_size // 2, agent_view_size - 1))
         else:
         else:
             vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
             vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
 
 
@@ -1204,7 +1232,8 @@ class MiniGridEnv(gym.Env):
         # Encode the partially observable view into a numpy array
         # Encode the partially observable view into a numpy array
         image = grid.encode(vis_mask)
         image = grid.encode(vis_mask)
 
 
-        assert hasattr(self, 'mission'), "environments must define a textual mission string"
+        assert hasattr(
+            self, 'mission'), "environments must define a textual mission string"
 
 
         # Observations are dictionaries containing:
         # Observations are dictionaries containing:
         # - an image (partially observable view of the environment)
         # - an image (partially observable view of the environment)
@@ -1239,7 +1268,8 @@ class MiniGridEnv(gym.Env):
         """
         """
         Render the whole-grid human view
         Render the whole-grid human view
         """
         """
-
+        if self.render_mode is not None:
+            mode = self.render_mode
         if close:
         if close:
             if self.window:
             if self.window:
                 self.window.close()
                 self.window.close()
@@ -1257,7 +1287,8 @@ class MiniGridEnv(gym.Env):
         # of the agent's view area
         # of the agent's view area
         f_vec = self.dir_vec
         f_vec = self.dir_vec
         r_vec = self.right_vec
         r_vec = self.right_vec
-        top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
+        top_left = self.agent_pos + f_vec * \
+            (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
 
 
         # Mask of which cells to highlight
         # Mask of which cells to highlight
         highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
         highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)

+ 4 - 4
gym_minigrid/roomgrid.py

@@ -72,8 +72,8 @@ class RoomGrid(MiniGridEnv):
         num_rows=3,
         num_rows=3,
         num_cols=3,
         num_cols=3,
         max_steps=100,
         max_steps=100,
-        seed=0,
-        agent_view_size=7
+        agent_view_size=7,
+        **kwargs
     ):
     ):
         assert room_size > 0
         assert room_size > 0
         assert room_size >= 3
         assert room_size >= 3
@@ -94,8 +94,8 @@ class RoomGrid(MiniGridEnv):
             height=height,
             height=height,
             max_steps=max_steps,
             max_steps=max_steps,
             see_through_walls=False,
             see_through_walls=False,
-            seed=seed,
-            agent_view_size=agent_view_size
+            agent_view_size=agent_view_size,
+            **kwargs
         )
         )
 
 
     def room_from_pos(self, x, y):
     def room_from_pos(self, x, y):

+ 1 - 1
gym_minigrid/window.py

@@ -23,7 +23,7 @@ class Window:
         self.fig, self.ax = plt.subplots()
         self.fig, self.ax = plt.subplots()
 
 
         # Show the env name in the window title
         # Show the env name in the window title
-        self.fig.canvas.set_window_title(title)
+        self.fig.canvas.manager.set_window_title(title)
 
 
         # Turn off x/y axis numbering/ticks
         # Turn off x/y axis numbering/ticks
         self.ax.xaxis.set_ticks_position('none')
         self.ax.xaxis.set_ticks_position('none')

+ 153 - 40
gym_minigrid/wrappers.py

@@ -7,7 +7,8 @@ import gym
 from gym import error, spaces, utils
 from gym import error, spaces, utils
 from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX, Goal
 from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX, Goal
 
 
-class ReseedWrapper(gym.core.Wrapper):
+
+class ReseedWrapper(gym.Wrapper):
     """
     """
     Wrapper to always regenerate an environment with the same set of seeds.
     Wrapper to always regenerate an environment with the same set of seeds.
     This can be used to force an environment to always keep the same
     This can be used to force an environment to always keep the same
@@ -22,14 +23,14 @@ class ReseedWrapper(gym.core.Wrapper):
     def reset(self, **kwargs):
     def reset(self, **kwargs):
         seed = self.seeds[self.seed_idx]
         seed = self.seeds[self.seed_idx]
         self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
         self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
-        self.env.seed(seed)
-        return self.env.reset(**kwargs)
+        return self.env.reset(seed=seed, **kwargs)
 
 
     def step(self, action):
     def step(self, action):
         obs, reward, done, info = self.env.step(action)
         obs, reward, done, info = self.env.step(action)
         return obs, reward, done, info
         return obs, reward, done, info
 
 
-class ActionBonus(gym.core.Wrapper):
+
+class ActionBonus(gym.Wrapper):
     """
     """
     Wrapper which adds an exploration bonus.
     Wrapper which adds an exploration bonus.
     This is a reward to encourage exploration of less
     This is a reward to encourage exploration of less
@@ -63,7 +64,8 @@ class ActionBonus(gym.core.Wrapper):
     def reset(self, **kwargs):
     def reset(self, **kwargs):
         return self.env.reset(**kwargs)
         return self.env.reset(**kwargs)
 
 
-class StateBonus(gym.core.Wrapper):
+
+class StateBonus(gym.Wrapper):
     """
     """
     Adds an exploration bonus based on which positions
     Adds an exploration bonus based on which positions
     are visited on the grid.
     are visited on the grid.
@@ -98,7 +100,8 @@ class StateBonus(gym.core.Wrapper):
     def reset(self, **kwargs):
     def reset(self, **kwargs):
         return self.env.reset(**kwargs)
         return self.env.reset(**kwargs)
 
 
-class ImgObsWrapper(gym.core.ObservationWrapper):
+
+class ImgObsWrapper(gym.ObservationWrapper):
     """
     """
     Use the image as the only observation output, no language/mission.
     Use the image as the only observation output, no language/mission.
     """
     """
@@ -110,7 +113,8 @@ class ImgObsWrapper(gym.core.ObservationWrapper):
     def observation(self, obs):
     def observation(self, obs):
         return obs['image']
         return obs['image']
 
 
-class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
+
+class OneHotPartialObsWrapper(gym.ObservationWrapper):
     """
     """
     Wrapper to get a one-hot encoding of a partially observable
     Wrapper to get a one-hot encoding of a partially observable
     agent view as observation.
     agent view as observation.
@@ -126,16 +130,19 @@ class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
         # Number of bits per cell
         # Number of bits per cell
         num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
         num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
 
 
-        self.observation_space.spaces["image"] = spaces.Box(
+        new_image_space = spaces.Box(
             low=0,
             low=0,
             high=255,
             high=255,
             shape=(obs_shape[0], obs_shape[1], num_bits),
             shape=(obs_shape[0], obs_shape[1], num_bits),
             dtype='uint8'
             dtype='uint8'
         )
         )
+        self.observation_space = spaces.Dict(
+            {**self.observation_space.spaces, 'image': new_image_space})
 
 
     def observation(self, obs):
     def observation(self, obs):
         img = obs['image']
         img = obs['image']
-        out = np.zeros(self.observation_space.spaces['image'].shape, dtype='uint8')
+        out = np.zeros(
+            self.observation_space.spaces['image'].shape, dtype='uint8')
 
 
         for i in range(img.shape[0]):
         for i in range(img.shape[0]):
             for j in range(img.shape[1]):
             for j in range(img.shape[1]):
@@ -152,10 +159,12 @@ class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
             'image': out
             'image': out
         }
         }
 
 
-class RGBImgObsWrapper(gym.core.ObservationWrapper):
+
+class RGBImgObsWrapper(gym.ObservationWrapper):
     """
     """
     Wrapper to use fully observable RGB image as observation,
     Wrapper to use fully observable RGB image as observation,
     This can be used to have the agent to solve the gridworld in pixel space.
     This can be used to have the agent to solve the gridworld in pixel space.
+    To use it, make the unwrapped environment with render_mode='rgb_array'.
     """
     """
 
 
     def __init__(self, env, tile_size=8):
     def __init__(self, env, tile_size=8):
@@ -163,18 +172,21 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
 
 
         self.tile_size = tile_size
         self.tile_size = tile_size
 
 
-        self.observation_space.spaces['image'] = spaces.Box(
+        new_image_space = spaces.Box(
             low=0,
             low=0,
             high=255,
             high=255,
             shape=(self.env.width * tile_size, self.env.height * tile_size, 3),
             shape=(self.env.width * tile_size, self.env.height * tile_size, 3),
             dtype='uint8'
             dtype='uint8'
         )
         )
 
 
+        self.observation_space = spaces.Dict(
+            {**self.observation_space.spaces, 'image': new_image_space})
+
     def observation(self, obs):
     def observation(self, obs):
         env = self.unwrapped
         env = self.unwrapped
+        assert env.render_mode == 'rgb_array', env.render_mode
 
 
         rgb_img = env.render(
         rgb_img = env.render(
-            mode='rgb_array',
             highlight=False,
             highlight=False,
             tile_size=self.tile_size
             tile_size=self.tile_size
         )
         )
@@ -185,7 +197,7 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
         }
         }
 
 
 
 
-class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
+class RGBImgPartialObsWrapper(gym.ObservationWrapper):
     """
     """
     Wrapper to use partially observable RGB image as observation.
     Wrapper to use partially observable RGB image as observation.
     This can be used to have the agent to solve the gridworld in pixel space.
     This can be used to have the agent to solve the gridworld in pixel space.
@@ -197,13 +209,16 @@ class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
         self.tile_size = tile_size
         self.tile_size = tile_size
 
 
         obs_shape = env.observation_space.spaces['image'].shape
         obs_shape = env.observation_space.spaces['image'].shape
-        self.observation_space.spaces['image'] = spaces.Box(
+        new_image_space = spaces.Box(
             low=0,
             low=0,
             high=255,
             high=255,
             shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
             shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
             dtype='uint8'
             dtype='uint8'
         )
         )
 
 
+        self.observation_space = spaces.Dict(
+            {**self.observation_space.spaces, 'image': new_image_space})
+
     def observation(self, obs):
     def observation(self, obs):
         env = self.unwrapped
         env = self.unwrapped
 
 
@@ -217,7 +232,8 @@ class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
             'image': rgb_img_partial
             'image': rgb_img_partial
         }
         }
 
 
-class FullyObsWrapper(gym.core.ObservationWrapper):
+
+class FullyObsWrapper(gym.ObservationWrapper):
     """
     """
     Fully observable gridworld using a compact grid encoding
     Fully observable gridworld using a compact grid encoding
     """
     """
@@ -225,13 +241,16 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
     def __init__(self, env):
     def __init__(self, env):
         super().__init__(env)
         super().__init__(env)
 
 
-        self.observation_space.spaces["image"] = spaces.Box(
+        new_image_space = spaces.Box(
             low=0,
             low=0,
             high=255,
             high=255,
             shape=(self.env.width, self.env.height, 3),  # number of cells
             shape=(self.env.width, self.env.height, 3),  # number of cells
             dtype='uint8'
             dtype='uint8'
         )
         )
 
 
+        self.observation_space = spaces.Dict(
+            {**self.observation_space.spaces, 'image': new_image_space})
+
     def observation(self, obs):
     def observation(self, obs):
         env = self.unwrapped
         env = self.unwrapped
         full_grid = env.grid.encode()
         full_grid = env.grid.encode()
@@ -246,7 +265,83 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
             'image': full_grid
             'image': full_grid
         }
         }
 
 
-class FlatObsWrapper(gym.core.ObservationWrapper):
+
+class DictObservationSpaceWrapper(gym.ObservationWrapper):
+    """
+    Transforms the observation space (that has a textual component) to a fully numerical observation space,
+    where the textual instructions are replaced by arrays representing the indices of each word in a fixed vocabulary.
+    """
+
+    def __init__(self, env, max_words_in_mission=50, word_dict=None):
+        """
+        max_words_in_mission is the length of the array to represent a mission, value 0 for missing words
+        word_dict is a dictionary of words to use (keys=words, values=indices from 1 to < max_words_in_mission),
+                  if None, use the Minigrid language
+        """
+        super().__init__(env)
+
+        if word_dict is None:
+            word_dict = self.get_minigrid_words()
+
+        self.max_words_in_mission = max_words_in_mission
+        self.word_dict = word_dict
+
+        image_observation_space = spaces.Box(
+            low=0,
+            high=255,
+            shape=(self.agent_view_size, self.agent_view_size, 3),
+            dtype='uint8'
+        )
+        self.observation_space = spaces.Dict({
+            'image': image_observation_space,
+            'direction': spaces.Discrete(4),
+            'mission': spaces.MultiDiscrete([len(self.word_dict.keys())]
+                                            * max_words_in_mission)
+        })
+
+    @staticmethod
+    def get_minigrid_words():
+        colors = ['red', 'green', 'blue', 'yellow', 'purple', 'grey']
+        objects = ['unseen', 'empty', 'wall', 'floor', 'box', 'key', 'ball',
+                   'door', 'goal', 'agent', 'lava']
+
+        verbs = ['pick', 'avoid', 'get', 'find', 'put',
+                 'use', 'open', 'go', 'fetch',
+                 'reach', 'unlock', 'traverse']
+
+        extra_words = ['up', 'the', 'a', 'at', ',', 'square',
+                       'and', 'then', 'to', 'of', 'rooms', 'near',
+                       'opening', 'must', 'you', 'matching', 'end',
+                       'hallway', 'object', 'from', 'room']
+
+        all_words = colors + objects + verbs + extra_words
+        assert len(all_words) == len(set(all_words))
+        return {word: i for i, word in enumerate(all_words)}
+
+    def string_to_indices(self, string, offset=1):
+        """
+        Convert a string to a list of indices.
+        """
+        indices = []
+        # adding space before and after commas
+        string = string.replace(',', ' , ')
+        for word in string.split():
+            if word in self.word_dict.keys():
+                indices.append(self.word_dict[word] + offset)
+            else:
+                raise ValueError('Unknown word: {}'.format(word))
+        return indices
+
+    def observation(self, obs):
+        obs['mission'] = self.string_to_indices(obs['mission'])
+        assert len(obs['mission']) < self.max_words_in_mission
+        obs['mission'] += [0] * \
+            (self.max_words_in_mission - len(obs['mission']))
+
+        return obs
+
+
+class FlatObsWrapper(gym.ObservationWrapper):
     """
     """
     Encode mission strings using a one-hot scheme,
     Encode mission strings using a one-hot scheme,
     and combine these with observed images into one flat array
     and combine these with observed images into one flat array
@@ -277,10 +372,12 @@ class FlatObsWrapper(gym.core.ObservationWrapper):
 
 
         # Cache the last-encoded mission string
         # Cache the last-encoded mission string
         if mission != self.cachedStr:
         if mission != self.cachedStr:
-            assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(len(mission))
+            assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(
+                len(mission))
             mission = mission.lower()
             mission = mission.lower()
 
 
-            strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
+            strArray = np.zeros(
+                shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
 
 
             for idx, ch in enumerate(mission):
             for idx, ch in enumerate(mission):
                 if ch >= 'a' and ch <= 'z':
                 if ch >= 'a' and ch <= 'z':
@@ -297,7 +394,8 @@ class FlatObsWrapper(gym.core.ObservationWrapper):
 
 
         return obs
         return obs
 
 
-class ViewSizeWrapper(gym.core.Wrapper):
+
+class ViewSizeWrapper(gym.Wrapper):
     """
     """
     Wrapper to customize the agent field of view size.
     Wrapper to customize the agent field of view size.
     This cannot be used with fully observable wrappers.
     This cannot be used with fully observable wrappers.
@@ -309,34 +407,41 @@ class ViewSizeWrapper(gym.core.Wrapper):
         assert agent_view_size % 2 == 1
         assert agent_view_size % 2 == 1
         assert agent_view_size >= 3
         assert agent_view_size >= 3
 
 
-        # Override default view size
-        env.unwrapped.agent_view_size = agent_view_size
+        self.agent_view_size = agent_view_size
 
 
         # Compute observation space with specified view size
         # Compute observation space with specified view size
-        observation_space = gym.spaces.Box(
+        new_image_space = gym.spaces.Box(
             low=0,
             low=0,
             high=255,
             high=255,
             shape=(agent_view_size, agent_view_size, 3),
             shape=(agent_view_size, agent_view_size, 3),
             dtype='uint8'
             dtype='uint8'
         )
         )
 
 
-        # Override the environment's observation space
-        self.observation_space = spaces.Dict({
-            'image': observation_space
-        })
+        # Override the environment's observation spaceexit
+        self.observation_space = spaces.Dict(
+            {**self.observation_space.spaces, 'image': new_image_space})
 
 
-    def reset(self, **kwargs):
-        return self.env.reset(**kwargs)
+    def observation(self, obs):
+        env = self.unwrapped
 
 
-    def step(self, action):
-        return self.env.step(action)
+        grid, vis_mask = env.gen_obs_grid(self.agent_view_size)
+
+        # Encode the partially observable view into a numpy array
+        image = grid.encode(vis_mask)
+
+        return {
+            **obs,
+            'image': image
+        }
 
 
-class DirectionObsWrapper(gym.core.ObservationWrapper):
+
+class DirectionObsWrapper(gym.ObservationWrapper):
     """
     """
     Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
     Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
     type = {slope , angle}
     type = {slope , angle}
     """
     """
-    def __init__(self, env,type='slope'):
+
+    def __init__(self, env, type='slope'):
         super().__init__(env)
         super().__init__(env)
         self.goal_position = None
         self.goal_position = None
         self.type = type
         self.type = type
@@ -344,17 +449,23 @@ class DirectionObsWrapper(gym.core.ObservationWrapper):
     def reset(self):
     def reset(self):
         obs = self.env.reset()
         obs = self.env.reset()
         if not self.goal_position:
         if not self.goal_position:
-            self.goal_position = [x for x,y in enumerate(self.grid.grid) if isinstance(y,(Goal) ) ]
-            if len(self.goal_position) >= 1: # in case there are multiple goals , needs to be handled for other env types
-                self.goal_position = (int(self.goal_position[0]/self.height) , self.goal_position[0]%self.width)
+            self.goal_position = [x for x, y in enumerate(
+                self.grid.grid) if isinstance(y, (Goal))]
+            # in case there are multiple goals , needs to be handled for other env types
+            if len(self.goal_position) >= 1:
+                self.goal_position = (
+                    int(self.goal_position[0]/self.height), self.goal_position[0] % self.width)
         return obs
         return obs
 
 
     def observation(self, obs):
     def observation(self, obs):
-        slope = np.divide( self.goal_position[1] - self.agent_pos[1] ,  self.goal_position[0] - self.agent_pos[0])
-        obs['goal_direction'] = np.arctan( slope ) if self.type == 'angle' else slope
+        slope = np.divide(
+            self.goal_position[1] - self.agent_pos[1],  self.goal_position[0] - self.agent_pos[0])
+        obs['goal_direction'] = np.arctan(
+            slope) if self.type == 'angle' else slope
         return obs
         return obs
 
 
-class SymbolicObsWrapper(gym.core.ObservationWrapper):
+
+class SymbolicObsWrapper(gym.ObservationWrapper):
     """
     """
     Fully observable grid with a symbolic state representation.
     Fully observable grid with a symbolic state representation.
     The symbol is a triple of (X, Y, IDX), where X and Y are
     The symbol is a triple of (X, Y, IDX), where X and Y are
@@ -364,12 +475,14 @@ class SymbolicObsWrapper(gym.core.ObservationWrapper):
     def __init__(self, env):
     def __init__(self, env):
         super().__init__(env)
         super().__init__(env)
 
 
-        self.observation_space.spaces["image"] = spaces.Box(
+        new_image_space = spaces.Box(
             low=0,
             low=0,
             high=max(OBJECT_TO_IDX.values()),
             high=max(OBJECT_TO_IDX.values()),
             shape=(self.env.width, self.env.height, 3),  # number of cells
             shape=(self.env.width, self.env.height, 3),  # number of cells
             dtype="uint8",
             dtype="uint8",
         )
         )
+        self.observation_space = spaces.Dict(
+            {**self.observation_space.spaces, 'image': new_image_space})
 
 
     def observation(self, obs):
     def observation(self, obs):
         objects = np.array(
         objects = np.array(

+ 9 - 6
manual_control.py

@@ -8,17 +8,17 @@ import gym_minigrid
 from gym_minigrid.wrappers import *
 from gym_minigrid.wrappers import *
 from gym_minigrid.window import Window
 from gym_minigrid.window import Window
 
 
+
 def redraw(img):
 def redraw(img):
     if not args.agent_view:
     if not args.agent_view:
-        img = env.render('rgb_array', tile_size=args.tile_size)
+        img = env.render(tile_size=args.tile_size)
 
 
     window.show_img(img)
     window.show_img(img)
 
 
-def reset():
-    if args.seed != -1:
-        env.seed(args.seed)
 
 
-    obs = env.reset()
+def reset():
+    seed = None if args.seed == -1 else args.seed
+    obs = env.reset(seed=seed)
 
 
     if hasattr(env, 'mission'):
     if hasattr(env, 'mission'):
         print('Mission: %s' % env.mission)
         print('Mission: %s' % env.mission)
@@ -26,6 +26,7 @@ def reset():
 
 
     redraw(obs)
     redraw(obs)
 
 
+
 def step(action):
 def step(action):
     obs, reward, done, info = env.step(action)
     obs, reward, done, info = env.step(action)
     print('step=%s, reward=%.2f' % (env.step_count, reward))
     print('step=%s, reward=%.2f' % (env.step_count, reward))
@@ -36,6 +37,7 @@ def step(action):
     else:
     else:
         redraw(obs)
         redraw(obs)
 
 
+
 def key_handler(event):
 def key_handler(event):
     print('pressed', event.key)
     print('pressed', event.key)
 
 
@@ -72,6 +74,7 @@ def key_handler(event):
         step(env.actions.done)
         step(env.actions.done)
         return
         return
 
 
+
 parser = argparse.ArgumentParser()
 parser = argparse.ArgumentParser()
 parser.add_argument(
 parser.add_argument(
     "--env",
     "--env",
@@ -99,7 +102,7 @@ parser.add_argument(
 
 
 args = parser.parse_args()
 args = parser.parse_args()
 
 
-env = gym.make(args.env)
+env = gym.make(args.env, render_mode='rgb_array')
 
 
 if args.agent_view:
 if args.agent_view:
     env = RGBImgPartialObsWrapper(env)
     env = RGBImgPartialObsWrapper(env)

+ 34 - 22
run_tests.py

@@ -1,5 +1,6 @@
 #!/usr/bin/env python3
 #!/usr/bin/env python3
 
 
+from pydoc import render_doc
 import random
 import random
 import numpy as np
 import numpy as np
 import gym
 import gym
@@ -21,17 +22,17 @@ for env_idx, env_name in enumerate(env_list):
     print('testing {} ({}/{})'.format(env_name, env_idx+1, len(env_list)))
     print('testing {} ({}/{})'.format(env_name, env_idx+1, len(env_list)))
 
 
     # Load the gym environment
     # Load the gym environment
-    env = gym.make(env_name)
+    env = gym.make(env_name, render_mode='rgb_array')
     env.max_steps = min(env.max_steps, 200)
     env.max_steps = min(env.max_steps, 200)
     env.reset()
     env.reset()
-    env.render('rgb_array')
+    env.render()
 
 
     # Verify that the same seed always produces the same environment
     # Verify that the same seed always produces the same environment
     for i in range(0, 5):
     for i in range(0, 5):
         seed = 1337 + i
         seed = 1337 + i
-        env.seed(seed)
+        _ = env.reset(seed=seed)
         grid1 = env.grid
         grid1 = env.grid
-        env.seed(seed)
+        _ = env.reset(seed=seed)
         grid2 = env.grid
         grid2 = env.grid
         assert grid1 == grid2
         assert grid1 == grid2
 
 
@@ -66,7 +67,7 @@ for env_idx, env_name in enumerate(env_list):
             num_episodes += 1
             num_episodes += 1
             env.reset()
             env.reset()
 
 
-        env.render('rgb_array')
+        env.render()
 
 
     # Test the close method
     # Test the close method
     env.close()
     env.close()
@@ -112,6 +113,16 @@ for env_idx, env_name in enumerate(env_list):
     env.step(0)
     env.step(0)
     env.close()
     env.close()
 
 
+    # Test the DictObservationSpaceWrapper
+    env = gym.make(env_name)
+    env = DictObservationSpaceWrapper(env)
+    env.reset()
+    mission = env.mission
+    obs, _, _, _ = env.step(0)
+    assert env.string_to_indices(mission) == [
+        value for value in obs['mission'] if value != 0]
+    env.close()
+
     # Test the wrappers return proper observation spaces.
     # Test the wrappers return proper observation spaces.
     wrappers = [
     wrappers = [
         RGBImgObsWrapper,
         RGBImgObsWrapper,
@@ -119,7 +130,7 @@ for env_idx, env_name in enumerate(env_list):
         OneHotPartialObsWrapper
         OneHotPartialObsWrapper
     ]
     ]
     for wrapper in wrappers:
     for wrapper in wrappers:
-        env = wrapper(gym.make(env_name))
+        env = wrapper(gym.make(env_name, render_mode='rgb_array'))
         obs_space, wrapper_name = env.observation_space, wrapper.__name__
         obs_space, wrapper_name = env.observation_space, wrapper.__name__
         assert isinstance(
         assert isinstance(
             obs_space, spaces.Dict
             obs_space, spaces.Dict
@@ -135,29 +146,33 @@ for env_idx, env_name in enumerate(env_list):
 ##############################################################################
 ##############################################################################
 
 
 print('testing extra observations')
 print('testing extra observations')
+
+
 class EmptyEnvWithExtraObs(gym_minigrid.envs.EmptyEnv5x5):
 class EmptyEnvWithExtraObs(gym_minigrid.envs.EmptyEnv5x5):
     """
     """
     Custom environment with an extra observation
     Custom environment with an extra observation
     """
     """
-    def __init__(self) -> None:
-        super().__init__()
+
+    def __init__(self, **kwargs) -> None:
+        super().__init__(**kwargs)
         self.observation_space['size'] = spaces.Box(
         self.observation_space['size'] = spaces.Box(
             low=0,
             low=0,
-            high=np.iinfo(np.uint).max,
+            high=1000,  # gym does not like np.iinfo(np.uint).max,
             shape=(2,),
             shape=(2,),
             dtype=np.uint
             dtype=np.uint
         )
         )
 
 
-    def reset(self):
-        obs = super().reset()
-        obs['size'] = np.array([self.width, self.height])
+    def reset(self, **kwargs):
+        obs = super().reset(**kwargs)
+        obs['size'] = np.array([self.width, self.height], dtype=np.uint)
         return obs
         return obs
 
 
     def step(self, action):
     def step(self, action):
         obs, reward, done, info = super().step(action)
         obs, reward, done, info = super().step(action)
-        obs['size'] = np.array([self.width, self.height])
+        obs['size'] = np.array([self.width, self.height], dtype=np.uint)
         return obs, reward, done, info
         return obs, reward, done, info
 
 
+
 wrappers = [
 wrappers = [
     OneHotPartialObsWrapper,
     OneHotPartialObsWrapper,
     RGBImgObsWrapper,
     RGBImgObsWrapper,
@@ -165,17 +180,14 @@ wrappers = [
     FullyObsWrapper,
     FullyObsWrapper,
 ]
 ]
 for wrapper in wrappers:
 for wrapper in wrappers:
-    env1 = wrapper(EmptyEnvWithExtraObs())
-    env2 = wrapper(gym.make('MiniGrid-Empty-5x5-v0'))
-
-    env1.seed(0)
-    env2.seed(0)
+    env1 = wrapper(EmptyEnvWithExtraObs(render_mode='rgb_array'))
+    env2 = wrapper(gym.make('MiniGrid-Empty-5x5-v0', render_mode='rgb_array'))
 
 
-    obs1 = env1.reset()
-    obs2 = env2.reset()
+    obs1 = env1.reset(seed=0)
+    obs2 = env2.reset(seed=0)
     assert 'size' in obs1
     assert 'size' in obs1
     assert obs1['size'].shape == (2,)
     assert obs1['size'].shape == (2,)
-    assert (obs1['size'] == [5,5]).all()
+    assert (obs1['size'] == [5, 5]).all()
     for key in obs2:
     for key in obs2:
         assert np.array_equal(obs1[key], obs2[key])
         assert np.array_equal(obs1[key], obs2[key])
 
 
@@ -183,7 +195,7 @@ for wrapper in wrappers:
     obs2, reward2, done2, _ = env2.step(0)
     obs2, reward2, done2, _ = env2.step(0)
     assert 'size' in obs1
     assert 'size' in obs1
     assert obs1['size'].shape == (2,)
     assert obs1['size'].shape == (2,)
-    assert (obs1['size'] == [5,5]).all()
+    assert (obs1['size'] == [5, 5]).all()
     for key in obs2:
     for key in obs2:
         assert np.array_equal(obs1[key], obs2[key])
         assert np.array_equal(obs1[key], obs2[key])
 
 

+ 9 - 9
setup.py

@@ -24,15 +24,15 @@ setup(
     python_requires=">=3.7, <3.11",
     python_requires=">=3.7, <3.11",
     long_description_content_type="text/markdown",
     long_description_content_type="text/markdown",
     install_requires=[
     install_requires=[
-        'gym>=0.24.0',
-        "numpy>=1.18.0"
+        "numpy>=1.18.0",
+        'gym>=0.25.0'
     ],
     ],
     classifiers=[
     classifiers=[
-    "Development Status :: 5 - Production/Stable",
-    "Programming Language :: Python :: 3",
-    "Programming Language :: Python :: 3.7",
-    "Programming Language :: Python :: 3.8",
-    "Programming Language :: Python :: 3.9",
-    "Programming Language :: Python :: 3.10",
-],
+        "Development Status :: 5 - Production/Stable",
+        "Programming Language :: Python :: 3",
+        "Programming Language :: Python :: 3.7",
+        "Programming Language :: Python :: 3.8",
+        "Programming Language :: Python :: 3.9",
+        "Programming Language :: Python :: 3.10",
+    ],
 )
 )