|
@@ -367,8 +367,8 @@ class Grid:
|
|
|
"""
|
|
|
|
|
|
def __init__(self, width, height):
|
|
|
- assert width >= 4
|
|
|
- assert height >= 4
|
|
|
+ assert width >= 3
|
|
|
+ assert height >= 3
|
|
|
|
|
|
self.width = width
|
|
|
self.height = height
|
|
@@ -664,11 +664,19 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
- grid_size=16,
|
|
|
+ grid_size=None,
|
|
|
+ width=None,
|
|
|
+ height=None,
|
|
|
max_steps=100,
|
|
|
see_through_walls=False,
|
|
|
seed=1337
|
|
|
):
|
|
|
+ # Can't set both grid_size and width/height
|
|
|
+ if grid_size:
|
|
|
+ assert width == None and height == None
|
|
|
+ width = grid_size
|
|
|
+ height = grid_size
|
|
|
+
|
|
|
# Action enumeration for this environment
|
|
|
self.actions = MiniGridEnv.Actions
|
|
|
|
|
@@ -697,7 +705,8 @@ class MiniGridEnv(gym.Env):
|
|
|
self.obs_render = None
|
|
|
|
|
|
# Environment configuration
|
|
|
- self.grid_size = grid_size
|
|
|
+ self.width = width
|
|
|
+ self.height = height
|
|
|
self.max_steps = max_steps
|
|
|
self.see_through_walls = see_through_walls
|
|
|
|
|
@@ -715,7 +724,7 @@ class MiniGridEnv(gym.Env):
|
|
|
# Generate a new random grid at the start of each episode
|
|
|
# To keep the same grid for each episode, call env.seed() with
|
|
|
# the same seed before calling env.reset()
|
|
|
- self._gen_grid(self.grid_size, self.grid_size)
|
|
|
+ self._gen_grid(self.width, self.height)
|
|
|
|
|
|
# These fields should be defined by _gen_grid
|
|
|
assert self.start_pos is not None
|
|
@@ -1259,8 +1268,8 @@ class MiniGridEnv(gym.Env):
|
|
|
if self.grid_render is None:
|
|
|
from gym_minigrid.rendering import Renderer
|
|
|
self.grid_render = Renderer(
|
|
|
- self.grid_size * CELL_PIXELS,
|
|
|
- self.grid_size * CELL_PIXELS,
|
|
|
+ self.width * CELL_PIXELS,
|
|
|
+ self.height * CELL_PIXELS,
|
|
|
True if mode == 'human' else False
|
|
|
)
|
|
|
|