|
@@ -1,8 +1,10 @@
|
|
|
|
+from __future__ import annotations
|
|
|
|
+
|
|
import hashlib
|
|
import hashlib
|
|
import math
|
|
import math
|
|
from abc import abstractmethod
|
|
from abc import abstractmethod
|
|
from enum import IntEnum
|
|
from enum import IntEnum
|
|
-from typing import Optional
|
|
|
|
|
|
+from typing import Iterable, TypeVar
|
|
|
|
|
|
import gymnasium as gym
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
import numpy as np
|
|
@@ -11,8 +13,11 @@ from gymnasium import spaces
|
|
from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC, TILE_PIXELS
|
|
from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC, TILE_PIXELS
|
|
from minigrid.core.grid import Grid
|
|
from minigrid.core.grid import Grid
|
|
from minigrid.core.mission import MissionSpace
|
|
from minigrid.core.mission import MissionSpace
|
|
|
|
+from minigrid.core.world_object import Point, WorldObj
|
|
from minigrid.utils.window import Window
|
|
from minigrid.utils.window import Window
|
|
|
|
|
|
|
|
+T = TypeVar("T")
|
|
|
|
+
|
|
|
|
|
|
class MiniGridEnv(gym.Env):
|
|
class MiniGridEnv(gym.Env):
|
|
"""
|
|
"""
|
|
@@ -43,13 +48,13 @@ class MiniGridEnv(gym.Env):
|
|
def __init__(
|
|
def __init__(
|
|
self,
|
|
self,
|
|
mission_space: MissionSpace,
|
|
mission_space: MissionSpace,
|
|
- grid_size: int = None,
|
|
|
|
- width: int = None,
|
|
|
|
- height: int = None,
|
|
|
|
|
|
+ grid_size: int | None = None,
|
|
|
|
+ width: int | None = None,
|
|
|
|
+ height: int | None = None,
|
|
max_steps: int = 100,
|
|
max_steps: int = 100,
|
|
see_through_walls: bool = False,
|
|
see_through_walls: bool = False,
|
|
agent_view_size: int = 7,
|
|
agent_view_size: int = 7,
|
|
- render_mode: Optional[str] = None,
|
|
|
|
|
|
+ render_mode: str | None = None,
|
|
highlight: bool = True,
|
|
highlight: bool = True,
|
|
tile_size: int = TILE_PIXELS,
|
|
tile_size: int = TILE_PIXELS,
|
|
agent_pov: bool = False,
|
|
agent_pov: bool = False,
|
|
@@ -62,6 +67,7 @@ class MiniGridEnv(gym.Env):
|
|
assert width is None and height is None
|
|
assert width is None and height is None
|
|
width = grid_size
|
|
width = grid_size
|
|
height = grid_size
|
|
height = grid_size
|
|
|
|
+ assert width is not None and height is not None
|
|
|
|
|
|
# Action enumeration for this environment
|
|
# Action enumeration for this environment
|
|
self.actions = MiniGridEnv.Actions
|
|
self.actions = MiniGridEnv.Actions
|
|
@@ -107,7 +113,7 @@ class MiniGridEnv(gym.Env):
|
|
self.see_through_walls = see_through_walls
|
|
self.see_through_walls = see_through_walls
|
|
|
|
|
|
# Current position and direction of the agent
|
|
# Current position and direction of the agent
|
|
- self.agent_pos: np.ndarray = None
|
|
|
|
|
|
+ self.agent_pos: np.ndarray | tuple[int, int] = None
|
|
self.agent_dir: int = None
|
|
self.agent_dir: int = None
|
|
|
|
|
|
# Current grid and mission and carrying
|
|
# Current grid and mission and carrying
|
|
@@ -228,35 +234,35 @@ class MiniGridEnv(gym.Env):
|
|
def _gen_grid(self, width, height):
|
|
def _gen_grid(self, width, height):
|
|
pass
|
|
pass
|
|
|
|
|
|
- def _reward(self):
|
|
|
|
|
|
+ def _reward(self) -> float:
|
|
"""
|
|
"""
|
|
Compute the reward to be given upon success
|
|
Compute the reward to be given upon success
|
|
"""
|
|
"""
|
|
|
|
|
|
return 1 - 0.9 * (self.step_count / self.max_steps)
|
|
return 1 - 0.9 * (self.step_count / self.max_steps)
|
|
|
|
|
|
- def _rand_int(self, low, high):
|
|
|
|
|
|
+ def _rand_int(self, low: int, high: int) -> int:
|
|
"""
|
|
"""
|
|
Generate random integer in [low,high[
|
|
Generate random integer in [low,high[
|
|
"""
|
|
"""
|
|
|
|
|
|
return self.np_random.integers(low, high)
|
|
return self.np_random.integers(low, high)
|
|
|
|
|
|
- def _rand_float(self, low, high):
|
|
|
|
|
|
+ def _rand_float(self, low: float, high: float) -> float:
|
|
"""
|
|
"""
|
|
Generate random float in [low,high[
|
|
Generate random float in [low,high[
|
|
"""
|
|
"""
|
|
|
|
|
|
return self.np_random.uniform(low, high)
|
|
return self.np_random.uniform(low, high)
|
|
|
|
|
|
- def _rand_bool(self):
|
|
|
|
|
|
+ def _rand_bool(self) -> bool:
|
|
"""
|
|
"""
|
|
Generate random boolean value
|
|
Generate random boolean value
|
|
"""
|
|
"""
|
|
|
|
|
|
return self.np_random.integers(0, 2) == 0
|
|
return self.np_random.integers(0, 2) == 0
|
|
|
|
|
|
- def _rand_elem(self, iterable):
|
|
|
|
|
|
+ def _rand_elem(self, iterable: Iterable[T]) -> T:
|
|
"""
|
|
"""
|
|
Pick a random element in a list
|
|
Pick a random element in a list
|
|
"""
|
|
"""
|
|
@@ -265,7 +271,7 @@ class MiniGridEnv(gym.Env):
|
|
idx = self._rand_int(0, len(lst))
|
|
idx = self._rand_int(0, len(lst))
|
|
return lst[idx]
|
|
return lst[idx]
|
|
|
|
|
|
- def _rand_subset(self, iterable, num_elems):
|
|
|
|
|
|
+ def _rand_subset(self, iterable: Iterable[T], num_elems: int) -> list[T]:
|
|
"""
|
|
"""
|
|
Sample a random subset of distinct elements of a list
|
|
Sample a random subset of distinct elements of a list
|
|
"""
|
|
"""
|
|
@@ -273,7 +279,7 @@ class MiniGridEnv(gym.Env):
|
|
lst = list(iterable)
|
|
lst = list(iterable)
|
|
assert num_elems <= len(lst)
|
|
assert num_elems <= len(lst)
|
|
|
|
|
|
- out = []
|
|
|
|
|
|
+ out: list[T] = []
|
|
|
|
|
|
while len(out) < num_elems:
|
|
while len(out) < num_elems:
|
|
elem = self._rand_elem(lst)
|
|
elem = self._rand_elem(lst)
|
|
@@ -282,24 +288,33 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
return out
|
|
return out
|
|
|
|
|
|
- def _rand_color(self):
|
|
|
|
|
|
+ def _rand_color(self) -> str:
|
|
"""
|
|
"""
|
|
Generate a random color name (string)
|
|
Generate a random color name (string)
|
|
"""
|
|
"""
|
|
|
|
|
|
return self._rand_elem(COLOR_NAMES)
|
|
return self._rand_elem(COLOR_NAMES)
|
|
|
|
|
|
- def _rand_pos(self, xLow, xHigh, yLow, yHigh):
|
|
|
|
|
|
+ def _rand_pos(
|
|
|
|
+ self, x_low: int, x_high: int, y_low: int, y_high: int
|
|
|
|
+ ) -> tuple[int, int]:
|
|
"""
|
|
"""
|
|
Generate a random (x,y) position tuple
|
|
Generate a random (x,y) position tuple
|
|
"""
|
|
"""
|
|
|
|
|
|
return (
|
|
return (
|
|
- self.np_random.integers(xLow, xHigh),
|
|
|
|
- self.np_random.integers(yLow, yHigh),
|
|
|
|
|
|
+ self.np_random.integers(x_low, x_high),
|
|
|
|
+ self.np_random.integers(y_low, y_high),
|
|
)
|
|
)
|
|
|
|
|
|
- def place_obj(self, obj, top=None, size=None, reject_fn=None, max_tries=math.inf):
|
|
|
|
|
|
+ def place_obj(
|
|
|
|
+ self,
|
|
|
|
+ obj: WorldObj | None,
|
|
|
|
+ top: Point = None,
|
|
|
|
+ size: tuple[int, int] = None,
|
|
|
|
+ reject_fn=None,
|
|
|
|
+ max_tries=math.inf,
|
|
|
|
+ ):
|
|
"""
|
|
"""
|
|
Place an object at an empty position in the grid
|
|
Place an object at an empty position in the grid
|
|
|
|
|
|
@@ -326,15 +341,11 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
num_tries += 1
|
|
num_tries += 1
|
|
|
|
|
|
- pos = np.array(
|
|
|
|
- (
|
|
|
|
- self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
|
|
|
|
- self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
|
|
|
|
- )
|
|
|
|
|
|
+ pos = (
|
|
|
|
+ self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
|
|
|
|
+ self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
|
|
)
|
|
)
|
|
|
|
|
|
- pos = tuple(pos)
|
|
|
|
-
|
|
|
|
# Don't place the object on top of another object
|
|
# Don't place the object on top of another object
|
|
if self.grid.get(*pos) is not None:
|
|
if self.grid.get(*pos) is not None:
|
|
continue
|
|
continue
|
|
@@ -357,7 +368,7 @@ class MiniGridEnv(gym.Env):
|
|
|
|
|
|
return pos
|
|
return pos
|
|
|
|
|
|
- def put_obj(self, obj, i, j):
|
|
|
|
|
|
+ def put_obj(self, obj: WorldObj, i: int, j: int):
|
|
"""
|
|
"""
|
|
Put an object at a specific position in the grid
|
|
Put an object at a specific position in the grid
|
|
"""
|
|
"""
|
|
@@ -387,7 +398,9 @@ class MiniGridEnv(gym.Env):
|
|
of forward movement.
|
|
of forward movement.
|
|
"""
|
|
"""
|
|
|
|
|
|
- assert self.agent_dir >= 0 and self.agent_dir < 4
|
|
|
|
|
|
+ assert (
|
|
|
|
+ self.agent_dir >= 0 and self.agent_dir < 4
|
|
|
|
+ ), f"Invalid agent_dir: {self.agent_dir} is not within range(0, 4)"
|
|
return DIR_TO_VEC[self.agent_dir]
|
|
return DIR_TO_VEC[self.agent_dir]
|
|
|
|
|
|
@property
|
|
@property
|