|
@@ -1,14 +1,13 @@
|
|
|
import hashlib
|
|
|
import math
|
|
|
-import string
|
|
|
from abc import abstractmethod
|
|
|
from enum import IntEnum
|
|
|
-from functools import partial
|
|
|
+from typing import Any, Callable, Optional, Union
|
|
|
|
|
|
import gym
|
|
|
import numpy as np
|
|
|
from gym import spaces
|
|
|
-from gym.utils.renderer import Renderer
|
|
|
+from gym.utils import seeding
|
|
|
|
|
|
# Size in pixels of a tile in the full-scale human view
|
|
|
from gym_minigrid.rendering import (
|
|
@@ -79,6 +78,197 @@ DIR_TO_VEC = [
|
|
|
]
|
|
|
|
|
|
|
|
|
+def check_if_no_duplicate(duplicate_list: list) -> bool:
|
|
|
+ """Check if given list contains any duplicates"""
|
|
|
+ return len(set(duplicate_list)) == len(duplicate_list)
|
|
|
+
|
|
|
+
|
|
|
+class MissionSpace(spaces.Space[str]):
|
|
|
+ r"""A space representing a mission for the Gym-Minigrid environments.
|
|
|
+ The space allows generating random mission strings constructed with an input placeholder list.
|
|
|
+ Example Usage::
|
|
|
+ >>> observation_space = MissionSpace(mission_func=lambda color: f"Get the {color} ball.",
|
|
|
+ ordered_placeholders=[["green", "blue"]])
|
|
|
+ >>> observation_space.sample()
|
|
|
+ "Get the green ball."
|
|
|
+ >>> observation_space = MissionSpace(mission_func=lambda : "Get the ball.".,
|
|
|
+ ordered_placeholders=None)
|
|
|
+ >>> observation_space.sample()
|
|
|
+ "Get the ball."
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ mission_func: Callable[..., str],
|
|
|
+ ordered_placeholders: Optional["list[list[str]]"] = None,
|
|
|
+ seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
|
|
|
+ ):
|
|
|
+ r"""Constructor of :class:`MissionSpace` space.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ mission_func (lambda _placeholders(str): _mission(str)): Function that generates a mission string from random placeholders.
|
|
|
+ ordered_placeholders (Optional["list[list[str]]"]): List of lists of placeholders ordered in placing order in the mission function mission_func.
|
|
|
+ seed: seed: The seed for sampling from the space.
|
|
|
+ """
|
|
|
+ # Check that the ordered placeholders and mission function are well defined.
|
|
|
+ if ordered_placeholders is not None:
|
|
|
+ assert (
|
|
|
+ len(ordered_placeholders) == mission_func.__code__.co_argcount
|
|
|
+ ), f"The number of placeholders {len(ordered_placeholders)} is different from the number of parameters in the mission function {mission_func.__code__.co_argcount}."
|
|
|
+ for placeholder_list in ordered_placeholders:
|
|
|
+ assert check_if_no_duplicate(
|
|
|
+ placeholder_list
|
|
|
+ ), "Make sure that the placeholders don't have any duplicate values."
|
|
|
+ else:
|
|
|
+ assert (
|
|
|
+ mission_func.__code__.co_argcount == 0
|
|
|
+ ), f"If the ordered placeholders are {ordered_placeholders}, the mission function shouldn't have any parameters."
|
|
|
+
|
|
|
+ self.ordered_placeholders = ordered_placeholders
|
|
|
+ self.mission_func = mission_func
|
|
|
+
|
|
|
+ super().__init__(dtype=str, seed=seed)
|
|
|
+
|
|
|
+ # Check that mission_func returns a string
|
|
|
+ sampled_mission = self.sample()
|
|
|
+ assert isinstance(
|
|
|
+ sampled_mission, str
|
|
|
+ ), f"mission_func must return type str not {type(sampled_mission)}"
|
|
|
+
|
|
|
+ def sample(self) -> str:
|
|
|
+ """Sample a random mission string."""
|
|
|
+ if self.ordered_placeholders is not None:
|
|
|
+ placeholders = []
|
|
|
+ for rand_var_list in self.ordered_placeholders:
|
|
|
+ idx = self.np_random.integers(0, len(rand_var_list))
|
|
|
+
|
|
|
+ placeholders.append(rand_var_list[idx])
|
|
|
+
|
|
|
+ return self.mission_func(*placeholders)
|
|
|
+ else:
|
|
|
+ return self.mission_func()
|
|
|
+
|
|
|
+ def contains(self, x: Any) -> bool:
|
|
|
+ """Return boolean specifying if x is a valid member of this space."""
|
|
|
+ # Store a list of all the placeholders from self.ordered_placeholders that appear in x
|
|
|
+ if self.ordered_placeholders is not None:
|
|
|
+ check_placeholder_list = []
|
|
|
+ for placeholder_list in self.ordered_placeholders:
|
|
|
+ for placeholder in placeholder_list:
|
|
|
+ if placeholder in x:
|
|
|
+ check_placeholder_list.append(placeholder)
|
|
|
+
|
|
|
+ # Remove duplicates from the list
|
|
|
+ check_placeholder_list = list(set(check_placeholder_list))
|
|
|
+
|
|
|
+ start_id_placeholder = []
|
|
|
+ end_id_placeholder = []
|
|
|
+ # Get the starting and ending id of the identified placeholders with possible duplicates
|
|
|
+ new_check_placeholder_list = []
|
|
|
+ for placeholder in check_placeholder_list:
|
|
|
+ new_start_id_placeholder = [
|
|
|
+ i for i in range(len(x)) if x.startswith(placeholder, i)
|
|
|
+ ]
|
|
|
+ new_check_placeholder_list += [placeholder] * len(
|
|
|
+ new_start_id_placeholder
|
|
|
+ )
|
|
|
+ end_id_placeholder += [
|
|
|
+ start_id + len(placeholder) - 1
|
|
|
+ for start_id in new_start_id_placeholder
|
|
|
+ ]
|
|
|
+ start_id_placeholder += new_start_id_placeholder
|
|
|
+
|
|
|
+ # Order by starting id the placeholders
|
|
|
+ ordered_placeholder_list = sorted(
|
|
|
+ zip(
|
|
|
+ start_id_placeholder, end_id_placeholder, new_check_placeholder_list
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ # Check for repeated placeholders contained in each other
|
|
|
+ remove_placeholder_id = []
|
|
|
+ for i, placeholder_1 in enumerate(ordered_placeholder_list):
|
|
|
+ starting_id = i + 1
|
|
|
+ for j, placeholder_2 in enumerate(
|
|
|
+ ordered_placeholder_list[starting_id:]
|
|
|
+ ):
|
|
|
+ # Check if place holder ids overlap and keep the longest
|
|
|
+ if max(placeholder_1[0], placeholder_2[0]) < min(
|
|
|
+ placeholder_1[1], placeholder_2[1]
|
|
|
+ ):
|
|
|
+ remove_placeholder = min(
|
|
|
+ placeholder_1[2], placeholder_2[2], key=len
|
|
|
+ )
|
|
|
+ if remove_placeholder == placeholder_1[2]:
|
|
|
+ remove_placeholder_id.append(i)
|
|
|
+ else:
|
|
|
+ remove_placeholder_id.append(i + j + 1)
|
|
|
+ for id in remove_placeholder_id:
|
|
|
+ del ordered_placeholder_list[id]
|
|
|
+
|
|
|
+ final_placeholders = [
|
|
|
+ placeholder[2] for placeholder in ordered_placeholder_list
|
|
|
+ ]
|
|
|
+
|
|
|
+ # Check that the identified final placeholders are in the same order as the original placeholders.
|
|
|
+ for orered_placeholder, final_placeholder in zip(
|
|
|
+ self.ordered_placeholders, final_placeholders
|
|
|
+ ):
|
|
|
+ if final_placeholder in orered_placeholder:
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ return False
|
|
|
+ try:
|
|
|
+ mission_string_with_placeholders = self.mission_func(
|
|
|
+ *final_placeholders
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ print(
|
|
|
+ f"{x} is not contained in MissionSpace due to the following exception: {e}"
|
|
|
+ )
|
|
|
+ return False
|
|
|
+
|
|
|
+ return bool(mission_string_with_placeholders == x)
|
|
|
+
|
|
|
+ else:
|
|
|
+ return bool(self.mission_func() == x)
|
|
|
+
|
|
|
+ def __repr__(self) -> str:
|
|
|
+ """Gives a string representation of this space."""
|
|
|
+ return f"MissionSpace({self.mission_func}, {self.ordered_placeholders})"
|
|
|
+
|
|
|
+ def __eq__(self, other) -> bool:
|
|
|
+ """Check whether ``other`` is equivalent to this instance."""
|
|
|
+ if isinstance(other, MissionSpace):
|
|
|
+
|
|
|
+ # Check that place holder lists are the same
|
|
|
+ if self.ordered_placeholders is not None:
|
|
|
+ # Check length
|
|
|
+ if (len(self.order_placeholder) == len(other.order_placeholder)) and (
|
|
|
+ all(
|
|
|
+ set(i) == set(j)
|
|
|
+ for i, j in zip(self.order_placeholder, other.order_placeholder)
|
|
|
+ )
|
|
|
+ ):
|
|
|
+ # Check mission string is the same with dummy space placeholders
|
|
|
+ test_placeholders = [""] * len(self.order_placeholder)
|
|
|
+ mission = self.mission_func(*test_placeholders)
|
|
|
+ other_mission = other.mission_func(*test_placeholders)
|
|
|
+ return mission == other_mission
|
|
|
+ else:
|
|
|
+
|
|
|
+ # Check that other is also None
|
|
|
+ if other.ordered_placeholders is None:
|
|
|
+
|
|
|
+ # Check mission string is the same
|
|
|
+ mission = self.mission_func()
|
|
|
+ other_mission = other.mission_func()
|
|
|
+ return mission == other_mission
|
|
|
+
|
|
|
+ # If none of the statements above return then False
|
|
|
+ return False
|
|
|
+
|
|
|
+
|
|
|
class WorldObj:
|
|
|
"""
|
|
|
Base class for grid world objects
|
|
@@ -261,9 +451,7 @@ class Door(WorldObj):
|
|
|
state = 1
|
|
|
else:
|
|
|
raise ValueError(
|
|
|
- "There is no possible state encoding for the state:\n -Door Open: {}\n -Door Closed: {}\n -Door Locked: {}".format(
|
|
|
- self.is_open, not self.is_open, self.is_locked
|
|
|
- )
|
|
|
+ f"There is no possible state encoding for the state:\n -Door Open: {self.is_open}\n -Door Closed: {not self.is_open}\n -Door Locked: {self.is_locked}"
|
|
|
)
|
|
|
|
|
|
return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
|
|
@@ -663,17 +851,21 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
+ mission_space: MissionSpace,
|
|
|
grid_size: int = None,
|
|
|
width: int = None,
|
|
|
height: int = None,
|
|
|
max_steps: int = 100,
|
|
|
see_through_walls: bool = False,
|
|
|
agent_view_size: int = 7,
|
|
|
- render_mode: str = None,
|
|
|
highlight: bool = True,
|
|
|
tile_size: int = TILE_PIXELS,
|
|
|
- **kwargs
|
|
|
+ **kwargs,
|
|
|
):
|
|
|
+
|
|
|
+ # Initialize mission
|
|
|
+ self.mission = mission_space.sample()
|
|
|
+
|
|
|
# Can't set both grid_size and width/height
|
|
|
if grid_size:
|
|
|
assert width is None and height is None
|
|
@@ -693,7 +885,7 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
# Observations are dictionaries containing an
|
|
|
# encoding of the grid and a textual 'mission' string
|
|
|
- self.observation_space = spaces.Box(
|
|
|
+ image_observation_space = spaces.Box(
|
|
|
low=0,
|
|
|
high=255,
|
|
|
shape=(self.agent_view_size, self.agent_view_size, 3),
|
|
@@ -701,24 +893,12 @@ class MiniGridEnv(gym.Env):
|
|
|
)
|
|
|
self.observation_space = spaces.Dict(
|
|
|
{
|
|
|
- "image": self.observation_space,
|
|
|
+ "image": image_observation_space,
|
|
|
"direction": spaces.Discrete(4),
|
|
|
- "mission": spaces.Text(
|
|
|
- max_length=200,
|
|
|
- charset=string.ascii_letters + string.digits + " .,!-",
|
|
|
- ),
|
|
|
+ "mission": mission_space,
|
|
|
}
|
|
|
)
|
|
|
|
|
|
- # render mode
|
|
|
- self.render_mode = render_mode
|
|
|
- render_frame = partial(
|
|
|
- self._render,
|
|
|
- highlight=highlight,
|
|
|
- tile_size=tile_size,
|
|
|
- )
|
|
|
- self.renderer = Renderer(self.render_mode, render_frame)
|
|
|
-
|
|
|
# Range of possible rewards
|
|
|
self.reward_range = (0, 1)
|
|
|
|
|
@@ -763,8 +943,6 @@ class MiniGridEnv(gym.Env):
|
|
|
# Return first observation
|
|
|
obs = self.gen_obs()
|
|
|
|
|
|
- self.renderer.reset()
|
|
|
- self.renderer.render_step()
|
|
|
if not return_info:
|
|
|
return obs
|
|
|
else:
|
|
@@ -1179,7 +1357,6 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
obs = self.gen_obs()
|
|
|
|
|
|
- self.renderer.render_step()
|
|
|
return obs, reward, done, {}
|
|
|
|
|
|
def gen_obs_grid(self, agent_view_size=None):
|
|
@@ -1258,7 +1435,7 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
return img
|
|
|
|
|
|
- def _render(self, mode="human", highlight=True, tile_size=TILE_PIXELS):
|
|
|
+ def render(self, mode="human", highlight=True, tile_size=TILE_PIXELS):
|
|
|
assert mode in self.metadata["render_modes"]
|
|
|
"""
|
|
|
Render the whole-grid human view
|
|
@@ -1315,16 +1492,6 @@ class MiniGridEnv(gym.Env):
|
|
|
else:
|
|
|
return img
|
|
|
|
|
|
- def render(self, mode="human", close=False, highlight=True, tile_size=TILE_PIXELS):
|
|
|
- if close:
|
|
|
- raise Exception(
|
|
|
- "Please close the rendering window using env.close(). Closing the rendering window with the render method is no longer allowed."
|
|
|
- )
|
|
|
- if self.render_mode is not None:
|
|
|
- return self.renderer.get_renders()
|
|
|
- else:
|
|
|
- return self._render(mode, highlight=highlight, tile_size=tile_size)
|
|
|
-
|
|
|
def close(self):
|
|
|
if self.window:
|
|
|
self.window.close()
|