浏览代码

Environments can now have different width & height

No longer restricted to square grids.
Maxime Chevalier-Boisvert 6 年之前
父节点
当前提交
ffe4e553e0
共有 7 个文件被更改,包括 26 次插入16 次删除
  1. 1 1
      gym_minigrid/envs/multiroom.py
  2. 2 1
      gym_minigrid/envs/redbluedoors.py
  3. 16 7
      gym_minigrid/minigrid.py
  4. 2 2
      gym_minigrid/roomgrid.py
  5. 2 2
      gym_minigrid/wrappers.py
  6. 2 2
      run_tests.py
  7. 1 1
      setup.py

+ 1 - 1
gym_minigrid/envs/multiroom.py

@@ -158,7 +158,7 @@ class MultiRoomEnv(MiniGridEnv):
         # If the room is out of the grid, can't place a room here
         if topX < 0 or topY < 0:
             return False
-        if topX + sizeX > self.grid_size or topY + sizeY >= self.grid_size:
+        if topX + sizeX > self.width or topY + sizeY >= self.height:
             return False
 
         # If the room intersects with previous rooms, can't place it here

+ 2 - 1
gym_minigrid/envs/redbluedoors.py

@@ -10,7 +10,8 @@ class RedBlueDoorEnv(MiniGridEnv):
         self.size = size
 
         super().__init__(
-            grid_size=2*size,
+            width=2*size,
+            height=size,
             max_steps=10*size*size
         )
 

+ 16 - 7
gym_minigrid/minigrid.py

@@ -367,8 +367,8 @@ class Grid:
     """
 
     def __init__(self, width, height):
-        assert width >= 4
-        assert height >= 4
+        assert width >= 3
+        assert height >= 3
 
         self.width = width
         self.height = height
@@ -664,11 +664,19 @@ class MiniGridEnv(gym.Env):
 
     def __init__(
         self,
-        grid_size=16,
+        grid_size=None,
+        width=None,
+        height=None,
         max_steps=100,
         see_through_walls=False,
         seed=1337
     ):
+        # Can't set both grid_size and width/height
+        if grid_size:
+            assert width == None and height == None
+            width = grid_size
+            height = grid_size
+
         # Action enumeration for this environment
         self.actions = MiniGridEnv.Actions
 
@@ -697,7 +705,8 @@ class MiniGridEnv(gym.Env):
         self.obs_render = None
 
         # Environment configuration
-        self.grid_size = grid_size
+        self.width = width
+        self.height = height
         self.max_steps = max_steps
         self.see_through_walls = see_through_walls
 
@@ -715,7 +724,7 @@ class MiniGridEnv(gym.Env):
         # Generate a new random grid at the start of each episode
         # To keep the same grid for each episode, call env.seed() with
         # the same seed before calling env.reset()
-        self._gen_grid(self.grid_size, self.grid_size)
+        self._gen_grid(self.width, self.height)
 
         # These fields should be defined by _gen_grid
         assert self.start_pos is not None
@@ -1259,8 +1268,8 @@ class MiniGridEnv(gym.Env):
         if self.grid_render is None:
             from gym_minigrid.rendering import Renderer
             self.grid_render = Renderer(
-                self.grid_size * CELL_PIXELS,
-                self.grid_size * CELL_PIXELS,
+                self.width * CELL_PIXELS,
+                self.height * CELL_PIXELS,
                 True if mode == 'human' else False
             )
 

+ 2 - 2
gym_minigrid/roomgrid.py

@@ -84,13 +84,13 @@ class RoomGrid(MiniGridEnv):
 
         height = (room_size - 1) * num_rows + 1
         width = (room_size - 1) * num_cols + 1
-        grid_size = max(width, height)
 
         # By default, this environment has no mission
         self.mission = ''
 
         super().__init__(
-            grid_size=grid_size,
+            width=width,
+            height=height,
             max_steps=max_steps,
             see_through_walls=False,
             seed=seed

+ 2 - 2
gym_minigrid/wrappers.py

@@ -98,8 +98,8 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
         self.__dict__.update(vars(env))  # hack to pass values to super wrapper
         self.observation_space = spaces.Box(
             low=0,
-            high=self.env.grid_size,
-            shape=(self.env.grid_size, self.env.grid_size, 3),  # number of cells
+            high=255,
+            shape=(self.env.width, self.env.height, 3),  # number of cells
             dtype='uint8'
         )
 

+ 2 - 2
run_tests.py

@@ -45,8 +45,8 @@ for envName in env_list:
         obs, reward, done, info = env.step(action)
 
         # Validate the agent position
-        assert env.agent_pos[0] < env.grid_size
-        assert env.agent_pos[1] < env.grid_size
+        assert env.agent_pos[0] < env.width
+        assert env.agent_pos[1] < env.height
 
         # Test observation encode/decode roundtrip
         img = obs['image']

+ 1 - 1
setup.py

@@ -7,7 +7,7 @@ setup(
     packages=['gym_minigrid', 'gym_minigrid.envs'],
     install_requires=[
         'gym>=0.9.6',
-        'numpy>=1.10.0',
+        'numpy==1.15.4', # FIXME: temporary because of bug in numpy 1.16
         'pyqt5>=5.10.1'
     ]
 )