Bläddra i källkod

Rename minigrid.py to minigrid_env.py (#249)

Rodrigo de Lazcano 2 år sedan
förälder
incheckning
56f6a57454

+ 2 - 2
README.md

@@ -94,13 +94,13 @@ pip install -e .
 There is a UI application which allows you to manually control the agent with the arrow keys:
 There is a UI application which allows you to manually control the agent with the arrow keys:
 
 
 ```
 ```
-./minigrid/examples/manual_control.py
+./minigrid/manual_control.py
 ```
 ```
 
 
 The environment being run can be selected with the `--env` option, eg:
 The environment being run can be selected with the `--env` option, eg:
 
 
 ```
 ```
-./minigrid/examples/manual_control.py --env MiniGrid-Empty-8x8-v0
+./minigrid/manual_control.py --env MiniGrid-Empty-8x8-v0
 ```
 ```
 
 
 ## Reinforcement Learning
 ## Reinforcement Learning

+ 1 - 1
minigrid/__init__.py

@@ -1,6 +1,6 @@
 from gymnasium.envs.registration import register
 from gymnasium.envs.registration import register
 
 
-from minigrid import minigrid, wrappers
+from minigrid import minigrid_env, wrappers
 from minigrid.core import roomgrid
 from minigrid.core import roomgrid
 from minigrid.core.world_object import Wall
 from minigrid.core.world_object import Wall
 
 

minigrid/examples/benchmark.py → minigrid/benchmark.py


+ 1 - 1
minigrid/core/roomgrid.py

@@ -3,7 +3,7 @@ import numpy as np
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.world_object import Ball, Box, Door, Key
 from minigrid.core.world_object import Ball, Box, Door, Key
-from minigrid.minigrid import MiniGridEnv
+from minigrid.minigrid_env import MiniGridEnv
 
 
 
 
 def reject_next_to(env, pos):
 def reject_next_to(env, pos):

+ 1 - 1
minigrid/envs/crossing.py

@@ -5,7 +5,7 @@ import numpy as np
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Goal, Lava
 from minigrid.core.world_object import Goal, Lava
-from minigrid.minigrid import MiniGridEnv
+from minigrid.minigrid_env import MiniGridEnv
 
 
 
 
 class CrossingEnv(MiniGridEnv):
 class CrossingEnv(MiniGridEnv):

+ 1 - 1
minigrid/envs/distshift.py

@@ -1,7 +1,7 @@
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Goal, Lava
 from minigrid.core.world_object import Goal, Lava
-from minigrid.minigrid import MiniGridEnv
+from minigrid.minigrid_env import MiniGridEnv
 
 
 
 
 class DistShiftEnv(MiniGridEnv):
 class DistShiftEnv(MiniGridEnv):

+ 1 - 1
minigrid/envs/doorkey.py

@@ -1,7 +1,7 @@
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Door, Goal, Key
 from minigrid.core.world_object import Door, Goal, Key
-from minigrid.minigrid import MiniGridEnv
+from minigrid.minigrid_env import MiniGridEnv
 
 
 
 
 class DoorKeyEnv(MiniGridEnv):
 class DoorKeyEnv(MiniGridEnv):

+ 1 - 1
minigrid/envs/dynamicobstacles.py

@@ -5,7 +5,7 @@ from gymnasium.spaces import Discrete
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Ball, Goal
 from minigrid.core.world_object import Ball, Goal
-from minigrid.minigrid import MiniGridEnv
+from minigrid.minigrid_env import MiniGridEnv
 
 
 
 
 class DynamicObstaclesEnv(MiniGridEnv):
 class DynamicObstaclesEnv(MiniGridEnv):

+ 1 - 1
minigrid/envs/empty.py

@@ -1,7 +1,7 @@
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Goal
 from minigrid.core.world_object import Goal
-from minigrid.minigrid import MiniGridEnv
+from minigrid.minigrid_env import MiniGridEnv
 
 
 
 
 class EmptyEnv(MiniGridEnv):
 class EmptyEnv(MiniGridEnv):

+ 1 - 1
minigrid/envs/fetch.py

@@ -2,7 +2,7 @@ from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Ball, Key
 from minigrid.core.world_object import Ball, Key
-from minigrid.minigrid import MiniGridEnv
+from minigrid.minigrid_env import MiniGridEnv
 
 
 
 
 class FetchEnv(MiniGridEnv):
 class FetchEnv(MiniGridEnv):

+ 1 - 1
minigrid/envs/fourrooms.py

@@ -1,7 +1,7 @@
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Goal
 from minigrid.core.world_object import Goal
-from minigrid.minigrid import MiniGridEnv
+from minigrid.minigrid_env import MiniGridEnv
 
 
 
 
 class FourRoomsEnv(MiniGridEnv):
 class FourRoomsEnv(MiniGridEnv):

+ 1 - 1
minigrid/envs/gotodoor.py

@@ -2,7 +2,7 @@ from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Door
 from minigrid.core.world_object import Door
-from minigrid.minigrid import MiniGridEnv
+from minigrid.minigrid_env import MiniGridEnv
 
 
 
 
 class GoToDoorEnv(MiniGridEnv):
 class GoToDoorEnv(MiniGridEnv):

+ 1 - 1
minigrid/envs/gotoobject.py

@@ -2,7 +2,7 @@ from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Ball, Box, Key
 from minigrid.core.world_object import Ball, Box, Key
-from minigrid.minigrid import MiniGridEnv
+from minigrid.minigrid_env import MiniGridEnv
 
 
 
 
 class GoToObjectEnv(MiniGridEnv):
 class GoToObjectEnv(MiniGridEnv):

+ 1 - 1
minigrid/envs/lavagap.py

@@ -3,7 +3,7 @@ import numpy as np
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Goal, Lava
 from minigrid.core.world_object import Goal, Lava
-from minigrid.minigrid import MiniGridEnv
+from minigrid.minigrid_env import MiniGridEnv
 
 
 
 
 class LavaGapEnv(MiniGridEnv):
 class LavaGapEnv(MiniGridEnv):

+ 1 - 1
minigrid/envs/lockedroom.py

@@ -2,7 +2,7 @@ from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Door, Goal, Key, Wall
 from minigrid.core.world_object import Door, Goal, Key, Wall
-from minigrid.minigrid import MiniGridEnv
+from minigrid.minigrid_env import MiniGridEnv
 
 
 
 
 class LockedRoom:
 class LockedRoom:

+ 1 - 1
minigrid/envs/memory.py

@@ -4,7 +4,7 @@ from minigrid.core.actions import Actions
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Ball, Key, Wall
 from minigrid.core.world_object import Ball, Key, Wall
-from minigrid.minigrid import MiniGridEnv
+from minigrid.minigrid_env import MiniGridEnv
 
 
 
 
 class MemoryEnv(MiniGridEnv):
 class MemoryEnv(MiniGridEnv):

+ 1 - 1
minigrid/envs/multiroom.py

@@ -2,7 +2,7 @@ from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Door, Goal, Wall
 from minigrid.core.world_object import Door, Goal, Wall
-from minigrid.minigrid import MiniGridEnv
+from minigrid.minigrid_env import MiniGridEnv
 
 
 
 
 class MultiRoom:
 class MultiRoom:

+ 1 - 1
minigrid/envs/playground.py

@@ -2,7 +2,7 @@ from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Ball, Box, Door, Key
 from minigrid.core.world_object import Ball, Box, Door, Key
-from minigrid.minigrid import MiniGridEnv
+from minigrid.minigrid_env import MiniGridEnv
 
 
 
 
 class PlaygroundEnv(MiniGridEnv):
 class PlaygroundEnv(MiniGridEnv):

+ 1 - 1
minigrid/envs/putnear.py

@@ -2,7 +2,7 @@ from minigrid.core.constants import COLOR_NAMES
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Ball, Box, Key
 from minigrid.core.world_object import Ball, Box, Key
-from minigrid.minigrid import MiniGridEnv
+from minigrid.minigrid_env import MiniGridEnv
 
 
 
 
 class PutNearEnv(MiniGridEnv):
 class PutNearEnv(MiniGridEnv):

+ 1 - 1
minigrid/envs/redbluedoors.py

@@ -1,7 +1,7 @@
 from minigrid.core.grid import Grid
 from minigrid.core.grid import Grid
 from minigrid.core.mission import MissionSpace
 from minigrid.core.mission import MissionSpace
 from minigrid.core.world_object import Door
 from minigrid.core.world_object import Door
-from minigrid.minigrid import MiniGridEnv
+from minigrid.minigrid_env import MiniGridEnv
 
 
 
 
 class RedBlueDoorEnv(MiniGridEnv):
 class RedBlueDoorEnv(MiniGridEnv):

minigrid/examples/manual_control.py → minigrid/manual_control.py


+ 0 - 716
minigrid/minigrid.py

@@ -1,716 +0,0 @@
-import hashlib
-import math
-from abc import abstractmethod
-from typing import Optional
-
-import gymnasium as gym
-import numpy as np
-from gymnasium import spaces
-
-from minigrid.core.actions import Actions
-from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC, TILE_PIXELS
-from minigrid.core.grid import Grid
-from minigrid.core.mission import MissionSpace
-from minigrid.utils.window import Window
-
-
-class MiniGridEnv(gym.Env):
-    """
-    2D grid world game environment
-    """
-
-    metadata = {
-        "render_modes": ["human", "rgb_array"],
-        "render_fps": 10,
-    }
-
-    def __init__(
-        self,
-        mission_space: MissionSpace,
-        grid_size: int = None,
-        width: int = None,
-        height: int = None,
-        max_steps: int = 100,
-        see_through_walls: bool = False,
-        agent_view_size: int = 7,
-        render_mode: Optional[str] = None,
-        highlight: bool = True,
-        tile_size: int = TILE_PIXELS,
-        agent_pov: bool = False,
-    ):
-        # Initialize mission
-        self.mission = mission_space.sample()
-
-        # Can't set both grid_size and width/height
-        if grid_size:
-            assert width is None and height is None
-            width = grid_size
-            height = grid_size
-
-        # Action enumeration for this environment
-        self.actions = Actions
-
-        # Actions are discrete integer values
-        self.action_space = spaces.Discrete(len(self.actions))
-
-        # Number of cells (width and height) in the agent view
-        assert agent_view_size % 2 == 1
-        assert agent_view_size >= 3
-        self.agent_view_size = agent_view_size
-
-        # Observations are dictionaries containing an
-        # encoding of the grid and a textual 'mission' string
-        image_observation_space = spaces.Box(
-            low=0,
-            high=255,
-            shape=(self.agent_view_size, self.agent_view_size, 3),
-            dtype="uint8",
-        )
-        self.observation_space = spaces.Dict(
-            {
-                "image": image_observation_space,
-                "direction": spaces.Discrete(4),
-                "mission": mission_space,
-            }
-        )
-
-        # Range of possible rewards
-        self.reward_range = (0, 1)
-
-        self.window: Window = None
-
-        # Environment configuration
-        self.width = width
-        self.height = height
-        self.max_steps = max_steps
-        self.see_through_walls = see_through_walls
-
-        # Current position and direction of the agent
-        self.agent_pos: np.ndarray = None
-        self.agent_dir: int = None
-
-        # Current grid and mission and carryinh
-        self.grid = Grid(width, height)
-        self.carrying = None
-
-        # Rendering attributes
-        self.render_mode = render_mode
-        self.highlight = highlight
-        self.tile_size = tile_size
-        self.agent_pov = agent_pov
-
-    def reset(self, *, seed=None, options=None):
-        super().reset(seed=seed)
-
-        # Reinitialize episode-specific variables
-        self.agent_pos = (-1, -1)
-        self.agent_dir = -1
-
-        # Generate a new random grid at the start of each episode
-        self._gen_grid(self.width, self.height)
-
-        # These fields should be defined by _gen_grid
-        assert (
-            self.agent_pos >= (0, 0)
-            if isinstance(self.agent_pos, tuple)
-            else all(self.agent_pos >= 0) and self.agent_dir >= 0
-        )
-
-        # Check that the agent doesn't overlap with an object
-        start_cell = self.grid.get(*self.agent_pos)
-        assert start_cell is None or start_cell.can_overlap()
-
-        # Item picked up, being carried, initially nothing
-        self.carrying = None
-
-        # Step count since episode start
-        self.step_count = 0
-
-        if self.render_mode == "human":
-            self.render()
-
-        # Return first observation
-        obs = self.gen_obs()
-
-        return obs, {}
-
-    def hash(self, size=16):
-        """Compute a hash that uniquely identifies the current state of the environment.
-        :param size: Size of the hashing
-        """
-        sample_hash = hashlib.sha256()
-
-        to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
-        for item in to_encode:
-            sample_hash.update(str(item).encode("utf8"))
-
-        return sample_hash.hexdigest()[:size]
-
-    @property
-    def steps_remaining(self):
-        return self.max_steps - self.step_count
-
-    def __str__(self):
-        """
-        Produce a pretty string of the environment's grid along with the agent.
-        A grid cell is represented by 2-character string, the first one for
-        the object and the second one for the color.
-        """
-
-        # Map of object types to short string
-        OBJECT_TO_STR = {
-            "wall": "W",
-            "floor": "F",
-            "door": "D",
-            "key": "K",
-            "ball": "A",
-            "box": "B",
-            "goal": "G",
-            "lava": "V",
-        }
-
-        # Map agent's direction to short string
-        AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
-
-        str = ""
-
-        for j in range(self.grid.height):
-
-            for i in range(self.grid.width):
-                if i == self.agent_pos[0] and j == self.agent_pos[1]:
-                    str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
-                    continue
-
-                c = self.grid.get(i, j)
-
-                if c is None:
-                    str += "  "
-                    continue
-
-                if c.type == "door":
-                    if c.is_open:
-                        str += "__"
-                    elif c.is_locked:
-                        str += "L" + c.color[0].upper()
-                    else:
-                        str += "D" + c.color[0].upper()
-                    continue
-
-                str += OBJECT_TO_STR[c.type] + c.color[0].upper()
-
-            if j < self.grid.height - 1:
-                str += "\n"
-
-        return str
-
-    @abstractmethod
-    def _gen_grid(self, width, height):
-        pass
-
-    def _reward(self):
-        """
-        Compute the reward to be given upon success
-        """
-
-        return 1 - 0.9 * (self.step_count / self.max_steps)
-
-    def _rand_int(self, low, high):
-        """
-        Generate random integer in [low,high[
-        """
-
-        return self.np_random.integers(low, high)
-
-    def _rand_float(self, low, high):
-        """
-        Generate random float in [low,high[
-        """
-
-        return self.np_random.uniform(low, high)
-
-    def _rand_bool(self):
-        """
-        Generate random boolean value
-        """
-
-        return self.np_random.integers(0, 2) == 0
-
-    def _rand_elem(self, iterable):
-        """
-        Pick a random element in a list
-        """
-
-        lst = list(iterable)
-        idx = self._rand_int(0, len(lst))
-        return lst[idx]
-
-    def _rand_subset(self, iterable, num_elems):
-        """
-        Sample a random subset of distinct elements of a list
-        """
-
-        lst = list(iterable)
-        assert num_elems <= len(lst)
-
-        out = []
-
-        while len(out) < num_elems:
-            elem = self._rand_elem(lst)
-            lst.remove(elem)
-            out.append(elem)
-
-        return out
-
-    def _rand_color(self):
-        """
-        Generate a random color name (string)
-        """
-
-        return self._rand_elem(COLOR_NAMES)
-
-    def _rand_pos(self, xLow, xHigh, yLow, yHigh):
-        """
-        Generate a random (x,y) position tuple
-        """
-
-        return (
-            self.np_random.integers(xLow, xHigh),
-            self.np_random.integers(yLow, yHigh),
-        )
-
-    def place_obj(self, obj, top=None, size=None, reject_fn=None, max_tries=math.inf):
-        """
-        Place an object at an empty position in the grid
-
-        :param top: top-left position of the rectangle where to place
-        :param size: size of the rectangle where to place
-        :param reject_fn: function to filter out potential positions
-        """
-
-        if top is None:
-            top = (0, 0)
-        else:
-            top = (max(top[0], 0), max(top[1], 0))
-
-        if size is None:
-            size = (self.grid.width, self.grid.height)
-
-        num_tries = 0
-
-        while True:
-            # This is to handle with rare cases where rejection sampling
-            # gets stuck in an infinite loop
-            if num_tries > max_tries:
-                raise RecursionError("rejection sampling failed in place_obj")
-
-            num_tries += 1
-
-            pos = np.array(
-                (
-                    self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
-                    self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
-                )
-            )
-
-            pos = tuple(pos)
-
-            # Don't place the object on top of another object
-            if self.grid.get(*pos) is not None:
-                continue
-
-            # Don't place the object where the agent is
-            if np.array_equal(pos, self.agent_pos):
-                continue
-
-            # Check if there is a filtering criterion
-            if reject_fn and reject_fn(self, pos):
-                continue
-
-            break
-
-        self.grid.set(pos[0], pos[1], obj)
-
-        if obj is not None:
-            obj.init_pos = pos
-            obj.cur_pos = pos
-
-        return pos
-
-    def put_obj(self, obj, i, j):
-        """
-        Put an object at a specific position in the grid
-        """
-
-        self.grid.set(i, j, obj)
-        obj.init_pos = (i, j)
-        obj.cur_pos = (i, j)
-
-    def place_agent(self, top=None, size=None, rand_dir=True, max_tries=math.inf):
-        """
-        Set the agent's starting point at an empty position in the grid
-        """
-
-        self.agent_pos = (-1, -1)
-        pos = self.place_obj(None, top, size, max_tries=max_tries)
-        self.agent_pos = pos
-
-        if rand_dir:
-            self.agent_dir = self._rand_int(0, 4)
-
-        return pos
-
-    @property
-    def dir_vec(self):
-        """
-        Get the direction vector for the agent, pointing in the direction
-        of forward movement.
-        """
-
-        assert self.agent_dir >= 0 and self.agent_dir < 4
-        return DIR_TO_VEC[self.agent_dir]
-
-    @property
-    def right_vec(self):
-        """
-        Get the vector pointing to the right of the agent.
-        """
-
-        dx, dy = self.dir_vec
-        return np.array((-dy, dx))
-
-    @property
-    def front_pos(self):
-        """
-        Get the position of the cell that is right in front of the agent
-        """
-
-        return self.agent_pos + self.dir_vec
-
-    def get_view_coords(self, i, j):
-        """
-        Translate and rotate absolute grid coordinates (i, j) into the
-        agent's partially observable view (sub-grid). Note that the resulting
-        coordinates may be negative or outside of the agent's view size.
-        """
-
-        ax, ay = self.agent_pos
-        dx, dy = self.dir_vec
-        rx, ry = self.right_vec
-
-        # Compute the absolute coordinates of the top-left view corner
-        sz = self.agent_view_size
-        hs = self.agent_view_size // 2
-        tx = ax + (dx * (sz - 1)) - (rx * hs)
-        ty = ay + (dy * (sz - 1)) - (ry * hs)
-
-        lx = i - tx
-        ly = j - ty
-
-        # Project the coordinates of the object relative to the top-left
-        # corner onto the agent's own coordinate system
-        vx = rx * lx + ry * ly
-        vy = -(dx * lx + dy * ly)
-
-        return vx, vy
-
-    def get_view_exts(self, agent_view_size=None):
-        """
-        Get the extents of the square set of tiles visible to the agent
-        Note: the bottom extent indices are not included in the set
-        if agent_view_size is None, use self.agent_view_size
-        """
-
-        agent_view_size = agent_view_size or self.agent_view_size
-
-        # Facing right
-        if self.agent_dir == 0:
-            topX = self.agent_pos[0]
-            topY = self.agent_pos[1] - agent_view_size // 2
-        # Facing down
-        elif self.agent_dir == 1:
-            topX = self.agent_pos[0] - agent_view_size // 2
-            topY = self.agent_pos[1]
-        # Facing left
-        elif self.agent_dir == 2:
-            topX = self.agent_pos[0] - agent_view_size + 1
-            topY = self.agent_pos[1] - agent_view_size // 2
-        # Facing up
-        elif self.agent_dir == 3:
-            topX = self.agent_pos[0] - agent_view_size // 2
-            topY = self.agent_pos[1] - agent_view_size + 1
-        else:
-            assert False, "invalid agent direction"
-
-        botX = topX + agent_view_size
-        botY = topY + agent_view_size
-
-        return (topX, topY, botX, botY)
-
-    def relative_coords(self, x, y):
-        """
-        Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
-        """
-
-        vx, vy = self.get_view_coords(x, y)
-
-        if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
-            return None
-
-        return vx, vy
-
-    def in_view(self, x, y):
-        """
-        check if a grid position is visible to the agent
-        """
-
-        return self.relative_coords(x, y) is not None
-
-    def agent_sees(self, x, y):
-        """
-        Check if a non-empty grid position is visible to the agent
-        """
-
-        coordinates = self.relative_coords(x, y)
-        if coordinates is None:
-            return False
-        vx, vy = coordinates
-
-        obs = self.gen_obs()
-        obs_grid, _ = Grid.decode(obs["image"])
-        obs_cell = obs_grid.get(vx, vy)
-        world_cell = self.grid.get(x, y)
-
-        assert world_cell is not None
-
-        return obs_cell is not None and obs_cell.type == world_cell.type
-
-    def step(self, action):
-        self.step_count += 1
-
-        reward = 0
-        terminated = False
-        truncated = False
-
-        # Get the position in front of the agent
-        fwd_pos = self.front_pos
-
-        # Get the contents of the cell in front of the agent
-        fwd_cell = self.grid.get(*fwd_pos)
-
-        # Rotate left
-        if action == self.actions.left:
-            self.agent_dir -= 1
-            if self.agent_dir < 0:
-                self.agent_dir += 4
-
-        # Rotate right
-        elif action == self.actions.right:
-            self.agent_dir = (self.agent_dir + 1) % 4
-
-        # Move forward
-        elif action == self.actions.forward:
-            if fwd_cell is None or fwd_cell.can_overlap():
-                self.agent_pos = tuple(fwd_pos)
-            if fwd_cell is not None and fwd_cell.type == "goal":
-                terminated = True
-                reward = self._reward()
-            if fwd_cell is not None and fwd_cell.type == "lava":
-                terminated = True
-
-        # Pick up an object
-        elif action == self.actions.pickup:
-            if fwd_cell and fwd_cell.can_pickup():
-                if self.carrying is None:
-                    self.carrying = fwd_cell
-                    self.carrying.cur_pos = np.array([-1, -1])
-                    self.grid.set(fwd_pos[0], fwd_pos[1], None)
-
-        # Drop an object
-        elif action == self.actions.drop:
-            if not fwd_cell and self.carrying:
-                self.grid.set(fwd_pos[0], fwd_pos[1], self.carrying)
-                self.carrying.cur_pos = fwd_pos
-                self.carrying = None
-
-        # Toggle/activate an object
-        elif action == self.actions.toggle:
-            if fwd_cell:
-                fwd_cell.toggle(self, fwd_pos)
-
-        # Done action (not used by default)
-        elif action == self.actions.done:
-            pass
-
-        else:
-            raise ValueError(f"Unknown action: {action}")
-
-        if self.step_count >= self.max_steps:
-            truncated = True
-
-        if self.render_mode == "human":
-            self.render()
-
-        obs = self.gen_obs()
-
-        return obs, reward, terminated, truncated, {}
-
-    def gen_obs_grid(self, agent_view_size=None):
-        """
-        Generate the sub-grid observed by the agent.
-        This method also outputs a visibility mask telling us which grid
-        cells the agent can actually see.
-        if agent_view_size is None, self.agent_view_size is used
-        """
-
-        topX, topY, botX, botY = self.get_view_exts(agent_view_size)
-
-        agent_view_size = agent_view_size or self.agent_view_size
-
-        grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
-
-        for i in range(self.agent_dir + 1):
-            grid = grid.rotate_left()
-
-        # Process occluders and visibility
-        # Note that this incurs some performance cost
-        if not self.see_through_walls:
-            vis_mask = grid.process_vis(
-                agent_pos=(agent_view_size // 2, agent_view_size - 1)
-            )
-        else:
-            vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
-
-        # Make it so the agent sees what it's carrying
-        # We do this by placing the carried object at the agent's position
-        # in the agent's partially observable view
-        agent_pos = grid.width // 2, grid.height - 1
-        if self.carrying:
-            grid.set(*agent_pos, self.carrying)
-        else:
-            grid.set(*agent_pos, None)
-
-        return grid, vis_mask
-
-    def gen_obs(self):
-        """
-        Generate the agent's view (partially observable, low-resolution encoding)
-        """
-
-        grid, vis_mask = self.gen_obs_grid()
-
-        # Encode the partially observable view into a numpy array
-        image = grid.encode(vis_mask)
-
-        # Observations are dictionaries containing:
-        # - an image (partially observable view of the environment)
-        # - the agent's direction/orientation (acting as a compass)
-        # - a textual mission string (instructions for the agent)
-        obs = {"image": image, "direction": self.agent_dir, "mission": self.mission}
-
-        return obs
-
-    def get_pov_render(self, tile_size):
-        """
-        Render an agent's POV observation for visualization
-        """
-        grid, vis_mask = self.gen_obs_grid()
-
-        # Render the whole grid
-        img = grid.render(
-            tile_size,
-            agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
-            agent_dir=3,
-            highlight_mask=vis_mask,
-        )
-
-        return img
-
-    def get_full_render(self, highlight, tile_size):
-        """
-        Render a non-paratial observation for visualization
-        """
-        # Compute which cells are visible to the agent
-        _, vis_mask = self.gen_obs_grid()
-
-        # Compute the world coordinates of the bottom-left corner
-        # of the agent's view area
-        f_vec = self.dir_vec
-        r_vec = self.right_vec
-        top_left = (
-            self.agent_pos
-            + f_vec * (self.agent_view_size - 1)
-            - r_vec * (self.agent_view_size // 2)
-        )
-
-        # Mask of which cells to highlight
-        highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
-
-        # For each cell in the visibility mask
-        for vis_j in range(0, self.agent_view_size):
-            for vis_i in range(0, self.agent_view_size):
-                # If this cell is not visible, don't highlight it
-                if not vis_mask[vis_i, vis_j]:
-                    continue
-
-                # Compute the world coordinates of this cell
-                abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
-
-                if abs_i < 0 or abs_i >= self.width:
-                    continue
-                if abs_j < 0 or abs_j >= self.height:
-                    continue
-
-                # Mark this cell to be highlighted
-                highlight_mask[abs_i, abs_j] = True
-
-        # Render the whole grid
-        img = self.grid.render(
-            tile_size,
-            self.agent_pos,
-            self.agent_dir,
-            highlight_mask=highlight_mask if highlight else None,
-        )
-
-        return img
-
-    def get_frame(
-        self,
-        highlight: bool = True,
-        tile_size: int = TILE_PIXELS,
-        agent_pov: bool = False,
-    ):
-        """Returns an RGB image corresponding to the whole environment or the agent's point of view.
-
-        Args:
-
-            highlight (bool): If true, the agent's field of view or point of view is highlighted with a lighter gray color.
-            tile_size (int): How many pixels will form a tile from the NxM grid.
-            agent_pov (bool): If true, the rendered frame will only contain the point of view of the agent.
-
-        Returns:
-
-            frame (np.ndarray): A frame of type numpy.ndarray with shape (x, y, 3) representing RGB values for the x-by-y pixel image.
-
-        """
-
-        if agent_pov:
-            return self.get_pov_render(tile_size)
-        else:
-            return self.get_full_render(highlight, tile_size)
-
-    def render(self):
-
-        img = self.get_frame(self.highlight, self.tile_size, self.agent_pov)
-
-        if self.render_mode == "human":
-            if self.window is None:
-                self.window = Window("minigrid")
-                self.window.show(block=False)
-            self.window.set_caption(self.mission)
-            self.window.show_img(img)
-        elif self.render_mode == "rgb_array":
-            return img
-
-    def close(self):
-        if self.window:
-            self.window.close()

Filskillnaden har hållts tillbaka eftersom den är för stor
+ 1537 - 0
minigrid/minigrid_env.py


+ 2 - 2
tests/test_scripts.py

@@ -1,8 +1,8 @@
 import gymnasium as gym
 import gymnasium as gym
 import numpy as np
 import numpy as np
 
 
-from minigrid.examples.benchmark import benchmark
-from minigrid.examples.manual_control import key_handler, reset
+from minigrid.benchmark import benchmark
+from minigrid.manual_control import key_handler, reset
 from minigrid.utils.window import Window
 from minigrid.utils.window import Window