|
@@ -2,826 +2,16 @@ import hashlib
|
|
|
import math
|
|
|
from abc import abstractmethod
|
|
|
from enum import IntEnum
|
|
|
-from typing import Any, Callable, Optional, Union
|
|
|
+from typing import Optional
|
|
|
|
|
|
import gymnasium as gym
|
|
|
import numpy as np
|
|
|
from gymnasium import spaces
|
|
|
-from gymnasium.utils import seeding
|
|
|
-
|
|
|
-# Size in pixels of a tile in the full-scale human view
|
|
|
-from minigrid.utils.rendering import (
|
|
|
- downsample,
|
|
|
- fill_coords,
|
|
|
- highlight_img,
|
|
|
- point_in_circle,
|
|
|
- point_in_line,
|
|
|
- point_in_rect,
|
|
|
- point_in_triangle,
|
|
|
- rotate_fn,
|
|
|
-)
|
|
|
-from minigrid.utils.window import Window
|
|
|
-
|
|
|
-TILE_PIXELS = 32
|
|
|
-
|
|
|
-# Map of color names to RGB values
|
|
|
-COLORS = {
|
|
|
- "red": np.array([255, 0, 0]),
|
|
|
- "green": np.array([0, 255, 0]),
|
|
|
- "blue": np.array([0, 0, 255]),
|
|
|
- "purple": np.array([112, 39, 195]),
|
|
|
- "yellow": np.array([255, 255, 0]),
|
|
|
- "grey": np.array([100, 100, 100]),
|
|
|
-}
|
|
|
-
|
|
|
-COLOR_NAMES = sorted(list(COLORS.keys()))
|
|
|
-
|
|
|
-# Used to map colors to integers
|
|
|
-COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5}
|
|
|
-
|
|
|
-IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
|
|
|
-
|
|
|
-# Map of object type to integers
|
|
|
-OBJECT_TO_IDX = {
|
|
|
- "unseen": 0,
|
|
|
- "empty": 1,
|
|
|
- "wall": 2,
|
|
|
- "floor": 3,
|
|
|
- "door": 4,
|
|
|
- "key": 5,
|
|
|
- "ball": 6,
|
|
|
- "box": 7,
|
|
|
- "goal": 8,
|
|
|
- "lava": 9,
|
|
|
- "agent": 10,
|
|
|
-}
|
|
|
-
|
|
|
-IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
|
|
|
-
|
|
|
-# Map of state names to integers
|
|
|
-STATE_TO_IDX = {
|
|
|
- "open": 0,
|
|
|
- "closed": 1,
|
|
|
- "locked": 2,
|
|
|
-}
|
|
|
-
|
|
|
-# Map of agent direction indices to vectors
|
|
|
-DIR_TO_VEC = [
|
|
|
- # Pointing right (positive X)
|
|
|
- np.array((1, 0)),
|
|
|
- # Down (positive Y)
|
|
|
- np.array((0, 1)),
|
|
|
- # Pointing left (negative X)
|
|
|
- np.array((-1, 0)),
|
|
|
- # Up (negative Y)
|
|
|
- np.array((0, -1)),
|
|
|
-]
|
|
|
-
|
|
|
-
|
|
|
-def check_if_no_duplicate(duplicate_list: list) -> bool:
|
|
|
- """Check if given list contains any duplicates"""
|
|
|
- return len(set(duplicate_list)) == len(duplicate_list)
|
|
|
-
|
|
|
-
|
|
|
-class MissionSpace(spaces.Space[str]):
|
|
|
- r"""A space representing a mission for the Minigrid environments.
|
|
|
- The space allows generating random mission strings constructed with an input placeholder list.
|
|
|
- Example Usage::
|
|
|
- >>> def _gen_mission() -> str:
|
|
|
- >>> return "Get the ball."
|
|
|
- >>> observation_space = MissionSpace(mission_func=_gen_mission)
|
|
|
- >>> observation_space.sample()
|
|
|
- "Get the ball."
|
|
|
- >>> def _gen_mission(color: str, object_type:str) -> str:
|
|
|
- >>> return f"Get the {color} {object_type}."
|
|
|
- >>> observation_space = MissionSpace(
|
|
|
- >>> mission_func=_gen_mission,
|
|
|
- >>> ordered_placeholders=[["green", "blue"], ["ball", "key"]],
|
|
|
- >>> )
|
|
|
- >>> observation_space.sample()
|
|
|
- "Get the green ball."
|
|
|
- """
|
|
|
-
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- mission_func: Callable[..., str],
|
|
|
- ordered_placeholders: Optional["list[list[str]]"] = None,
|
|
|
- seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
|
|
|
- ):
|
|
|
- r"""Constructor of :class:`MissionSpace` space.
|
|
|
-
|
|
|
- Args:
|
|
|
- mission_func (lambda _placeholders(str): _mission(str)): Function that generates a mission string from random placeholders.
|
|
|
- ordered_placeholders (Optional["list[list[str]]"]): List of lists of placeholders ordered in placing order in the mission function mission_func.
|
|
|
- seed: seed: The seed for sampling from the space.
|
|
|
- """
|
|
|
- # Check that the ordered placeholders and mission function are well defined.
|
|
|
- if ordered_placeholders is not None:
|
|
|
- assert (
|
|
|
- len(ordered_placeholders) == mission_func.__code__.co_argcount
|
|
|
- ), f"The number of placeholders {len(ordered_placeholders)} is different from the number of parameters in the mission function {mission_func.__code__.co_argcount}."
|
|
|
- for placeholder_list in ordered_placeholders:
|
|
|
- assert check_if_no_duplicate(
|
|
|
- placeholder_list
|
|
|
- ), "Make sure that the placeholders don't have any duplicate values."
|
|
|
- else:
|
|
|
- assert (
|
|
|
- mission_func.__code__.co_argcount == 0
|
|
|
- ), f"If the ordered placeholders are {ordered_placeholders}, the mission function shouldn't have any parameters."
|
|
|
-
|
|
|
- self.ordered_placeholders = ordered_placeholders
|
|
|
- self.mission_func = mission_func
|
|
|
-
|
|
|
- super().__init__(dtype=str, seed=seed)
|
|
|
-
|
|
|
- # Check that mission_func returns a string
|
|
|
- sampled_mission = self.sample()
|
|
|
- assert isinstance(
|
|
|
- sampled_mission, str
|
|
|
- ), f"mission_func must return type str not {type(sampled_mission)}"
|
|
|
-
|
|
|
- def sample(self) -> str:
|
|
|
- """Sample a random mission string."""
|
|
|
- if self.ordered_placeholders is not None:
|
|
|
- placeholders = []
|
|
|
- for rand_var_list in self.ordered_placeholders:
|
|
|
- idx = self.np_random.integers(0, len(rand_var_list))
|
|
|
-
|
|
|
- placeholders.append(rand_var_list[idx])
|
|
|
-
|
|
|
- return self.mission_func(*placeholders)
|
|
|
- else:
|
|
|
- return self.mission_func()
|
|
|
-
|
|
|
- def contains(self, x: Any) -> bool:
|
|
|
- """Return boolean specifying if x is a valid member of this space."""
|
|
|
- # Store a list of all the placeholders from self.ordered_placeholders that appear in x
|
|
|
- if self.ordered_placeholders is not None:
|
|
|
- check_placeholder_list = []
|
|
|
- for placeholder_list in self.ordered_placeholders:
|
|
|
- for placeholder in placeholder_list:
|
|
|
- if placeholder in x:
|
|
|
- check_placeholder_list.append(placeholder)
|
|
|
-
|
|
|
- # Remove duplicates from the list
|
|
|
- check_placeholder_list = list(set(check_placeholder_list))
|
|
|
-
|
|
|
- start_id_placeholder = []
|
|
|
- end_id_placeholder = []
|
|
|
- # Get the starting and ending id of the identified placeholders with possible duplicates
|
|
|
- new_check_placeholder_list = []
|
|
|
- for placeholder in check_placeholder_list:
|
|
|
- new_start_id_placeholder = [
|
|
|
- i for i in range(len(x)) if x.startswith(placeholder, i)
|
|
|
- ]
|
|
|
- new_check_placeholder_list += [placeholder] * len(
|
|
|
- new_start_id_placeholder
|
|
|
- )
|
|
|
- end_id_placeholder += [
|
|
|
- start_id + len(placeholder) - 1
|
|
|
- for start_id in new_start_id_placeholder
|
|
|
- ]
|
|
|
- start_id_placeholder += new_start_id_placeholder
|
|
|
-
|
|
|
- # Order by starting id the placeholders
|
|
|
- ordered_placeholder_list = sorted(
|
|
|
- zip(
|
|
|
- start_id_placeholder, end_id_placeholder, new_check_placeholder_list
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
- # Check for repeated placeholders contained in each other
|
|
|
- remove_placeholder_id = []
|
|
|
- for i, placeholder_1 in enumerate(ordered_placeholder_list):
|
|
|
- starting_id = i + 1
|
|
|
- for j, placeholder_2 in enumerate(
|
|
|
- ordered_placeholder_list[starting_id:]
|
|
|
- ):
|
|
|
- # Check if place holder ids overlap and keep the longest
|
|
|
- if max(placeholder_1[0], placeholder_2[0]) < min(
|
|
|
- placeholder_1[1], placeholder_2[1]
|
|
|
- ):
|
|
|
- remove_placeholder = min(
|
|
|
- placeholder_1[2], placeholder_2[2], key=len
|
|
|
- )
|
|
|
- if remove_placeholder == placeholder_1[2]:
|
|
|
- remove_placeholder_id.append(i)
|
|
|
- else:
|
|
|
- remove_placeholder_id.append(i + j + 1)
|
|
|
- for id in remove_placeholder_id:
|
|
|
- del ordered_placeholder_list[id]
|
|
|
-
|
|
|
- final_placeholders = [
|
|
|
- placeholder[2] for placeholder in ordered_placeholder_list
|
|
|
- ]
|
|
|
-
|
|
|
- # Check that the identified final placeholders are in the same order as the original placeholders.
|
|
|
- for orered_placeholder, final_placeholder in zip(
|
|
|
- self.ordered_placeholders, final_placeholders
|
|
|
- ):
|
|
|
- if final_placeholder in orered_placeholder:
|
|
|
- continue
|
|
|
- else:
|
|
|
- return False
|
|
|
- try:
|
|
|
- mission_string_with_placeholders = self.mission_func(
|
|
|
- *final_placeholders
|
|
|
- )
|
|
|
- except Exception as e:
|
|
|
- print(
|
|
|
- f"{x} is not contained in MissionSpace due to the following exception: {e}"
|
|
|
- )
|
|
|
- return False
|
|
|
-
|
|
|
- return bool(mission_string_with_placeholders == x)
|
|
|
-
|
|
|
- else:
|
|
|
- return bool(self.mission_func() == x)
|
|
|
-
|
|
|
- def __repr__(self) -> str:
|
|
|
- """Gives a string representation of this space."""
|
|
|
- return f"MissionSpace({self.mission_func}, {self.ordered_placeholders})"
|
|
|
-
|
|
|
- def __eq__(self, other) -> bool:
|
|
|
- """Check whether ``other`` is equivalent to this instance."""
|
|
|
- if isinstance(other, MissionSpace):
|
|
|
-
|
|
|
- # Check that place holder lists are the same
|
|
|
- if self.ordered_placeholders is not None:
|
|
|
- # Check length
|
|
|
- if (len(self.order_placeholder) == len(other.order_placeholder)) and (
|
|
|
- all(
|
|
|
- set(i) == set(j)
|
|
|
- for i, j in zip(self.order_placeholder, other.order_placeholder)
|
|
|
- )
|
|
|
- ):
|
|
|
- # Check mission string is the same with dummy space placeholders
|
|
|
- test_placeholders = [""] * len(self.order_placeholder)
|
|
|
- mission = self.mission_func(*test_placeholders)
|
|
|
- other_mission = other.mission_func(*test_placeholders)
|
|
|
- return mission == other_mission
|
|
|
- else:
|
|
|
-
|
|
|
- # Check that other is also None
|
|
|
- if other.ordered_placeholders is None:
|
|
|
-
|
|
|
- # Check mission string is the same
|
|
|
- mission = self.mission_func()
|
|
|
- other_mission = other.mission_func()
|
|
|
- return mission == other_mission
|
|
|
-
|
|
|
- # If none of the statements above return then False
|
|
|
- return False
|
|
|
-
|
|
|
-
|
|
|
-class WorldObj:
|
|
|
- """
|
|
|
- Base class for grid world objects
|
|
|
- """
|
|
|
-
|
|
|
- def __init__(self, type, color):
|
|
|
- assert type in OBJECT_TO_IDX, type
|
|
|
- assert color in COLOR_TO_IDX, color
|
|
|
- self.type = type
|
|
|
- self.color = color
|
|
|
- self.contains = None
|
|
|
-
|
|
|
- # Initial position of the object
|
|
|
- self.init_pos = None
|
|
|
-
|
|
|
- # Current position of the object
|
|
|
- self.cur_pos = None
|
|
|
-
|
|
|
- def can_overlap(self):
|
|
|
- """Can the agent overlap with this?"""
|
|
|
- return False
|
|
|
-
|
|
|
- def can_pickup(self):
|
|
|
- """Can the agent pick this up?"""
|
|
|
- return False
|
|
|
-
|
|
|
- def can_contain(self):
|
|
|
- """Can this contain another object?"""
|
|
|
- return False
|
|
|
-
|
|
|
- def see_behind(self):
|
|
|
- """Can the agent see behind this object?"""
|
|
|
- return True
|
|
|
-
|
|
|
- def toggle(self, env, pos):
|
|
|
- """Method to trigger/toggle an action this object performs"""
|
|
|
- return False
|
|
|
-
|
|
|
- def encode(self):
|
|
|
- """Encode the a description of this object as a 3-tuple of integers"""
|
|
|
- return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def decode(type_idx, color_idx, state):
|
|
|
- """Create an object from a 3-tuple state description"""
|
|
|
-
|
|
|
- obj_type = IDX_TO_OBJECT[type_idx]
|
|
|
- color = IDX_TO_COLOR[color_idx]
|
|
|
-
|
|
|
- if obj_type == "empty" or obj_type == "unseen":
|
|
|
- return None
|
|
|
-
|
|
|
- # State, 0: open, 1: closed, 2: locked
|
|
|
- is_open = state == 0
|
|
|
- is_locked = state == 2
|
|
|
-
|
|
|
- if obj_type == "wall":
|
|
|
- v = Wall(color)
|
|
|
- elif obj_type == "floor":
|
|
|
- v = Floor(color)
|
|
|
- elif obj_type == "ball":
|
|
|
- v = Ball(color)
|
|
|
- elif obj_type == "key":
|
|
|
- v = Key(color)
|
|
|
- elif obj_type == "box":
|
|
|
- v = Box(color)
|
|
|
- elif obj_type == "door":
|
|
|
- v = Door(color, is_open, is_locked)
|
|
|
- elif obj_type == "goal":
|
|
|
- v = Goal()
|
|
|
- elif obj_type == "lava":
|
|
|
- v = Lava()
|
|
|
- else:
|
|
|
- assert False, "unknown object type in decode '%s'" % obj_type
|
|
|
-
|
|
|
- return v
|
|
|
-
|
|
|
- def render(self, r):
|
|
|
- """Draw this object with the given renderer"""
|
|
|
- raise NotImplementedError
|
|
|
-
|
|
|
-
|
|
|
-class Goal(WorldObj):
|
|
|
- def __init__(self):
|
|
|
- super().__init__("goal", "green")
|
|
|
-
|
|
|
- def can_overlap(self):
|
|
|
- return True
|
|
|
-
|
|
|
- def render(self, img):
|
|
|
- fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
|
|
|
-
|
|
|
-
|
|
|
-class Floor(WorldObj):
|
|
|
- """
|
|
|
- Colored floor tile the agent can walk over
|
|
|
- """
|
|
|
-
|
|
|
- def __init__(self, color="blue"):
|
|
|
- super().__init__("floor", color)
|
|
|
-
|
|
|
- def can_overlap(self):
|
|
|
- return True
|
|
|
-
|
|
|
- def render(self, img):
|
|
|
- # Give the floor a pale color
|
|
|
- color = COLORS[self.color] / 2
|
|
|
- fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
|
|
|
-
|
|
|
-
|
|
|
-class Lava(WorldObj):
|
|
|
- def __init__(self):
|
|
|
- super().__init__("lava", "red")
|
|
|
-
|
|
|
- def can_overlap(self):
|
|
|
- return True
|
|
|
-
|
|
|
- def render(self, img):
|
|
|
- c = (255, 128, 0)
|
|
|
-
|
|
|
- # Background color
|
|
|
- fill_coords(img, point_in_rect(0, 1, 0, 1), c)
|
|
|
-
|
|
|
- # Little waves
|
|
|
- for i in range(3):
|
|
|
- ylo = 0.3 + 0.2 * i
|
|
|
- yhi = 0.4 + 0.2 * i
|
|
|
- fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
|
|
|
- fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
|
|
|
- fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
|
|
|
- fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
|
|
|
-
|
|
|
-
|
|
|
-class Wall(WorldObj):
|
|
|
- def __init__(self, color="grey"):
|
|
|
- super().__init__("wall", color)
|
|
|
-
|
|
|
- def see_behind(self):
|
|
|
- return False
|
|
|
-
|
|
|
- def render(self, img):
|
|
|
- fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
|
|
|
-
|
|
|
-
|
|
|
-class Door(WorldObj):
|
|
|
- def __init__(self, color, is_open=False, is_locked=False):
|
|
|
- super().__init__("door", color)
|
|
|
- self.is_open = is_open
|
|
|
- self.is_locked = is_locked
|
|
|
-
|
|
|
- def can_overlap(self):
|
|
|
- """The agent can only walk over this cell when the door is open"""
|
|
|
- return self.is_open
|
|
|
-
|
|
|
- def see_behind(self):
|
|
|
- return self.is_open
|
|
|
-
|
|
|
- def toggle(self, env, pos):
|
|
|
- # If the player has the right key to open the door
|
|
|
- if self.is_locked:
|
|
|
- if isinstance(env.carrying, Key) and env.carrying.color == self.color:
|
|
|
- self.is_locked = False
|
|
|
- self.is_open = True
|
|
|
- return True
|
|
|
- return False
|
|
|
-
|
|
|
- self.is_open = not self.is_open
|
|
|
- return True
|
|
|
-
|
|
|
- def encode(self):
|
|
|
- """Encode the a description of this object as a 3-tuple of integers"""
|
|
|
-
|
|
|
- # State, 0: open, 1: closed, 2: locked
|
|
|
- if self.is_open:
|
|
|
- state = 0
|
|
|
- elif self.is_locked:
|
|
|
- state = 2
|
|
|
- # if door is closed and unlocked
|
|
|
- elif not self.is_open:
|
|
|
- state = 1
|
|
|
- else:
|
|
|
- raise ValueError(
|
|
|
- f"There is no possible state encoding for the state:\n -Door Open: {self.is_open}\n -Door Closed: {not self.is_open}\n -Door Locked: {self.is_locked}"
|
|
|
- )
|
|
|
-
|
|
|
- return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
|
|
|
-
|
|
|
- def render(self, img):
|
|
|
- c = COLORS[self.color]
|
|
|
-
|
|
|
- if self.is_open:
|
|
|
- fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
|
|
|
- fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
|
|
|
- return
|
|
|
-
|
|
|
- # Door frame and door
|
|
|
- if self.is_locked:
|
|
|
- fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
|
|
|
- fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
|
|
|
-
|
|
|
- # Draw key slot
|
|
|
- fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
|
|
|
- else:
|
|
|
- fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
|
|
|
- fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
|
|
|
- fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
|
|
|
- fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
|
|
|
-
|
|
|
- # Draw door handle
|
|
|
- fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
|
|
|
-
|
|
|
-
|
|
|
-class Key(WorldObj):
|
|
|
- def __init__(self, color="blue"):
|
|
|
- super().__init__("key", color)
|
|
|
-
|
|
|
- def can_pickup(self):
|
|
|
- return True
|
|
|
-
|
|
|
- def render(self, img):
|
|
|
- c = COLORS[self.color]
|
|
|
-
|
|
|
- # Vertical quad
|
|
|
- fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
|
|
|
-
|
|
|
- # Teeth
|
|
|
- fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
|
|
|
- fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
|
|
|
-
|
|
|
- # Ring
|
|
|
- fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
|
|
|
- fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
|
|
|
-
|
|
|
-
|
|
|
-class Ball(WorldObj):
|
|
|
- def __init__(self, color="blue"):
|
|
|
- super().__init__("ball", color)
|
|
|
-
|
|
|
- def can_pickup(self):
|
|
|
- return True
|
|
|
|
|
|
- def render(self, img):
|
|
|
- fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
|
|
|
-
|
|
|
-
|
|
|
-class Box(WorldObj):
|
|
|
- def __init__(self, color, contains=None):
|
|
|
- super().__init__("box", color)
|
|
|
- self.contains = contains
|
|
|
-
|
|
|
- def can_pickup(self):
|
|
|
- return True
|
|
|
-
|
|
|
- def render(self, img):
|
|
|
- c = COLORS[self.color]
|
|
|
-
|
|
|
- # Outline
|
|
|
- fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
|
|
|
- fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
|
|
|
-
|
|
|
- # Horizontal slit
|
|
|
- fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
|
|
|
-
|
|
|
- def toggle(self, env, pos):
|
|
|
- # Replace the box by its contents
|
|
|
- env.grid.set(pos[0], pos[1], self.contains)
|
|
|
- return True
|
|
|
-
|
|
|
-
|
|
|
-class Grid:
|
|
|
- """
|
|
|
- Represent a grid and operations on it
|
|
|
- """
|
|
|
-
|
|
|
- # Static cache of pre-renderer tiles
|
|
|
- tile_cache = {}
|
|
|
-
|
|
|
- def __init__(self, width, height):
|
|
|
- assert width >= 3
|
|
|
- assert height >= 3
|
|
|
-
|
|
|
- self.width = width
|
|
|
- self.height = height
|
|
|
-
|
|
|
- self.grid = [None] * width * height
|
|
|
-
|
|
|
- def __contains__(self, key):
|
|
|
- if isinstance(key, WorldObj):
|
|
|
- for e in self.grid:
|
|
|
- if e is key:
|
|
|
- return True
|
|
|
- elif isinstance(key, tuple):
|
|
|
- for e in self.grid:
|
|
|
- if e is None:
|
|
|
- continue
|
|
|
- if (e.color, e.type) == key:
|
|
|
- return True
|
|
|
- if key[0] is None and key[1] == e.type:
|
|
|
- return True
|
|
|
- return False
|
|
|
-
|
|
|
- def __eq__(self, other):
|
|
|
- grid1 = self.encode()
|
|
|
- grid2 = other.encode()
|
|
|
- return np.array_equal(grid2, grid1)
|
|
|
-
|
|
|
- def __ne__(self, other):
|
|
|
- return not self == other
|
|
|
-
|
|
|
- def copy(self):
|
|
|
- from copy import deepcopy
|
|
|
-
|
|
|
- return deepcopy(self)
|
|
|
-
|
|
|
- def set(self, i, j, v):
|
|
|
- assert i >= 0 and i < self.width
|
|
|
- assert j >= 0 and j < self.height
|
|
|
- self.grid[j * self.width + i] = v
|
|
|
-
|
|
|
- def get(self, i, j):
|
|
|
- assert i >= 0 and i < self.width
|
|
|
- assert j >= 0 and j < self.height
|
|
|
- return self.grid[j * self.width + i]
|
|
|
-
|
|
|
- def horz_wall(self, x, y, length=None, obj_type=Wall):
|
|
|
- if length is None:
|
|
|
- length = self.width - x
|
|
|
- for i in range(0, length):
|
|
|
- self.set(x + i, y, obj_type())
|
|
|
-
|
|
|
- def vert_wall(self, x, y, length=None, obj_type=Wall):
|
|
|
- if length is None:
|
|
|
- length = self.height - y
|
|
|
- for j in range(0, length):
|
|
|
- self.set(x, y + j, obj_type())
|
|
|
-
|
|
|
- def wall_rect(self, x, y, w, h):
|
|
|
- self.horz_wall(x, y, w)
|
|
|
- self.horz_wall(x, y + h - 1, w)
|
|
|
- self.vert_wall(x, y, h)
|
|
|
- self.vert_wall(x + w - 1, y, h)
|
|
|
-
|
|
|
- def rotate_left(self):
|
|
|
- """
|
|
|
- Rotate the grid to the left (counter-clockwise)
|
|
|
- """
|
|
|
-
|
|
|
- grid = Grid(self.height, self.width)
|
|
|
-
|
|
|
- for i in range(self.width):
|
|
|
- for j in range(self.height):
|
|
|
- v = self.get(i, j)
|
|
|
- grid.set(j, grid.height - 1 - i, v)
|
|
|
-
|
|
|
- return grid
|
|
|
-
|
|
|
- def slice(self, topX, topY, width, height):
|
|
|
- """
|
|
|
- Get a subset of the grid
|
|
|
- """
|
|
|
-
|
|
|
- grid = Grid(width, height)
|
|
|
-
|
|
|
- for j in range(0, height):
|
|
|
- for i in range(0, width):
|
|
|
- x = topX + i
|
|
|
- y = topY + j
|
|
|
-
|
|
|
- if x >= 0 and x < self.width and y >= 0 and y < self.height:
|
|
|
- v = self.get(x, y)
|
|
|
- else:
|
|
|
- v = Wall()
|
|
|
-
|
|
|
- grid.set(i, j, v)
|
|
|
-
|
|
|
- return grid
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def render_tile(
|
|
|
- cls, obj, agent_dir=None, highlight=False, tile_size=TILE_PIXELS, subdivs=3
|
|
|
- ):
|
|
|
- """
|
|
|
- Render a tile and cache the result
|
|
|
- """
|
|
|
-
|
|
|
- # Hash map lookup key for the cache
|
|
|
- key = (agent_dir, highlight, tile_size)
|
|
|
- key = obj.encode() + key if obj else key
|
|
|
-
|
|
|
- if key in cls.tile_cache:
|
|
|
- return cls.tile_cache[key]
|
|
|
-
|
|
|
- img = np.zeros(
|
|
|
- shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
|
|
|
- )
|
|
|
-
|
|
|
- # Draw the grid lines (top and left edges)
|
|
|
- fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
|
|
|
- fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
|
|
|
-
|
|
|
- if obj is not None:
|
|
|
- obj.render(img)
|
|
|
-
|
|
|
- # Overlay the agent on top
|
|
|
- if agent_dir is not None:
|
|
|
- tri_fn = point_in_triangle(
|
|
|
- (0.12, 0.19),
|
|
|
- (0.87, 0.50),
|
|
|
- (0.12, 0.81),
|
|
|
- )
|
|
|
-
|
|
|
- # Rotate the agent based on its direction
|
|
|
- tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
|
|
|
- fill_coords(img, tri_fn, (255, 0, 0))
|
|
|
-
|
|
|
- # Highlight the cell if needed
|
|
|
- if highlight:
|
|
|
- highlight_img(img)
|
|
|
-
|
|
|
- # Downsample the image to perform supersampling/anti-aliasing
|
|
|
- img = downsample(img, subdivs)
|
|
|
-
|
|
|
- # Cache the rendered tile
|
|
|
- cls.tile_cache[key] = img
|
|
|
-
|
|
|
- return img
|
|
|
-
|
|
|
- def render(self, tile_size, agent_pos, agent_dir=None, highlight_mask=None):
|
|
|
- """
|
|
|
- Render this grid at a given scale
|
|
|
- :param r: target renderer object
|
|
|
- :param tile_size: tile size in pixels
|
|
|
- """
|
|
|
-
|
|
|
- if highlight_mask is None:
|
|
|
- highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
|
|
|
-
|
|
|
- # Compute the total grid size
|
|
|
- width_px = self.width * tile_size
|
|
|
- height_px = self.height * tile_size
|
|
|
-
|
|
|
- img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
|
|
|
-
|
|
|
- # Render the grid
|
|
|
- for j in range(0, self.height):
|
|
|
- for i in range(0, self.width):
|
|
|
- cell = self.get(i, j)
|
|
|
-
|
|
|
- agent_here = np.array_equal(agent_pos, (i, j))
|
|
|
- tile_img = Grid.render_tile(
|
|
|
- cell,
|
|
|
- agent_dir=agent_dir if agent_here else None,
|
|
|
- highlight=highlight_mask[i, j],
|
|
|
- tile_size=tile_size,
|
|
|
- )
|
|
|
-
|
|
|
- ymin = j * tile_size
|
|
|
- ymax = (j + 1) * tile_size
|
|
|
- xmin = i * tile_size
|
|
|
- xmax = (i + 1) * tile_size
|
|
|
- img[ymin:ymax, xmin:xmax, :] = tile_img
|
|
|
-
|
|
|
- return img
|
|
|
-
|
|
|
- def encode(self, vis_mask=None):
|
|
|
- """
|
|
|
- Produce a compact numpy encoding of the grid
|
|
|
- """
|
|
|
-
|
|
|
- if vis_mask is None:
|
|
|
- vis_mask = np.ones((self.width, self.height), dtype=bool)
|
|
|
-
|
|
|
- array = np.zeros((self.width, self.height, 3), dtype="uint8")
|
|
|
-
|
|
|
- for i in range(self.width):
|
|
|
- for j in range(self.height):
|
|
|
- if vis_mask[i, j]:
|
|
|
- v = self.get(i, j)
|
|
|
-
|
|
|
- if v is None:
|
|
|
- array[i, j, 0] = OBJECT_TO_IDX["empty"]
|
|
|
- array[i, j, 1] = 0
|
|
|
- array[i, j, 2] = 0
|
|
|
-
|
|
|
- else:
|
|
|
- array[i, j, :] = v.encode()
|
|
|
-
|
|
|
- return array
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def decode(array):
|
|
|
- """
|
|
|
- Decode an array grid encoding back into a grid
|
|
|
- """
|
|
|
-
|
|
|
- width, height, channels = array.shape
|
|
|
- assert channels == 3
|
|
|
-
|
|
|
- vis_mask = np.ones(shape=(width, height), dtype=bool)
|
|
|
-
|
|
|
- grid = Grid(width, height)
|
|
|
- for i in range(width):
|
|
|
- for j in range(height):
|
|
|
- type_idx, color_idx, state = array[i, j]
|
|
|
- v = WorldObj.decode(type_idx, color_idx, state)
|
|
|
- grid.set(i, j, v)
|
|
|
- vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
|
|
|
-
|
|
|
- return grid, vis_mask
|
|
|
-
|
|
|
- def process_vis(self, agent_pos):
|
|
|
- mask = np.zeros(shape=(self.width, self.height), dtype=bool)
|
|
|
-
|
|
|
- mask[agent_pos[0], agent_pos[1]] = True
|
|
|
-
|
|
|
- for j in reversed(range(0, self.height)):
|
|
|
- for i in range(0, self.width - 1):
|
|
|
- if not mask[i, j]:
|
|
|
- continue
|
|
|
-
|
|
|
- cell = self.get(i, j)
|
|
|
- if cell and not cell.see_behind():
|
|
|
- continue
|
|
|
-
|
|
|
- mask[i + 1, j] = True
|
|
|
- if j > 0:
|
|
|
- mask[i + 1, j - 1] = True
|
|
|
- mask[i, j - 1] = True
|
|
|
-
|
|
|
- for i in reversed(range(1, self.width)):
|
|
|
- if not mask[i, j]:
|
|
|
- continue
|
|
|
-
|
|
|
- cell = self.get(i, j)
|
|
|
- if cell and not cell.see_behind():
|
|
|
- continue
|
|
|
-
|
|
|
- mask[i - 1, j] = True
|
|
|
- if j > 0:
|
|
|
- mask[i - 1, j - 1] = True
|
|
|
- mask[i, j - 1] = True
|
|
|
-
|
|
|
- for j in range(0, self.height):
|
|
|
- for i in range(0, self.width):
|
|
|
- if not mask[i, j]:
|
|
|
- self.set(i, j, None)
|
|
|
-
|
|
|
- return mask
|
|
|
+from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC, TILE_PIXELS
|
|
|
+from minigrid.core.grid import Grid
|
|
|
+from minigrid.core.mission import MissionSpace
|
|
|
+from minigrid.utils.window import Window
|
|
|
|
|
|
|
|
|
class MiniGridEnv(gym.Env):
|
|
@@ -920,7 +110,7 @@ class MiniGridEnv(gym.Env):
|
|
|
self.agent_pos: np.ndarray = None
|
|
|
self.agent_dir: int = None
|
|
|
|
|
|
- # Current grid and mission and carryinh
|
|
|
+ # Current grid and mission and carrying
|
|
|
self.grid = Grid(width, height)
|
|
|
self.carrying = None
|
|
|
|
|
@@ -1307,6 +497,7 @@ class MiniGridEnv(gym.Env):
|
|
|
vx, vy = coordinates
|
|
|
|
|
|
obs = self.gen_obs()
|
|
|
+
|
|
|
obs_grid, _ = Grid.decode(obs["image"])
|
|
|
obs_cell = obs_grid.get(vx, vy)
|
|
|
world_cell = self.grid.get(x, y)
|