Преглед изворни кода

New file structure for v2.0.0 (#248)

Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
Rodrigo de Lazcano пре 2 година
родитељ
комит
dcf15dbf01

+ 2 - 2
README.md

@@ -3,7 +3,7 @@
 [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://pre-commit.com/) 
 [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://pre-commit.com/) 
 [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
 [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
 
 
-There are other gridworld Gym environments out there, but this one is
+There are other gridworld Gymnasium environments out there, but this one is
 designed to be particularly simple, lightweight and fast. The code has very few
 designed to be particularly simple, lightweight and fast. The code has very few
 dependencies, making it less likely to break or fail to install. It loads no
 dependencies, making it less likely to break or fail to install. It loads no
 external sprites/textures, and it can run at up to 5000 FPS on a Core i7
 external sprites/textures, and it can run at up to 5000 FPS on a Core i7
@@ -12,7 +12,7 @@ implementation can be found [in this repository](https://github.com/lcswillems/t
 
 
 Requirements:
 Requirements:
 - Python 3.7 to 3.10
 - Python 3.7 to 3.10
-- OpenAI Gym v0.26
+- Gymnasium v0.26
 - NumPy 1.18+
 - NumPy 1.18+
 - Matplotlib (optional, only needed for display) - 3.0+
 - Matplotlib (optional, only needed for display) - 3.0+
 
 

+ 8 - 6
minigrid/__init__.py

@@ -1,6 +1,8 @@
 from gymnasium.envs.registration import register
 from gymnasium.envs.registration import register
 
 
-from minigrid import minigrid, roomgrid, wrappers
+from minigrid import minigrid, wrappers
+from minigrid.core import roomgrid
+from minigrid.core.world_object import Wall
 
 
 
 
 def register_minigrid_envs():
 def register_minigrid_envs():
@@ -44,25 +46,25 @@ def register_minigrid_envs():
     register(
     register(
         id="MiniGrid-SimpleCrossingS9N1-v0",
         id="MiniGrid-SimpleCrossingS9N1-v0",
         entry_point="minigrid.envs:CrossingEnv",
         entry_point="minigrid.envs:CrossingEnv",
-        kwargs={"size": 9, "num_crossings": 1, "obstacle_type": minigrid.Wall},
+        kwargs={"size": 9, "num_crossings": 1, "obstacle_type": Wall},
     )
     )
 
 
     register(
     register(
         id="MiniGrid-SimpleCrossingS9N2-v0",
         id="MiniGrid-SimpleCrossingS9N2-v0",
         entry_point="minigrid.envs:CrossingEnv",
         entry_point="minigrid.envs:CrossingEnv",
-        kwargs={"size": 9, "num_crossings": 2, "obstacle_type": minigrid.Wall},
+        kwargs={"size": 9, "num_crossings": 2, "obstacle_type": Wall},
     )
     )
 
 
     register(
     register(
         id="MiniGrid-SimpleCrossingS9N3-v0",
         id="MiniGrid-SimpleCrossingS9N3-v0",
         entry_point="minigrid.envs:CrossingEnv",
         entry_point="minigrid.envs:CrossingEnv",
-        kwargs={"size": 9, "num_crossings": 3, "obstacle_type": minigrid.Wall},
+        kwargs={"size": 9, "num_crossings": 3, "obstacle_type": Wall},
     )
     )
 
 
     register(
     register(
         id="MiniGrid-SimpleCrossingS11N5-v0",
         id="MiniGrid-SimpleCrossingS11N5-v0",
-        entry_point="minigrid.envs:CrossingEnv",
-        kwargs={"size": 11, "num_crossings": 5, "obstacle_type": minigrid.Wall},
+        entry_point="gym_minigrid.envs:CrossingEnv",
+        kwargs={"size": 11, "num_crossings": 5, "obstacle_type": Wall},
     )
     )
 
 
     # DistShift
     # DistShift

+ 0 - 0
minigrid/core/__init__.py


+ 18 - 0
minigrid/core/actions.py

@@ -0,0 +1,18 @@
+# Enumeration of possible actions
+from enum import IntEnum
+
+
+class Actions(IntEnum):
+    # Turn left, turn right, move forward
+    left = 0
+    right = 1
+    forward = 2
+    # Pick up an object
+    pickup = 3
+    # Drop an object
+    drop = 4
+    # Toggle/activate an object
+    toggle = 5
+
+    # Done completing task
+    done = 6

+ 56 - 0
minigrid/core/constants.py

@@ -0,0 +1,56 @@
+import numpy as np
+
+TILE_PIXELS = 32
+
+# Map of color names to RGB values
+COLORS = {
+    "red": np.array([255, 0, 0]),
+    "green": np.array([0, 255, 0]),
+    "blue": np.array([0, 0, 255]),
+    "purple": np.array([112, 39, 195]),
+    "yellow": np.array([255, 255, 0]),
+    "grey": np.array([100, 100, 100]),
+}
+
+COLOR_NAMES = sorted(list(COLORS.keys()))
+
+# Used to map colors to integers
+COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5}
+
+IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
+
+# Map of object type to integers
+OBJECT_TO_IDX = {
+    "unseen": 0,
+    "empty": 1,
+    "wall": 2,
+    "floor": 3,
+    "door": 4,
+    "key": 5,
+    "ball": 6,
+    "box": 7,
+    "goal": 8,
+    "lava": 9,
+    "agent": 10,
+}
+
+IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
+
+# Map of state names to integers
+STATE_TO_IDX = {
+    "open": 0,
+    "closed": 1,
+    "locked": 2,
+}
+
+# Map of agent direction indices to vectors
+DIR_TO_VEC = [
+    # Pointing right (positive X)
+    np.array((1, 0)),
+    # Down (positive Y)
+    np.array((0, 1)),
+    # Pointing left (negative X)
+    np.array((-1, 0)),
+    # Up (negative Y)
+    np.array((0, -1)),
+]

+ 295 - 0
minigrid/core/grid.py

@@ -0,0 +1,295 @@
+import math
+
+import numpy as np
+
+from minigrid.core.constants import OBJECT_TO_IDX, TILE_PIXELS
+from minigrid.core.world_object import Wall, WorldObj
+from minigrid.utils.rendering import (
+    downsample,
+    fill_coords,
+    highlight_img,
+    point_in_rect,
+    point_in_triangle,
+    rotate_fn,
+)
+
+
+class Grid:
+    """
+    Represent a grid and operations on it
+    """
+
+    # Static cache of pre-renderer tiles
+    tile_cache = {}
+
+    def __init__(self, width, height):
+        assert width >= 3
+        assert height >= 3
+
+        self.width = width
+        self.height = height
+
+        self.grid = [None] * width * height
+
+    def __contains__(self, key):
+        if isinstance(key, WorldObj):
+            for e in self.grid:
+                if e is key:
+                    return True
+        elif isinstance(key, tuple):
+            for e in self.grid:
+                if e is None:
+                    continue
+                if (e.color, e.type) == key:
+                    return True
+                if key[0] is None and key[1] == e.type:
+                    return True
+        return False
+
+    def __eq__(self, other):
+        grid1 = self.encode()
+        grid2 = other.encode()
+        return np.array_equal(grid2, grid1)
+
+    def __ne__(self, other):
+        return not self == other
+
+    def copy(self):
+        from copy import deepcopy
+
+        return deepcopy(self)
+
+    def set(self, i, j, v):
+        assert i >= 0 and i < self.width
+        assert j >= 0 and j < self.height
+        self.grid[j * self.width + i] = v
+
+    def get(self, i, j):
+        assert i >= 0 and i < self.width
+        assert j >= 0 and j < self.height
+        return self.grid[j * self.width + i]
+
+    def horz_wall(self, x, y, length=None, obj_type=Wall):
+        if length is None:
+            length = self.width - x
+        for i in range(0, length):
+            self.set(x + i, y, obj_type())
+
+    def vert_wall(self, x, y, length=None, obj_type=Wall):
+        if length is None:
+            length = self.height - y
+        for j in range(0, length):
+            self.set(x, y + j, obj_type())
+
+    def wall_rect(self, x, y, w, h):
+        self.horz_wall(x, y, w)
+        self.horz_wall(x, y + h - 1, w)
+        self.vert_wall(x, y, h)
+        self.vert_wall(x + w - 1, y, h)
+
+    def rotate_left(self):
+        """
+        Rotate the grid to the left (counter-clockwise)
+        """
+
+        grid = Grid(self.height, self.width)
+
+        for i in range(self.width):
+            for j in range(self.height):
+                v = self.get(i, j)
+                grid.set(j, grid.height - 1 - i, v)
+
+        return grid
+
+    def slice(self, topX, topY, width, height):
+        """
+        Get a subset of the grid
+        """
+
+        grid = Grid(width, height)
+
+        for j in range(0, height):
+            for i in range(0, width):
+                x = topX + i
+                y = topY + j
+
+                if x >= 0 and x < self.width and y >= 0 and y < self.height:
+                    v = self.get(x, y)
+                else:
+                    v = Wall()
+
+                grid.set(i, j, v)
+
+        return grid
+
+    @classmethod
+    def render_tile(
+        cls, obj, agent_dir=None, highlight=False, tile_size=TILE_PIXELS, subdivs=3
+    ):
+        """
+        Render a tile and cache the result
+        """
+
+        # Hash map lookup key for the cache
+        key = (agent_dir, highlight, tile_size)
+        key = obj.encode() + key if obj else key
+
+        if key in cls.tile_cache:
+            return cls.tile_cache[key]
+
+        img = np.zeros(
+            shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
+        )
+
+        # Draw the grid lines (top and left edges)
+        fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
+        fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
+
+        if obj is not None:
+            obj.render(img)
+
+        # Overlay the agent on top
+        if agent_dir is not None:
+            tri_fn = point_in_triangle(
+                (0.12, 0.19),
+                (0.87, 0.50),
+                (0.12, 0.81),
+            )
+
+            # Rotate the agent based on its direction
+            tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
+            fill_coords(img, tri_fn, (255, 0, 0))
+
+        # Highlight the cell if needed
+        if highlight:
+            highlight_img(img)
+
+        # Downsample the image to perform supersampling/anti-aliasing
+        img = downsample(img, subdivs)
+
+        # Cache the rendered tile
+        cls.tile_cache[key] = img
+
+        return img
+
+    def render(self, tile_size, agent_pos, agent_dir=None, highlight_mask=None):
+        """
+        Render this grid at a given scale
+        :param r: target renderer object
+        :param tile_size: tile size in pixels
+        """
+
+        if highlight_mask is None:
+            highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
+
+        # Compute the total grid size
+        width_px = self.width * tile_size
+        height_px = self.height * tile_size
+
+        img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
+
+        # Render the grid
+        for j in range(0, self.height):
+            for i in range(0, self.width):
+                cell = self.get(i, j)
+
+                agent_here = np.array_equal(agent_pos, (i, j))
+                tile_img = Grid.render_tile(
+                    cell,
+                    agent_dir=agent_dir if agent_here else None,
+                    highlight=highlight_mask[i, j],
+                    tile_size=tile_size,
+                )
+
+                ymin = j * tile_size
+                ymax = (j + 1) * tile_size
+                xmin = i * tile_size
+                xmax = (i + 1) * tile_size
+                img[ymin:ymax, xmin:xmax, :] = tile_img
+
+        return img
+
+    def encode(self, vis_mask=None):
+        """
+        Produce a compact numpy encoding of the grid
+        """
+
+        if vis_mask is None:
+            vis_mask = np.ones((self.width, self.height), dtype=bool)
+
+        array = np.zeros((self.width, self.height, 3), dtype="uint8")
+
+        for i in range(self.width):
+            for j in range(self.height):
+                if vis_mask[i, j]:
+                    v = self.get(i, j)
+
+                    if v is None:
+                        array[i, j, 0] = OBJECT_TO_IDX["empty"]
+                        array[i, j, 1] = 0
+                        array[i, j, 2] = 0
+
+                    else:
+                        array[i, j, :] = v.encode()
+
+        return array
+
+    @staticmethod
+    def decode(array):
+        """
+        Decode an array grid encoding back into a grid
+        """
+
+        width, height, channels = array.shape
+        assert channels == 3
+
+        vis_mask = np.ones(shape=(width, height), dtype=bool)
+
+        grid = Grid(width, height)
+        for i in range(width):
+            for j in range(height):
+                type_idx, color_idx, state = array[i, j]
+                v = WorldObj.decode(type_idx, color_idx, state)
+                grid.set(i, j, v)
+                vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
+
+        return grid, vis_mask
+
+    def process_vis(self, agent_pos):
+        mask = np.zeros(shape=(self.width, self.height), dtype=bool)
+
+        mask[agent_pos[0], agent_pos[1]] = True
+
+        for j in reversed(range(0, self.height)):
+            for i in range(0, self.width - 1):
+                if not mask[i, j]:
+                    continue
+
+                cell = self.get(i, j)
+                if cell and not cell.see_behind():
+                    continue
+
+                mask[i + 1, j] = True
+                if j > 0:
+                    mask[i + 1, j - 1] = True
+                    mask[i, j - 1] = True
+
+            for i in reversed(range(1, self.width)):
+                if not mask[i, j]:
+                    continue
+
+                cell = self.get(i, j)
+                if cell and not cell.see_behind():
+                    continue
+
+                mask[i - 1, j] = True
+                if j > 0:
+                    mask[i - 1, j - 1] = True
+                    mask[i, j - 1] = True
+
+        for j in range(0, self.height):
+            for i in range(0, self.width):
+                if not mask[i, j]:
+                    self.set(i, j, None)
+
+        return mask

+ 195 - 0
minigrid/core/mission.py

@@ -0,0 +1,195 @@
+from typing import Any, Callable, Optional, Union
+
+from gymnasium import spaces
+from gymnasium.utils import seeding
+
+
+def check_if_no_duplicate(duplicate_list: list) -> bool:
+    """Check if given list contains any duplicates"""
+    return len(set(duplicate_list)) == len(duplicate_list)
+
+
+class MissionSpace(spaces.Space[str]):
+    r"""A space representing a mission for the Gym-Minigrid environments.
+    The space allows generating random mission strings constructed with an input placeholder list.
+    Example Usage::
+        >>> observation_space = MissionSpace(mission_func=lambda color: f"Get the {color} ball.",
+                                                ordered_placeholders=[["green", "blue"]])
+        >>> observation_space.sample()
+            "Get the green ball."
+        >>> observation_space = MissionSpace(mission_func=lambda : "Get the ball.".,
+                                                ordered_placeholders=None)
+        >>> observation_space.sample()
+            "Get the ball."
+    """
+
+    def __init__(
+        self,
+        mission_func: Callable[..., str],
+        ordered_placeholders: Optional["list[list[str]]"] = None,
+        seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
+    ):
+        r"""Constructor of :class:`MissionSpace` space.
+
+        Args:
+            mission_func (lambda _placeholders(str): _mission(str)): Function that generates a mission string from random placeholders.
+            ordered_placeholders (Optional["list[list[str]]"]): List of lists of placeholders ordered in placing order in the mission function mission_func.
+            seed: seed: The seed for sampling from the space.
+        """
+        # Check that the ordered placeholders and mission function are well defined.
+        if ordered_placeholders is not None:
+            assert (
+                len(ordered_placeholders) == mission_func.__code__.co_argcount
+            ), f"The number of placeholders {len(ordered_placeholders)} is different from the number of parameters in the mission function {mission_func.__code__.co_argcount}."
+            for placeholder_list in ordered_placeholders:
+                assert check_if_no_duplicate(
+                    placeholder_list
+                ), "Make sure that the placeholders don't have any duplicate values."
+        else:
+            assert (
+                mission_func.__code__.co_argcount == 0
+            ), f"If the ordered placeholders are {ordered_placeholders}, the mission function shouldn't have any parameters."
+
+        self.ordered_placeholders = ordered_placeholders
+        self.mission_func = mission_func
+
+        super().__init__(dtype=str, seed=seed)
+
+        # Check that mission_func returns a string
+        sampled_mission = self.sample()
+        assert isinstance(
+            sampled_mission, str
+        ), f"mission_func must return type str not {type(sampled_mission)}"
+
+    def sample(self) -> str:
+        """Sample a random mission string."""
+        if self.ordered_placeholders is not None:
+            placeholders = []
+            for rand_var_list in self.ordered_placeholders:
+                idx = self.np_random.integers(0, len(rand_var_list))
+
+                placeholders.append(rand_var_list[idx])
+
+            return self.mission_func(*placeholders)
+        else:
+            return self.mission_func()
+
+    def contains(self, x: Any) -> bool:
+        """Return boolean specifying if x is a valid member of this space."""
+        # Store a list of all the placeholders from self.ordered_placeholders that appear in x
+        if self.ordered_placeholders is not None:
+            check_placeholder_list = []
+            for placeholder_list in self.ordered_placeholders:
+                for placeholder in placeholder_list:
+                    if placeholder in x:
+                        check_placeholder_list.append(placeholder)
+
+            # Remove duplicates from the list
+            check_placeholder_list = list(set(check_placeholder_list))
+
+            start_id_placeholder = []
+            end_id_placeholder = []
+            # Get the starting and ending id of the identified placeholders with possible duplicates
+            new_check_placeholder_list = []
+            for placeholder in check_placeholder_list:
+                new_start_id_placeholder = [
+                    i for i in range(len(x)) if x.startswith(placeholder, i)
+                ]
+                new_check_placeholder_list += [placeholder] * len(
+                    new_start_id_placeholder
+                )
+                end_id_placeholder += [
+                    start_id + len(placeholder) - 1
+                    for start_id in new_start_id_placeholder
+                ]
+                start_id_placeholder += new_start_id_placeholder
+
+            # Order by starting id the placeholders
+            ordered_placeholder_list = sorted(
+                zip(
+                    start_id_placeholder, end_id_placeholder, new_check_placeholder_list
+                )
+            )
+
+            # Check for repeated placeholders contained in each other
+            remove_placeholder_id = []
+            for i, placeholder_1 in enumerate(ordered_placeholder_list):
+                starting_id = i + 1
+                for j, placeholder_2 in enumerate(
+                    ordered_placeholder_list[starting_id:]
+                ):
+                    # Check if place holder ids overlap and keep the longest
+                    if max(placeholder_1[0], placeholder_2[0]) < min(
+                        placeholder_1[1], placeholder_2[1]
+                    ):
+                        remove_placeholder = min(
+                            placeholder_1[2], placeholder_2[2], key=len
+                        )
+                        if remove_placeholder == placeholder_1[2]:
+                            remove_placeholder_id.append(i)
+                        else:
+                            remove_placeholder_id.append(i + j + 1)
+            for id in remove_placeholder_id:
+                del ordered_placeholder_list[id]
+
+            final_placeholders = [
+                placeholder[2] for placeholder in ordered_placeholder_list
+            ]
+
+            # Check that the identified final placeholders are in the same order as the original placeholders.
+            for orered_placeholder, final_placeholder in zip(
+                self.ordered_placeholders, final_placeholders
+            ):
+                if final_placeholder in orered_placeholder:
+                    continue
+                else:
+                    return False
+            try:
+                mission_string_with_placeholders = self.mission_func(
+                    *final_placeholders
+                )
+            except Exception as e:
+                print(
+                    f"{x} is not contained in MissionSpace due to the following exception: {e}"
+                )
+                return False
+
+            return bool(mission_string_with_placeholders == x)
+
+        else:
+            return bool(self.mission_func() == x)
+
+    def __repr__(self) -> str:
+        """Gives a string representation of this space."""
+        return f"MissionSpace({self.mission_func}, {self.ordered_placeholders})"
+
+    def __eq__(self, other) -> bool:
+        """Check whether ``other`` is equivalent to this instance."""
+        if isinstance(other, MissionSpace):
+
+            # Check that place holder lists are the same
+            if self.ordered_placeholders is not None:
+                # Check length
+                if (len(self.order_placeholder) == len(other.order_placeholder)) and (
+                    all(
+                        set(i) == set(j)
+                        for i, j in zip(self.order_placeholder, other.order_placeholder)
+                    )
+                ):
+                    # Check mission string is the same with dummy space placeholders
+                    test_placeholders = [""] * len(self.order_placeholder)
+                    mission = self.mission_func(*test_placeholders)
+                    other_mission = other.mission_func(*test_placeholders)
+                    return mission == other_mission
+            else:
+
+                # Check that other is also None
+                if other.ordered_placeholders is None:
+
+                    # Check mission string is the same
+                    mission = self.mission_func()
+                    other_mission = other.mission_func()
+                    return mission == other_mission
+
+        # If none of the statements above return then False
+        return False

+ 4 - 1
minigrid/roomgrid.py

@@ -1,6 +1,9 @@
 import numpy as np
 import numpy as np
 
 
-from minigrid.minigrid import COLOR_NAMES, Ball, Box, Door, Grid, Key, MiniGridEnv
+from minigrid.core.constants import COLOR_NAMES
+from minigrid.core.grid import Grid
+from minigrid.core.world_object import Ball, Box, Door, Key
+from minigrid.minigrid import MiniGridEnv
 
 
 
 
 def reject_next_to(env, pos):
 def reject_next_to(env, pos):

+ 284 - 0
minigrid/core/world_object.py

@@ -0,0 +1,284 @@
+import numpy as np
+
+from minigrid.core.constants import (
+    COLOR_TO_IDX,
+    COLORS,
+    IDX_TO_COLOR,
+    IDX_TO_OBJECT,
+    OBJECT_TO_IDX,
+)
+from minigrid.utils.rendering import (
+    fill_coords,
+    point_in_circle,
+    point_in_line,
+    point_in_rect,
+)
+
+
+class WorldObj:
+    """
+    Base class for grid world objects
+    """
+
+    def __init__(self, type, color):
+        assert type in OBJECT_TO_IDX, type
+        assert color in COLOR_TO_IDX, color
+        self.type = type
+        self.color = color
+        self.contains = None
+
+        # Initial position of the object
+        self.init_pos = None
+
+        # Current position of the object
+        self.cur_pos = None
+
+    def can_overlap(self):
+        """Can the agent overlap with this?"""
+        return False
+
+    def can_pickup(self):
+        """Can the agent pick this up?"""
+        return False
+
+    def can_contain(self):
+        """Can this contain another object?"""
+        return False
+
+    def see_behind(self):
+        """Can the agent see behind this object?"""
+        return True
+
+    def toggle(self, env, pos):
+        """Method to trigger/toggle an action this object performs"""
+        return False
+
+    def encode(self):
+        """Encode the a description of this object as a 3-tuple of integers"""
+        return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
+
+    @staticmethod
+    def decode(type_idx, color_idx, state):
+        """Create an object from a 3-tuple state description"""
+
+        obj_type = IDX_TO_OBJECT[type_idx]
+        color = IDX_TO_COLOR[color_idx]
+
+        if obj_type == "empty" or obj_type == "unseen":
+            return None
+
+        # State, 0: open, 1: closed, 2: locked
+        is_open = state == 0
+        is_locked = state == 2
+
+        if obj_type == "wall":
+            v = Wall(color)
+        elif obj_type == "floor":
+            v = Floor(color)
+        elif obj_type == "ball":
+            v = Ball(color)
+        elif obj_type == "key":
+            v = Key(color)
+        elif obj_type == "box":
+            v = Box(color)
+        elif obj_type == "door":
+            v = Door(color, is_open, is_locked)
+        elif obj_type == "goal":
+            v = Goal()
+        elif obj_type == "lava":
+            v = Lava()
+        else:
+            assert False, "unknown object type in decode '%s'" % obj_type
+
+        return v
+
+    def render(self, r):
+        """Draw this object with the given renderer"""
+        raise NotImplementedError
+
+
+class Goal(WorldObj):
+    def __init__(self):
+        super().__init__("goal", "green")
+
+    def can_overlap(self):
+        return True
+
+    def render(self, img):
+        fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
+
+
+class Floor(WorldObj):
+    """
+    Colored floor tile the agent can walk over
+    """
+
+    def __init__(self, color="blue"):
+        super().__init__("floor", color)
+
+    def can_overlap(self):
+        return True
+
+    def render(self, img):
+        # Give the floor a pale color
+        color = COLORS[self.color] / 2
+        fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
+
+
+class Lava(WorldObj):
+    def __init__(self):
+        super().__init__("lava", "red")
+
+    def can_overlap(self):
+        return True
+
+    def render(self, img):
+        c = (255, 128, 0)
+
+        # Background color
+        fill_coords(img, point_in_rect(0, 1, 0, 1), c)
+
+        # Little waves
+        for i in range(3):
+            ylo = 0.3 + 0.2 * i
+            yhi = 0.4 + 0.2 * i
+            fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
+            fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
+            fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
+            fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
+
+
+class Wall(WorldObj):
+    def __init__(self, color="grey"):
+        super().__init__("wall", color)
+
+    def see_behind(self):
+        return False
+
+    def render(self, img):
+        fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
+
+
+class Door(WorldObj):
+    def __init__(self, color, is_open=False, is_locked=False):
+        super().__init__("door", color)
+        self.is_open = is_open
+        self.is_locked = is_locked
+
+    def can_overlap(self):
+        """The agent can only walk over this cell when the door is open"""
+        return self.is_open
+
+    def see_behind(self):
+        return self.is_open
+
+    def toggle(self, env, pos):
+        # If the player has the right key to open the door
+        if self.is_locked:
+            if isinstance(env.carrying, Key) and env.carrying.color == self.color:
+                self.is_locked = False
+                self.is_open = True
+                return True
+            return False
+
+        self.is_open = not self.is_open
+        return True
+
+    def encode(self):
+        """Encode the a description of this object as a 3-tuple of integers"""
+
+        # State, 0: open, 1: closed, 2: locked
+        if self.is_open:
+            state = 0
+        elif self.is_locked:
+            state = 2
+        # if door is closed and unlocked
+        elif not self.is_open:
+            state = 1
+        else:
+            raise ValueError(
+                f"There is no possible state encoding for the state:\n -Door Open: {self.is_open}\n -Door Closed: {not self.is_open}\n -Door Locked: {self.is_locked}"
+            )
+
+        return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
+
+    def render(self, img):
+        c = COLORS[self.color]
+
+        if self.is_open:
+            fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
+            fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
+            return
+
+        # Door frame and door
+        if self.is_locked:
+            fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
+            fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
+
+            # Draw key slot
+            fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
+        else:
+            fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
+            fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
+            fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
+            fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
+
+            # Draw door handle
+            fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
+
+
+class Key(WorldObj):
+    def __init__(self, color="blue"):
+        super().__init__("key", color)
+
+    def can_pickup(self):
+        return True
+
+    def render(self, img):
+        c = COLORS[self.color]
+
+        # Vertical quad
+        fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
+
+        # Teeth
+        fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
+        fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
+
+        # Ring
+        fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
+        fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
+
+
+class Ball(WorldObj):
+    def __init__(self, color="blue"):
+        super().__init__("ball", color)
+
+    def can_pickup(self):
+        return True
+
+    def render(self, img):
+        fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
+
+
+class Box(WorldObj):
+    def __init__(self, color, contains=None):
+        super().__init__("box", color)
+        self.contains = contains
+
+    def can_pickup(self):
+        return True
+
+    def render(self, img):
+        c = COLORS[self.color]
+
+        # Outline
+        fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
+        fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
+
+        # Horizontal slit
+        fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
+
+    def toggle(self, env, pos):
+        # Replace the box by its contents
+        env.grid.set(pos[0], pos[1], self.contains)
+        return True

+ 4 - 2
minigrid/envs/blockedunlockpickup.py

@@ -1,5 +1,7 @@
-from minigrid.minigrid import COLOR_NAMES, Ball, MissionSpace
-from minigrid.roomgrid import RoomGrid
+from minigrid.core.constants import COLOR_NAMES
+from minigrid.core.mission import MissionSpace
+from minigrid.core.roomgrid import RoomGrid
+from minigrid.core.world_object import Ball
 
 
 
 
 class BlockedUnlockPickupEnv(RoomGrid):
 class BlockedUnlockPickupEnv(RoomGrid):

+ 4 - 1
minigrid/envs/crossing.py

@@ -2,7 +2,10 @@ import itertools as itt
 
 
 import numpy as np
 import numpy as np
 
 
-from minigrid.minigrid import Goal, Grid, Lava, MiniGridEnv, MissionSpace
+from minigrid.core.grid import Grid
+from minigrid.core.mission import MissionSpace
+from minigrid.core.world_object import Goal, Lava
+from minigrid.minigrid import MiniGridEnv
 
 
 
 
 class CrossingEnv(MiniGridEnv):
 class CrossingEnv(MiniGridEnv):

+ 4 - 1
minigrid/envs/distshift.py

@@ -1,4 +1,7 @@
-from minigrid.minigrid import Goal, Grid, Lava, MiniGridEnv, MissionSpace
+from minigrid.core.grid import Grid
+from minigrid.core.mission import MissionSpace
+from minigrid.core.world_object import Goal, Lava
+from minigrid.minigrid import MiniGridEnv
 
 
 
 
 class DistShiftEnv(MiniGridEnv):
 class DistShiftEnv(MiniGridEnv):

+ 4 - 1
minigrid/envs/doorkey.py

@@ -1,4 +1,7 @@
-from minigrid.minigrid import Door, Goal, Grid, Key, MiniGridEnv, MissionSpace
+from minigrid.core.grid import Grid
+from minigrid.core.mission import MissionSpace
+from minigrid.core.world_object import Door, Goal, Key
+from minigrid.minigrid import MiniGridEnv
 
 
 
 
 class DoorKeyEnv(MiniGridEnv):
 class DoorKeyEnv(MiniGridEnv):

+ 4 - 1
minigrid/envs/dynamicobstacles.py

@@ -2,7 +2,10 @@ from operator import add
 
 
 from gymnasium.spaces import Discrete
 from gymnasium.spaces import Discrete
 
 
-from minigrid.minigrid import Ball, Goal, Grid, MiniGridEnv, MissionSpace
+from minigrid.core.grid import Grid
+from minigrid.core.mission import MissionSpace
+from minigrid.core.world_object import Ball, Goal
+from minigrid.minigrid import MiniGridEnv
 
 
 
 
 class DynamicObstaclesEnv(MiniGridEnv):
 class DynamicObstaclesEnv(MiniGridEnv):

+ 4 - 2
minigrid/envs/empty.py

@@ -1,8 +1,10 @@
-from minigrid.minigrid import Goal, Grid, MiniGridEnv, MissionSpace
+from minigrid.core.grid import Grid
+from minigrid.core.mission import MissionSpace
+from minigrid.core.world_object import Goal
+from minigrid.minigrid import MiniGridEnv
 
 
 
 
 class EmptyEnv(MiniGridEnv):
 class EmptyEnv(MiniGridEnv):
-
     """
     """
     ### Description
     ### Description
 
 

+ 5 - 1
minigrid/envs/fetch.py

@@ -1,4 +1,8 @@
-from minigrid.minigrid import COLOR_NAMES, Ball, Grid, Key, MiniGridEnv, MissionSpace
+from minigrid.core.constants import COLOR_NAMES
+from minigrid.core.grid import Grid
+from minigrid.core.mission import MissionSpace
+from minigrid.core.world_object import Ball, Key
+from minigrid.minigrid import MiniGridEnv
 
 
 
 
 class FetchEnv(MiniGridEnv):
 class FetchEnv(MiniGridEnv):

+ 4 - 1
minigrid/envs/fourrooms.py

@@ -1,4 +1,7 @@
-from minigrid.minigrid import Goal, Grid, MiniGridEnv, MissionSpace
+from minigrid.core.grid import Grid
+from minigrid.core.mission import MissionSpace
+from minigrid.core.world_object import Goal
+from minigrid.minigrid import MiniGridEnv
 
 
 
 
 class FourRoomsEnv(MiniGridEnv):
 class FourRoomsEnv(MiniGridEnv):

+ 5 - 2
minigrid/envs/gotodoor.py

@@ -1,8 +1,11 @@
-from minigrid.minigrid import COLOR_NAMES, Door, Grid, MiniGridEnv, MissionSpace
+from minigrid.core.constants import COLOR_NAMES
+from minigrid.core.grid import Grid
+from minigrid.core.mission import MissionSpace
+from minigrid.core.world_object import Door
+from minigrid.minigrid import MiniGridEnv
 
 
 
 
 class GoToDoorEnv(MiniGridEnv):
 class GoToDoorEnv(MiniGridEnv):
-
     """
     """
     ### Description
     ### Description
 
 

+ 5 - 9
minigrid/envs/gotoobject.py

@@ -1,12 +1,8 @@
-from minigrid.minigrid import (
-    COLOR_NAMES,
-    Ball,
-    Box,
-    Grid,
-    Key,
-    MiniGridEnv,
-    MissionSpace,
-)
+from minigrid.core.constants import COLOR_NAMES
+from minigrid.core.grid import Grid
+from minigrid.core.mission import MissionSpace
+from minigrid.core.world_object import Ball, Box, Key
+from minigrid.minigrid import MiniGridEnv
 
 
 
 
 class GoToObjectEnv(MiniGridEnv):
 class GoToObjectEnv(MiniGridEnv):

+ 3 - 2
minigrid/envs/keycorridor.py

@@ -1,5 +1,6 @@
-from minigrid.minigrid import COLOR_NAMES, MissionSpace
-from minigrid.roomgrid import RoomGrid
+from minigrid.core.constants import COLOR_NAMES
+from minigrid.core.mission import MissionSpace
+from minigrid.core.roomgrid import RoomGrid
 
 
 
 
 class KeyCorridorEnv(RoomGrid):
 class KeyCorridorEnv(RoomGrid):

+ 4 - 1
minigrid/envs/lavagap.py

@@ -1,6 +1,9 @@
 import numpy as np
 import numpy as np
 
 
-from minigrid.minigrid import Goal, Grid, Lava, MiniGridEnv, MissionSpace
+from minigrid.core.grid import Grid
+from minigrid.core.mission import MissionSpace
+from minigrid.core.world_object import Goal, Lava
+from minigrid.minigrid import MiniGridEnv
 
 
 
 
 class LavaGapEnv(MiniGridEnv):
 class LavaGapEnv(MiniGridEnv):

+ 5 - 10
minigrid/envs/lockedroom.py

@@ -1,13 +1,8 @@
-from minigrid.minigrid import (
-    COLOR_NAMES,
-    Door,
-    Goal,
-    Grid,
-    Key,
-    MiniGridEnv,
-    MissionSpace,
-    Wall,
-)
+from minigrid.core.constants import COLOR_NAMES
+from minigrid.core.grid import Grid
+from minigrid.core.mission import MissionSpace
+from minigrid.core.world_object import Door, Goal, Key, Wall
+from minigrid.minigrid import MiniGridEnv
 
 
 
 
 class LockedRoom:
 class LockedRoom:

+ 7 - 3
minigrid/envs/memory.py

@@ -1,6 +1,10 @@
 import numpy as np
 import numpy as np
 
 
-from minigrid.minigrid import Ball, Grid, Key, MiniGridEnv, MissionSpace, Wall
+from minigrid.core.actions import Actions
+from minigrid.core.grid import Grid
+from minigrid.core.mission import MissionSpace
+from minigrid.core.world_object import Ball, Key, Wall
+from minigrid.minigrid import MiniGridEnv
 
 
 
 
 class MemoryEnv(MiniGridEnv):
 class MemoryEnv(MiniGridEnv):
@@ -137,8 +141,8 @@ class MemoryEnv(MiniGridEnv):
         self.mission = "go to the matching object at the end of the hallway"
         self.mission = "go to the matching object at the end of the hallway"
 
 
     def step(self, action):
     def step(self, action):
-        if action == self.Actions.pickup:
-            action = self.Actions.toggle
+        if action == Actions.pickup:
+            action = Actions.toggle
         obs, reward, terminated, truncated, info = super().step(action)
         obs, reward, terminated, truncated, info = super().step(action)
 
 
         if tuple(self.agent_pos) == self.success_pos:
         if tuple(self.agent_pos) == self.success_pos:

+ 5 - 9
minigrid/envs/multiroom.py

@@ -1,12 +1,8 @@
-from minigrid.minigrid import (
-    COLOR_NAMES,
-    Door,
-    Goal,
-    Grid,
-    MiniGridEnv,
-    MissionSpace,
-    Wall,
-)
+from minigrid.core.constants import COLOR_NAMES
+from minigrid.core.grid import Grid
+from minigrid.core.mission import MissionSpace
+from minigrid.core.world_object import Door, Goal, Wall
+from minigrid.minigrid import MiniGridEnv
 
 
 
 
 class MultiRoom:
 class MultiRoom:

+ 4 - 2
minigrid/envs/obstructedmaze.py

@@ -1,5 +1,7 @@
-from minigrid.minigrid import COLOR_NAMES, DIR_TO_VEC, Ball, Box, Key, MissionSpace
-from minigrid.roomgrid import RoomGrid
+from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC
+from minigrid.core.mission import MissionSpace
+from minigrid.core.roomgrid import RoomGrid
+from minigrid.core.world_object import Ball, Box, Key
 
 
 
 
 class ObstructedMazeEnv(RoomGrid):
 class ObstructedMazeEnv(RoomGrid):

+ 5 - 10
minigrid/envs/playground.py

@@ -1,13 +1,8 @@
-from minigrid.minigrid import (
-    COLOR_NAMES,
-    Ball,
-    Box,
-    Door,
-    Grid,
-    Key,
-    MiniGridEnv,
-    MissionSpace,
-)
+from minigrid.core.constants import COLOR_NAMES
+from minigrid.core.grid import Grid
+from minigrid.core.mission import MissionSpace
+from minigrid.core.world_object import Ball, Box, Door, Key
+from minigrid.minigrid import MiniGridEnv
 
 
 
 
 class PlaygroundEnv(MiniGridEnv):
 class PlaygroundEnv(MiniGridEnv):

+ 5 - 9
minigrid/envs/putnear.py

@@ -1,12 +1,8 @@
-from minigrid.minigrid import (
-    COLOR_NAMES,
-    Ball,
-    Box,
-    Grid,
-    Key,
-    MiniGridEnv,
-    MissionSpace,
-)
+from minigrid.core.constants import COLOR_NAMES
+from minigrid.core.grid import Grid
+from minigrid.core.mission import MissionSpace
+from minigrid.core.world_object import Ball, Box, Key
+from minigrid.minigrid import MiniGridEnv
 
 
 
 
 class PutNearEnv(MiniGridEnv):
 class PutNearEnv(MiniGridEnv):

+ 4 - 1
minigrid/envs/redbluedoors.py

@@ -1,4 +1,7 @@
-from minigrid.minigrid import Door, Grid, MiniGridEnv, MissionSpace
+from minigrid.core.grid import Grid
+from minigrid.core.mission import MissionSpace
+from minigrid.core.world_object import Door
+from minigrid.minigrid import MiniGridEnv
 
 
 
 
 class RedBlueDoorEnv(MiniGridEnv):
 class RedBlueDoorEnv(MiniGridEnv):

+ 2 - 2
minigrid/envs/unlock.py

@@ -1,5 +1,5 @@
-from minigrid.minigrid import MissionSpace
-from minigrid.roomgrid import RoomGrid
+from minigrid.core.mission import MissionSpace
+from minigrid.core.roomgrid import RoomGrid
 
 
 
 
 class UnlockEnv(RoomGrid):
 class UnlockEnv(RoomGrid):

+ 3 - 2
minigrid/envs/unlockpickup.py

@@ -1,5 +1,6 @@
-from minigrid.minigrid import COLOR_NAMES, MissionSpace
-from minigrid.roomgrid import RoomGrid
+from minigrid.core.constants import COLOR_NAMES
+from minigrid.core.mission import MissionSpace
+from minigrid.core.roomgrid import RoomGrid
 
 
 
 
 class UnlockPickupEnv(RoomGrid):
 class UnlockPickupEnv(RoomGrid):

+ 1 - 1
minigrid/examples/manual_control.py

@@ -2,7 +2,7 @@
 
 
 import gymnasium as gym
 import gymnasium as gym
 
 
-from minigrid.window import Window
+from minigrid.utils.window import Window
 from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
 from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
 
 
 
 

+ 7 - 828
minigrid/minigrid.py

@@ -1,822 +1,17 @@
 import hashlib
 import hashlib
 import math
 import math
 from abc import abstractmethod
 from abc import abstractmethod
-from enum import IntEnum
-from typing import Any, Callable, Optional, Union
+from typing import Optional
 
 
 import gymnasium as gym
 import gymnasium as gym
 import numpy as np
 import numpy as np
 from gymnasium import spaces
 from gymnasium import spaces
-from gymnasium.utils import seeding
-
-# Size in pixels of a tile in the full-scale human view
-from minigrid.rendering import (
-    downsample,
-    fill_coords,
-    highlight_img,
-    point_in_circle,
-    point_in_line,
-    point_in_rect,
-    point_in_triangle,
-    rotate_fn,
-)
-from minigrid.window import Window
-
-TILE_PIXELS = 32
-
-# Map of color names to RGB values
-COLORS = {
-    "red": np.array([255, 0, 0]),
-    "green": np.array([0, 255, 0]),
-    "blue": np.array([0, 0, 255]),
-    "purple": np.array([112, 39, 195]),
-    "yellow": np.array([255, 255, 0]),
-    "grey": np.array([100, 100, 100]),
-}
-
-COLOR_NAMES = sorted(list(COLORS.keys()))
-
-# Used to map colors to integers
-COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5}
-
-IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
-
-# Map of object type to integers
-OBJECT_TO_IDX = {
-    "unseen": 0,
-    "empty": 1,
-    "wall": 2,
-    "floor": 3,
-    "door": 4,
-    "key": 5,
-    "ball": 6,
-    "box": 7,
-    "goal": 8,
-    "lava": 9,
-    "agent": 10,
-}
-
-IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
-
-# Map of state names to integers
-STATE_TO_IDX = {
-    "open": 0,
-    "closed": 1,
-    "locked": 2,
-}
-
-# Map of agent direction indices to vectors
-DIR_TO_VEC = [
-    # Pointing right (positive X)
-    np.array((1, 0)),
-    # Down (positive Y)
-    np.array((0, 1)),
-    # Pointing left (negative X)
-    np.array((-1, 0)),
-    # Up (negative Y)
-    np.array((0, -1)),
-]
-
-
-def check_if_no_duplicate(duplicate_list: list) -> bool:
-    """Check if given list contains any duplicates"""
-    return len(set(duplicate_list)) == len(duplicate_list)
-
-
-class MissionSpace(spaces.Space[str]):
-    r"""A space representing a mission for the Gym-Minigrid environments.
-    The space allows generating random mission strings constructed with an input placeholder list.
-    Example Usage::
-        >>> observation_space = MissionSpace(mission_func=lambda color: f"Get the {color} ball.",
-                                                ordered_placeholders=[["green", "blue"]])
-        >>> observation_space.sample()
-            "Get the green ball."
-        >>> observation_space = MissionSpace(mission_func=lambda : "Get the ball.".,
-                                                ordered_placeholders=None)
-        >>> observation_space.sample()
-            "Get the ball."
-    """
-
-    def __init__(
-        self,
-        mission_func: Callable[..., str],
-        ordered_placeholders: Optional["list[list[str]]"] = None,
-        seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
-    ):
-        r"""Constructor of :class:`MissionSpace` space.
-
-        Args:
-            mission_func (lambda _placeholders(str): _mission(str)): Function that generates a mission string from random placeholders.
-            ordered_placeholders (Optional["list[list[str]]"]): List of lists of placeholders ordered in placing order in the mission function mission_func.
-            seed: seed: The seed for sampling from the space.
-        """
-        # Check that the ordered placeholders and mission function are well defined.
-        if ordered_placeholders is not None:
-            assert (
-                len(ordered_placeholders) == mission_func.__code__.co_argcount
-            ), f"The number of placeholders {len(ordered_placeholders)} is different from the number of parameters in the mission function {mission_func.__code__.co_argcount}."
-            for placeholder_list in ordered_placeholders:
-                assert check_if_no_duplicate(
-                    placeholder_list
-                ), "Make sure that the placeholders don't have any duplicate values."
-        else:
-            assert (
-                mission_func.__code__.co_argcount == 0
-            ), f"If the ordered placeholders are {ordered_placeholders}, the mission function shouldn't have any parameters."
-
-        self.ordered_placeholders = ordered_placeholders
-        self.mission_func = mission_func
-
-        super().__init__(dtype=str, seed=seed)
-
-        # Check that mission_func returns a string
-        sampled_mission = self.sample()
-        assert isinstance(
-            sampled_mission, str
-        ), f"mission_func must return type str not {type(sampled_mission)}"
-
-    def sample(self) -> str:
-        """Sample a random mission string."""
-        if self.ordered_placeholders is not None:
-            placeholders = []
-            for rand_var_list in self.ordered_placeholders:
-                idx = self.np_random.integers(0, len(rand_var_list))
-
-                placeholders.append(rand_var_list[idx])
-
-            return self.mission_func(*placeholders)
-        else:
-            return self.mission_func()
-
-    def contains(self, x: Any) -> bool:
-        """Return boolean specifying if x is a valid member of this space."""
-        # Store a list of all the placeholders from self.ordered_placeholders that appear in x
-        if self.ordered_placeholders is not None:
-            check_placeholder_list = []
-            for placeholder_list in self.ordered_placeholders:
-                for placeholder in placeholder_list:
-                    if placeholder in x:
-                        check_placeholder_list.append(placeholder)
-
-            # Remove duplicates from the list
-            check_placeholder_list = list(set(check_placeholder_list))
-
-            start_id_placeholder = []
-            end_id_placeholder = []
-            # Get the starting and ending id of the identified placeholders with possible duplicates
-            new_check_placeholder_list = []
-            for placeholder in check_placeholder_list:
-                new_start_id_placeholder = [
-                    i for i in range(len(x)) if x.startswith(placeholder, i)
-                ]
-                new_check_placeholder_list += [placeholder] * len(
-                    new_start_id_placeholder
-                )
-                end_id_placeholder += [
-                    start_id + len(placeholder) - 1
-                    for start_id in new_start_id_placeholder
-                ]
-                start_id_placeholder += new_start_id_placeholder
-
-            # Order by starting id the placeholders
-            ordered_placeholder_list = sorted(
-                zip(
-                    start_id_placeholder, end_id_placeholder, new_check_placeholder_list
-                )
-            )
-
-            # Check for repeated placeholders contained in each other
-            remove_placeholder_id = []
-            for i, placeholder_1 in enumerate(ordered_placeholder_list):
-                starting_id = i + 1
-                for j, placeholder_2 in enumerate(
-                    ordered_placeholder_list[starting_id:]
-                ):
-                    # Check if place holder ids overlap and keep the longest
-                    if max(placeholder_1[0], placeholder_2[0]) < min(
-                        placeholder_1[1], placeholder_2[1]
-                    ):
-                        remove_placeholder = min(
-                            placeholder_1[2], placeholder_2[2], key=len
-                        )
-                        if remove_placeholder == placeholder_1[2]:
-                            remove_placeholder_id.append(i)
-                        else:
-                            remove_placeholder_id.append(i + j + 1)
-            for id in remove_placeholder_id:
-                del ordered_placeholder_list[id]
-
-            final_placeholders = [
-                placeholder[2] for placeholder in ordered_placeholder_list
-            ]
-
-            # Check that the identified final placeholders are in the same order as the original placeholders.
-            for orered_placeholder, final_placeholder in zip(
-                self.ordered_placeholders, final_placeholders
-            ):
-                if final_placeholder in orered_placeholder:
-                    continue
-                else:
-                    return False
-            try:
-                mission_string_with_placeholders = self.mission_func(
-                    *final_placeholders
-                )
-            except Exception as e:
-                print(
-                    f"{x} is not contained in MissionSpace due to the following exception: {e}"
-                )
-                return False
-
-            return bool(mission_string_with_placeholders == x)
-
-        else:
-            return bool(self.mission_func() == x)
-
-    def __repr__(self) -> str:
-        """Gives a string representation of this space."""
-        return f"MissionSpace({self.mission_func}, {self.ordered_placeholders})"
-
-    def __eq__(self, other) -> bool:
-        """Check whether ``other`` is equivalent to this instance."""
-        if isinstance(other, MissionSpace):
-
-            # Check that place holder lists are the same
-            if self.ordered_placeholders is not None:
-                # Check length
-                if (len(self.order_placeholder) == len(other.order_placeholder)) and (
-                    all(
-                        set(i) == set(j)
-                        for i, j in zip(self.order_placeholder, other.order_placeholder)
-                    )
-                ):
-                    # Check mission string is the same with dummy space placeholders
-                    test_placeholders = [""] * len(self.order_placeholder)
-                    mission = self.mission_func(*test_placeholders)
-                    other_mission = other.mission_func(*test_placeholders)
-                    return mission == other_mission
-            else:
-
-                # Check that other is also None
-                if other.ordered_placeholders is None:
-
-                    # Check mission string is the same
-                    mission = self.mission_func()
-                    other_mission = other.mission_func()
-                    return mission == other_mission
-
-        # If none of the statements above return then False
-        return False
-
-
-class WorldObj:
-    """
-    Base class for grid world objects
-    """
-
-    def __init__(self, type, color):
-        assert type in OBJECT_TO_IDX, type
-        assert color in COLOR_TO_IDX, color
-        self.type = type
-        self.color = color
-        self.contains = None
-
-        # Initial position of the object
-        self.init_pos = None
-
-        # Current position of the object
-        self.cur_pos = None
-
-    def can_overlap(self):
-        """Can the agent overlap with this?"""
-        return False
-
-    def can_pickup(self):
-        """Can the agent pick this up?"""
-        return False
-
-    def can_contain(self):
-        """Can this contain another object?"""
-        return False
-
-    def see_behind(self):
-        """Can the agent see behind this object?"""
-        return True
-
-    def toggle(self, env, pos):
-        """Method to trigger/toggle an action this object performs"""
-        return False
-
-    def encode(self):
-        """Encode the a description of this object as a 3-tuple of integers"""
-        return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
-
-    @staticmethod
-    def decode(type_idx, color_idx, state):
-        """Create an object from a 3-tuple state description"""
-
-        obj_type = IDX_TO_OBJECT[type_idx]
-        color = IDX_TO_COLOR[color_idx]
-
-        if obj_type == "empty" or obj_type == "unseen":
-            return None
-
-        # State, 0: open, 1: closed, 2: locked
-        is_open = state == 0
-        is_locked = state == 2
-
-        if obj_type == "wall":
-            v = Wall(color)
-        elif obj_type == "floor":
-            v = Floor(color)
-        elif obj_type == "ball":
-            v = Ball(color)
-        elif obj_type == "key":
-            v = Key(color)
-        elif obj_type == "box":
-            v = Box(color)
-        elif obj_type == "door":
-            v = Door(color, is_open, is_locked)
-        elif obj_type == "goal":
-            v = Goal()
-        elif obj_type == "lava":
-            v = Lava()
-        else:
-            assert False, "unknown object type in decode '%s'" % obj_type
-
-        return v
-
-    def render(self, r):
-        """Draw this object with the given renderer"""
-        raise NotImplementedError
-
-
-class Goal(WorldObj):
-    def __init__(self):
-        super().__init__("goal", "green")
-
-    def can_overlap(self):
-        return True
-
-    def render(self, img):
-        fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
-
-
-class Floor(WorldObj):
-    """
-    Colored floor tile the agent can walk over
-    """
-
-    def __init__(self, color="blue"):
-        super().__init__("floor", color)
-
-    def can_overlap(self):
-        return True
-
-    def render(self, img):
-        # Give the floor a pale color
-        color = COLORS[self.color] / 2
-        fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
-
-
-class Lava(WorldObj):
-    def __init__(self):
-        super().__init__("lava", "red")
-
-    def can_overlap(self):
-        return True
-
-    def render(self, img):
-        c = (255, 128, 0)
-
-        # Background color
-        fill_coords(img, point_in_rect(0, 1, 0, 1), c)
-
-        # Little waves
-        for i in range(3):
-            ylo = 0.3 + 0.2 * i
-            yhi = 0.4 + 0.2 * i
-            fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
-            fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
-            fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
-            fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
-
-
-class Wall(WorldObj):
-    def __init__(self, color="grey"):
-        super().__init__("wall", color)
-
-    def see_behind(self):
-        return False
-
-    def render(self, img):
-        fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
-
-
-class Door(WorldObj):
-    def __init__(self, color, is_open=False, is_locked=False):
-        super().__init__("door", color)
-        self.is_open = is_open
-        self.is_locked = is_locked
-
-    def can_overlap(self):
-        """The agent can only walk over this cell when the door is open"""
-        return self.is_open
-
-    def see_behind(self):
-        return self.is_open
-
-    def toggle(self, env, pos):
-        # If the player has the right key to open the door
-        if self.is_locked:
-            if isinstance(env.carrying, Key) and env.carrying.color == self.color:
-                self.is_locked = False
-                self.is_open = True
-                return True
-            return False
-
-        self.is_open = not self.is_open
-        return True
-
-    def encode(self):
-        """Encode the a description of this object as a 3-tuple of integers"""
-
-        # State, 0: open, 1: closed, 2: locked
-        if self.is_open:
-            state = 0
-        elif self.is_locked:
-            state = 2
-        # if door is closed and unlocked
-        elif not self.is_open:
-            state = 1
-        else:
-            raise ValueError(
-                f"There is no possible state encoding for the state:\n -Door Open: {self.is_open}\n -Door Closed: {not self.is_open}\n -Door Locked: {self.is_locked}"
-            )
-
-        return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
-
-    def render(self, img):
-        c = COLORS[self.color]
-
-        if self.is_open:
-            fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
-            fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
-            return
-
-        # Door frame and door
-        if self.is_locked:
-            fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
-            fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
-
-            # Draw key slot
-            fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
-        else:
-            fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
-            fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
-            fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
-            fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
-
-            # Draw door handle
-            fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
-
-
-class Key(WorldObj):
-    def __init__(self, color="blue"):
-        super().__init__("key", color)
-
-    def can_pickup(self):
-        return True
-
-    def render(self, img):
-        c = COLORS[self.color]
-
-        # Vertical quad
-        fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
-
-        # Teeth
-        fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
-        fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
-
-        # Ring
-        fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
-        fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
-
-
-class Ball(WorldObj):
-    def __init__(self, color="blue"):
-        super().__init__("ball", color)
-
-    def can_pickup(self):
-        return True
-
-    def render(self, img):
-        fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
-
-
-class Box(WorldObj):
-    def __init__(self, color, contains=None):
-        super().__init__("box", color)
-        self.contains = contains
 
 
-    def can_pickup(self):
-        return True
-
-    def render(self, img):
-        c = COLORS[self.color]
-
-        # Outline
-        fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
-        fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
-
-        # Horizontal slit
-        fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
-
-    def toggle(self, env, pos):
-        # Replace the box by its contents
-        env.grid.set(pos[0], pos[1], self.contains)
-        return True
-
-
-class Grid:
-    """
-    Represent a grid and operations on it
-    """
-
-    # Static cache of pre-renderer tiles
-    tile_cache = {}
-
-    def __init__(self, width, height):
-        assert width >= 3
-        assert height >= 3
-
-        self.width = width
-        self.height = height
-
-        self.grid = [None] * width * height
-
-    def __contains__(self, key):
-        if isinstance(key, WorldObj):
-            for e in self.grid:
-                if e is key:
-                    return True
-        elif isinstance(key, tuple):
-            for e in self.grid:
-                if e is None:
-                    continue
-                if (e.color, e.type) == key:
-                    return True
-                if key[0] is None and key[1] == e.type:
-                    return True
-        return False
-
-    def __eq__(self, other):
-        grid1 = self.encode()
-        grid2 = other.encode()
-        return np.array_equal(grid2, grid1)
-
-    def __ne__(self, other):
-        return not self == other
-
-    def copy(self):
-        from copy import deepcopy
-
-        return deepcopy(self)
-
-    def set(self, i, j, v):
-        assert i >= 0 and i < self.width
-        assert j >= 0 and j < self.height
-        self.grid[j * self.width + i] = v
-
-    def get(self, i, j):
-        assert i >= 0 and i < self.width
-        assert j >= 0 and j < self.height
-        return self.grid[j * self.width + i]
-
-    def horz_wall(self, x, y, length=None, obj_type=Wall):
-        if length is None:
-            length = self.width - x
-        for i in range(0, length):
-            self.set(x + i, y, obj_type())
-
-    def vert_wall(self, x, y, length=None, obj_type=Wall):
-        if length is None:
-            length = self.height - y
-        for j in range(0, length):
-            self.set(x, y + j, obj_type())
-
-    def wall_rect(self, x, y, w, h):
-        self.horz_wall(x, y, w)
-        self.horz_wall(x, y + h - 1, w)
-        self.vert_wall(x, y, h)
-        self.vert_wall(x + w - 1, y, h)
-
-    def rotate_left(self):
-        """
-        Rotate the grid to the left (counter-clockwise)
-        """
-
-        grid = Grid(self.height, self.width)
-
-        for i in range(self.width):
-            for j in range(self.height):
-                v = self.get(i, j)
-                grid.set(j, grid.height - 1 - i, v)
-
-        return grid
-
-    def slice(self, topX, topY, width, height):
-        """
-        Get a subset of the grid
-        """
-
-        grid = Grid(width, height)
-
-        for j in range(0, height):
-            for i in range(0, width):
-                x = topX + i
-                y = topY + j
-
-                if x >= 0 and x < self.width and y >= 0 and y < self.height:
-                    v = self.get(x, y)
-                else:
-                    v = Wall()
-
-                grid.set(i, j, v)
-
-        return grid
-
-    @classmethod
-    def render_tile(
-        cls, obj, agent_dir=None, highlight=False, tile_size=TILE_PIXELS, subdivs=3
-    ):
-        """
-        Render a tile and cache the result
-        """
-
-        # Hash map lookup key for the cache
-        key = (agent_dir, highlight, tile_size)
-        key = obj.encode() + key if obj else key
-
-        if key in cls.tile_cache:
-            return cls.tile_cache[key]
-
-        img = np.zeros(
-            shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
-        )
-
-        # Draw the grid lines (top and left edges)
-        fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
-        fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
-
-        if obj is not None:
-            obj.render(img)
-
-        # Overlay the agent on top
-        if agent_dir is not None:
-            tri_fn = point_in_triangle(
-                (0.12, 0.19),
-                (0.87, 0.50),
-                (0.12, 0.81),
-            )
-
-            # Rotate the agent based on its direction
-            tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
-            fill_coords(img, tri_fn, (255, 0, 0))
-
-        # Highlight the cell if needed
-        if highlight:
-            highlight_img(img)
-
-        # Downsample the image to perform supersampling/anti-aliasing
-        img = downsample(img, subdivs)
-
-        # Cache the rendered tile
-        cls.tile_cache[key] = img
-
-        return img
-
-    def render(self, tile_size, agent_pos, agent_dir=None, highlight_mask=None):
-        """
-        Render this grid at a given scale
-        :param r: target renderer object
-        :param tile_size: tile size in pixels
-        """
-
-        if highlight_mask is None:
-            highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
-
-        # Compute the total grid size
-        width_px = self.width * tile_size
-        height_px = self.height * tile_size
-
-        img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
-
-        # Render the grid
-        for j in range(0, self.height):
-            for i in range(0, self.width):
-                cell = self.get(i, j)
-
-                agent_here = np.array_equal(agent_pos, (i, j))
-                tile_img = Grid.render_tile(
-                    cell,
-                    agent_dir=agent_dir if agent_here else None,
-                    highlight=highlight_mask[i, j],
-                    tile_size=tile_size,
-                )
-
-                ymin = j * tile_size
-                ymax = (j + 1) * tile_size
-                xmin = i * tile_size
-                xmax = (i + 1) * tile_size
-                img[ymin:ymax, xmin:xmax, :] = tile_img
-
-        return img
-
-    def encode(self, vis_mask=None):
-        """
-        Produce a compact numpy encoding of the grid
-        """
-
-        if vis_mask is None:
-            vis_mask = np.ones((self.width, self.height), dtype=bool)
-
-        array = np.zeros((self.width, self.height, 3), dtype="uint8")
-
-        for i in range(self.width):
-            for j in range(self.height):
-                if vis_mask[i, j]:
-                    v = self.get(i, j)
-
-                    if v is None:
-                        array[i, j, 0] = OBJECT_TO_IDX["empty"]
-                        array[i, j, 1] = 0
-                        array[i, j, 2] = 0
-
-                    else:
-                        array[i, j, :] = v.encode()
-
-        return array
-
-    @staticmethod
-    def decode(array):
-        """
-        Decode an array grid encoding back into a grid
-        """
-
-        width, height, channels = array.shape
-        assert channels == 3
-
-        vis_mask = np.ones(shape=(width, height), dtype=bool)
-
-        grid = Grid(width, height)
-        for i in range(width):
-            for j in range(height):
-                type_idx, color_idx, state = array[i, j]
-                v = WorldObj.decode(type_idx, color_idx, state)
-                grid.set(i, j, v)
-                vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
-
-        return grid, vis_mask
-
-    def process_vis(self, agent_pos):
-        mask = np.zeros(shape=(self.width, self.height), dtype=bool)
-
-        mask[agent_pos[0], agent_pos[1]] = True
-
-        for j in reversed(range(0, self.height)):
-            for i in range(0, self.width - 1):
-                if not mask[i, j]:
-                    continue
-
-                cell = self.get(i, j)
-                if cell and not cell.see_behind():
-                    continue
-
-                mask[i + 1, j] = True
-                if j > 0:
-                    mask[i + 1, j - 1] = True
-                    mask[i, j - 1] = True
-
-            for i in reversed(range(1, self.width)):
-                if not mask[i, j]:
-                    continue
-
-                cell = self.get(i, j)
-                if cell and not cell.see_behind():
-                    continue
-
-                mask[i - 1, j] = True
-                if j > 0:
-                    mask[i - 1, j - 1] = True
-                    mask[i, j - 1] = True
-
-        for j in range(0, self.height):
-            for i in range(0, self.width):
-                if not mask[i, j]:
-                    self.set(i, j, None)
-
-        return mask
+from minigrid.core.actions import Actions
+from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC, TILE_PIXELS
+from minigrid.core.grid import Grid
+from minigrid.core.mission import MissionSpace
+from minigrid.utils.window import Window
 
 
 
 
 class MiniGridEnv(gym.Env):
 class MiniGridEnv(gym.Env):
@@ -829,22 +24,6 @@ class MiniGridEnv(gym.Env):
         "render_fps": 10,
         "render_fps": 10,
     }
     }
 
 
-    # Enumeration of possible actions
-    class Actions(IntEnum):
-        # Turn left, turn right, move forward
-        left = 0
-        right = 1
-        forward = 2
-        # Pick up an object
-        pickup = 3
-        # Drop an object
-        drop = 4
-        # Toggle/activate an object
-        toggle = 5
-
-        # Done completing task
-        done = 6
-
     def __init__(
     def __init__(
         self,
         self,
         mission_space: MissionSpace,
         mission_space: MissionSpace,
@@ -869,7 +48,7 @@ class MiniGridEnv(gym.Env):
             height = grid_size
             height = grid_size
 
 
         # Action enumeration for this environment
         # Action enumeration for this environment
-        self.actions = MiniGridEnv.Actions
+        self.actions = Actions
 
 
         # Actions are discrete integer values
         # Actions are discrete integer values
         self.action_space = spaces.Discrete(len(self.actions))
         self.action_space = spaces.Discrete(len(self.actions))

+ 0 - 0
minigrid/utils/__init__.py


minigrid/rendering.py → minigrid/utils/rendering.py


minigrid/window.py → minigrid/utils/window.py


+ 2 - 1
minigrid/wrappers.py

@@ -7,7 +7,8 @@ import numpy as np
 from gymnasium import spaces
 from gymnasium import spaces
 from gymnasium.core import ObservationWrapper, Wrapper
 from gymnasium.core import ObservationWrapper, Wrapper
 
 
-from minigrid.minigrid import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX, Goal
+from minigrid.core.constants import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX
+from minigrid.core.world_object import Goal
 
 
 
 
 class ReseedWrapper(Wrapper):
 class ReseedWrapper(Wrapper):

+ 1 - 0
pyproject.toml

@@ -31,6 +31,7 @@ reportUntypedFunctionDecorator = "none"
 reportMissingTypeStubs = false
 reportMissingTypeStubs = false
 reportUnboundVariable = "warning"
 reportUnboundVariable = "warning"
 reportGeneralTypeIssues ="none"
 reportGeneralTypeIssues ="none"
+reportPrivateImportUsage = "none"
 
 
 [tool.pytest.ini_options]
 [tool.pytest.ini_options]
 filterwarnings = ['ignore:.*step API.*:DeprecationWarning'] # TODO: to be removed when old step API is removed
 filterwarnings = ['ignore:.*step API.*:DeprecationWarning'] # TODO: to be removed when old step API is removed

+ 1 - 1
setup.py

@@ -26,7 +26,7 @@ setup(
         "Programming Language :: Python :: 3.9",
         "Programming Language :: Python :: 3.9",
         "Programming Language :: Python :: 3.10",
         "Programming Language :: Python :: 3.10",
     ],
     ],
-    version="1.2.1",
+    version="2.0.0",
     keywords="memory, environment, agent, rl, gymnasium",
     keywords="memory, environment, agent, rl, gymnasium",
     url="https://github.com/Farama-Foundation/MiniGrid",
     url="https://github.com/Farama-Foundation/MiniGrid",
     description="Minimalistic gridworld reinforcement learning environments",
     description="Minimalistic gridworld reinforcement learning environments",

+ 2 - 1
tests/test_envs.py

@@ -6,7 +6,8 @@ import pytest
 from gymnasium.envs.registration import EnvSpec
 from gymnasium.envs.registration import EnvSpec
 from gymnasium.utils.env_checker import check_env
 from gymnasium.utils.env_checker import check_env
 
 
-from minigrid.minigrid import Grid, MissionSpace
+from minigrid.core.grid import Grid
+from minigrid.core.mission import MissionSpace
 from tests.utils import all_testing_env_specs, assert_equals
 from tests.utils import all_testing_env_specs, assert_equals
 
 
 CHECK_ENV_IGNORE_WARNINGS = [
 CHECK_ENV_IGNORE_WARNINGS = [

+ 1 - 1
tests/test_scripts.py

@@ -3,7 +3,7 @@ import numpy as np
 
 
 from minigrid.examples.benchmark import benchmark
 from minigrid.examples.benchmark import benchmark
 from minigrid.examples.manual_control import key_handler, reset
 from minigrid.examples.manual_control import key_handler, reset
-from minigrid.window import Window
+from minigrid.utils.window import Window
 
 
 
 
 def test_benchmark():
 def test_benchmark():

+ 5 - 5
tests/test_wrappers.py

@@ -4,8 +4,8 @@ import gymnasium as gym
 import numpy as np
 import numpy as np
 import pytest
 import pytest
 
 
+from minigrid.core.actions import Actions
 from minigrid.envs import EmptyEnv
 from minigrid.envs import EmptyEnv
-from minigrid.minigrid import MiniGridEnv
 from minigrid.wrappers import (
 from minigrid.wrappers import (
     ActionBonus,
     ActionBonus,
     DictObservationSpaceWrapper,
     DictObservationSpaceWrapper,
@@ -79,9 +79,9 @@ def test_state_bonus_wrapper(env_id):
     env = gym.make(env_id)
     env = gym.make(env_id)
     wrapped_env = StateBonus(gym.make(env_id))
     wrapped_env = StateBonus(gym.make(env_id))
 
 
-    action_forward = MiniGridEnv.Actions.forward
-    action_left = MiniGridEnv.Actions.left
-    action_right = MiniGridEnv.Actions.right
+    action_forward = Actions.forward
+    action_left = Actions.left
+    action_right = Actions.right
 
 
     for _ in range(10):
     for _ in range(10):
         wrapped_env.reset()
         wrapped_env.reset()
@@ -109,7 +109,7 @@ def test_action_bonus_wrapper(env_id):
     env = gym.make(env_id)
     env = gym.make(env_id)
     wrapped_env = ActionBonus(gym.make(env_id))
     wrapped_env = ActionBonus(gym.make(env_id))
 
 
-    action = MiniGridEnv.Actions.forward
+    action = Actions.forward
 
 
     for _ in range(10):
     for _ in range(10):
         wrapped_env.reset()
         wrapped_env.reset()