|
@@ -3,11 +3,12 @@ from __future__ import annotations
|
|
|
import math
|
|
|
import operator
|
|
|
from functools import reduce
|
|
|
+from typing import Any
|
|
|
|
|
|
import gymnasium as gym
|
|
|
import numpy as np
|
|
|
-from gymnasium import spaces
|
|
|
-from gymnasium.core import ObservationWrapper, Wrapper
|
|
|
+from gymnasium import logger, spaces
|
|
|
+from gymnasium.core import ObservationWrapper, ObsType, Wrapper
|
|
|
|
|
|
from minigrid.core.constants import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX
|
|
|
from minigrid.core.world_object import Goal
|
|
@@ -24,8 +25,9 @@ class ReseedWrapper(Wrapper):
|
|
|
>>> import gymnasium as gym
|
|
|
>>> from minigrid.wrappers import ReseedWrapper
|
|
|
>>> env = gym.make("MiniGrid-Empty-5x5-v0")
|
|
|
+ >>> _ = env.reset(seed=123)
|
|
|
>>> [env.np_random.integers(10) for i in range(10)]
|
|
|
- [1, 9, 5, 8, 4, 3, 8, 8, 3, 1]
|
|
|
+ [0, 6, 5, 0, 9, 2, 2, 1, 3, 1]
|
|
|
>>> env = ReseedWrapper(env, seeds=[0, 1], seed_idx=0)
|
|
|
>>> _, _ = env.reset()
|
|
|
>>> [env.np_random.integers(10) for i in range(10)]
|
|
@@ -41,7 +43,7 @@ class ReseedWrapper(Wrapper):
|
|
|
[4, 5, 7, 9, 0, 1, 8, 9, 2, 3]
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, env, seeds=[0], seed_idx=0):
|
|
|
+ def __init__(self, env, seeds=(0,), seed_idx=0):
|
|
|
"""A wrapper that always regenerate an environment with the same set of seeds.
|
|
|
|
|
|
Args:
|
|
@@ -53,15 +55,16 @@ class ReseedWrapper(Wrapper):
|
|
|
self.seed_idx = seed_idx
|
|
|
super().__init__(env)
|
|
|
|
|
|
- def reset(self, **kwargs):
|
|
|
- """Resets the environment with `kwargs`."""
|
|
|
+ def reset(
|
|
|
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
|
|
+ ) -> tuple[ObsType, dict[str, Any]]:
|
|
|
+ if seed is not None:
|
|
|
+ logger.warn(
|
|
|
+ "A seed has been passed to `ReseedWrapper.reset` which is ignored."
|
|
|
+ )
|
|
|
seed = self.seeds[self.seed_idx]
|
|
|
self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
|
|
|
- return self.env.reset(seed=seed, **kwargs)
|
|
|
-
|
|
|
- def step(self, action):
|
|
|
- """Steps through the environment with `action`."""
|
|
|
- return self.env.step(action)
|
|
|
+ return self.env.reset(seed=seed, options=options)
|
|
|
|
|
|
|
|
|
class ActionBonus(gym.Wrapper):
|
|
@@ -71,7 +74,6 @@ class ActionBonus(gym.Wrapper):
|
|
|
visited (state,action) pairs.
|
|
|
|
|
|
Example:
|
|
|
- >>> import minigrid
|
|
|
>>> import gymnasium as gym
|
|
|
>>> from minigrid.wrappers import ActionBonus
|
|
|
>>> env = gym.make("MiniGrid-Empty-5x5-v0")
|
|
@@ -122,10 +124,6 @@ class ActionBonus(gym.Wrapper):
|
|
|
|
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
|
|
- def reset(self, **kwargs):
|
|
|
- """Resets the environment with `kwargs`."""
|
|
|
- return self.env.reset(**kwargs)
|
|
|
-
|
|
|
|
|
|
class PositionBonus(Wrapper):
|
|
|
"""
|
|
@@ -136,7 +134,6 @@ class PositionBonus(Wrapper):
|
|
|
This wrapper was previously called ``StateBonus``.
|
|
|
|
|
|
Example:
|
|
|
- >>> import minigrid
|
|
|
>>> import gymnasium as gym
|
|
|
>>> from minigrid.wrappers import PositionBonus
|
|
|
>>> env = gym.make("MiniGrid-Empty-5x5-v0")
|
|
@@ -189,17 +186,12 @@ class PositionBonus(Wrapper):
|
|
|
|
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
|
|
- def reset(self, **kwargs):
|
|
|
- """Resets the environment with `kwargs`."""
|
|
|
- return self.env.reset(**kwargs)
|
|
|
-
|
|
|
|
|
|
class ImgObsWrapper(ObservationWrapper):
|
|
|
"""
|
|
|
Use the image as the only observation output, no language/mission.
|
|
|
|
|
|
Example:
|
|
|
- >>> import minigrid
|
|
|
>>> import gymnasium as gym
|
|
|
>>> from minigrid.wrappers import ImgObsWrapper
|
|
|
>>> env = gym.make("MiniGrid-Empty-5x5-v0")
|
|
@@ -231,7 +223,6 @@ class OneHotPartialObsWrapper(ObservationWrapper):
|
|
|
agent view as observation.
|
|
|
|
|
|
Example:
|
|
|
- >>> import minigrid
|
|
|
>>> import gymnasium as gym
|
|
|
>>> from minigrid.wrappers import OneHotPartialObsWrapper
|
|
|
>>> env = gym.make("MiniGrid-Empty-5x5-v0")
|
|
@@ -254,7 +245,7 @@ class OneHotPartialObsWrapper(ObservationWrapper):
|
|
|
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
|
|
|
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
|
|
|
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0]],
|
|
|
- dtype=uint8)
|
|
|
+ dtype=uint8)
|
|
|
"""
|
|
|
|
|
|
def __init__(self, env, tile_size=8):
|
|
@@ -302,17 +293,16 @@ class RGBImgObsWrapper(ObservationWrapper):
|
|
|
This can be used to have the agent to solve the gridworld in pixel space.
|
|
|
|
|
|
Example:
|
|
|
- >>> import minigrid
|
|
|
>>> import gymnasium as gym
|
|
|
>>> import matplotlib.pyplot as plt
|
|
|
>>> from minigrid.wrappers import RGBImgObsWrapper
|
|
|
>>> env = gym.make("MiniGrid-Empty-5x5-v0")
|
|
|
>>> obs, _ = env.reset()
|
|
|
- >>> plt.imshow(obs['image'])
|
|
|
+ >>> plt.imshow(obs['image']) # doctest: +SKIP
|
|
|

|
|
|
>>> env = RGBImgObsWrapper(env)
|
|
|
>>> obs, _ = env.reset()
|
|
|
- >>> plt.imshow(obs['image'])
|
|
|
+ >>> plt.imshow(obs['image']) # doctest: +SKIP
|
|
|

|
|
|
"""
|
|
|
|
|
@@ -344,21 +334,20 @@ class RGBImgPartialObsWrapper(ObservationWrapper):
|
|
|
This can be used to have the agent to solve the gridworld in pixel space.
|
|
|
|
|
|
Example:
|
|
|
- >>> import minigrid
|
|
|
>>> import gymnasium as gym
|
|
|
>>> import matplotlib.pyplot as plt
|
|
|
>>> from minigrid.wrappers import RGBImgObsWrapper, RGBImgPartialObsWrapper
|
|
|
>>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
|
|
|
>>> obs, _ = env.reset()
|
|
|
- >>> plt.imshow(obs["image"])
|
|
|
+ >>> plt.imshow(obs["image"]) # doctest: +SKIP
|
|
|

|
|
|
>>> env_obs = RGBImgObsWrapper(env)
|
|
|
>>> obs, _ = env_obs.reset()
|
|
|
- >>> plt.imshow(obs["image"])
|
|
|
+ >>> plt.imshow(obs["image"]) # doctest: +SKIP
|
|
|

|
|
|
>>> env_obs = RGBImgPartialObsWrapper(env)
|
|
|
>>> obs, _ = env_obs.reset()
|
|
|
- >>> plt.imshow(obs["image"])
|
|
|
+ >>> plt.imshow(obs["image"]) # doctest: +SKIP
|
|
|

|
|
|
"""
|
|
|
|
|
@@ -391,7 +380,6 @@ class FullyObsWrapper(ObservationWrapper):
|
|
|
Fully observable gridworld using a compact grid encoding instead of the agent view.
|
|
|
|
|
|
Example:
|
|
|
- >>> import minigrid
|
|
|
>>> import gymnasium as gym
|
|
|
>>> import matplotlib.pyplot as plt
|
|
|
>>> from minigrid.wrappers import FullyObsWrapper
|
|
@@ -437,7 +425,6 @@ class DictObservationSpaceWrapper(ObservationWrapper):
|
|
|
This wrapper is not applicable to BabyAI environments, given that these have their own language component.
|
|
|
|
|
|
Example:
|
|
|
- >>> import minigrid
|
|
|
>>> import gymnasium as gym
|
|
|
>>> import matplotlib.pyplot as plt
|
|
|
>>> from minigrid.wrappers import DictObservationSpaceWrapper
|
|
@@ -571,7 +558,6 @@ class FlatObsWrapper(ObservationWrapper):
|
|
|
This wrapper is not applicable to BabyAI environments, given that these have their own language component.
|
|
|
|
|
|
Example:
|
|
|
- >>> import minigrid
|
|
|
>>> import gymnasium as gym
|
|
|
>>> import matplotlib.pyplot as plt
|
|
|
>>> from minigrid.wrappers import FlatObsWrapper
|
|
@@ -643,9 +629,7 @@ class ViewSizeWrapper(ObservationWrapper):
|
|
|
This cannot be used with fully observable wrappers.
|
|
|
|
|
|
Example:
|
|
|
- >>> import minigrid
|
|
|
>>> import gymnasium as gym
|
|
|
- >>> import matplotlib.pyplot as plt
|
|
|
>>> from minigrid.wrappers import ViewSizeWrapper
|
|
|
>>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
|
|
|
>>> obs, _ = env.reset()
|
|
@@ -692,7 +676,6 @@ class DirectionObsWrapper(ObservationWrapper):
|
|
|
type = {slope , angle}
|
|
|
|
|
|
Example:
|
|
|
- >>> import minigrid
|
|
|
>>> import gymnasium as gym
|
|
|
>>> import matplotlib.pyplot as plt
|
|
|
>>> from minigrid.wrappers import DirectionObsWrapper
|
|
@@ -708,8 +691,10 @@ class DirectionObsWrapper(ObservationWrapper):
|
|
|
self.goal_position: tuple = None
|
|
|
self.type = type
|
|
|
|
|
|
- def reset(self):
|
|
|
- obs, _ = self.env.reset()
|
|
|
+ def reset(
|
|
|
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
|
|
+ ) -> tuple[ObsType, dict[str, Any]]:
|
|
|
+ obs, info = self.env.reset()
|
|
|
|
|
|
if not self.goal_position:
|
|
|
self.goal_position = [
|
|
@@ -722,7 +707,7 @@ class DirectionObsWrapper(ObservationWrapper):
|
|
|
self.goal_position[0] % self.width,
|
|
|
)
|
|
|
|
|
|
- return self.observation(obs)
|
|
|
+ return self.observation(obs), info
|
|
|
|
|
|
def observation(self, obs):
|
|
|
slope = np.divide(
|
|
@@ -745,9 +730,7 @@ class SymbolicObsWrapper(ObservationWrapper):
|
|
|
the coordinates on the grid, and IDX is the id of the object.
|
|
|
|
|
|
Example:
|
|
|
- >>> import minigrid
|
|
|
>>> import gymnasium as gym
|
|
|
- >>> import matplotlib.pyplot as plt
|
|
|
>>> from minigrid.wrappers import SymbolicObsWrapper
|
|
|
>>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
|
|
|
>>> obs, _ = env.reset()
|