|
@@ -127,16 +127,18 @@ class ActionBonus(gym.Wrapper):
|
|
|
return self.env.reset(**kwargs)
|
|
|
|
|
|
|
|
|
-# Should be named PositionBonus
|
|
|
-class StateBonus(Wrapper):
|
|
|
+class PositionBonus(Wrapper):
|
|
|
"""
|
|
|
Adds an exploration bonus based on which positions
|
|
|
are visited on the grid.
|
|
|
|
|
|
+ Note:
|
|
|
+ This wrapper was previously called ``StateBonus``.
|
|
|
+
|
|
|
Example:
|
|
|
>>> import miniworld
|
|
|
>>> import gymnasium as gym
|
|
|
- >>> from minigrid.wrappers import StateBonus
|
|
|
+ >>> from minigrid.wrappers import PositionBonus
|
|
|
>>> env = gym.make("MiniGrid-Empty-5x5-v0")
|
|
|
>>> _, _ = env.reset(seed=0)
|
|
|
>>> _, reward, _, _, _ = env.step(1)
|
|
@@ -145,7 +147,7 @@ class StateBonus(Wrapper):
|
|
|
>>> _, reward, _, _, _ = env.step(1)
|
|
|
>>> print(reward)
|
|
|
0
|
|
|
- >>> env_bonus = StateBonus(env)
|
|
|
+ >>> env_bonus = PositionBonus(env)
|
|
|
>>> obs, _ = env_bonus.reset(seed=0)
|
|
|
>>> obs, reward, terminated, truncated, info = env_bonus.step(1)
|
|
|
>>> print(reward)
|
|
@@ -688,6 +690,17 @@ class DirectionObsWrapper(ObservationWrapper):
|
|
|
"""
|
|
|
Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
|
|
|
type = {slope , angle}
|
|
|
+
|
|
|
+ Example:
|
|
|
+ >>> import miniworld
|
|
|
+ >>> import gymnasium as gym
|
|
|
+ >>> import matplotlib.pyplot as plt
|
|
|
+ >>> from minigrid.wrappers import DirectionObsWrapper
|
|
|
+ >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
|
|
|
+ >>> env_obs = DirectionObsWrapper(env, type="slope")
|
|
|
+ >>> obs, _ = env_obs.reset()
|
|
|
+ >>> obs['goal_direction']
|
|
|
+ 1.0
|
|
|
"""
|
|
|
|
|
|
def __init__(self, env, type="slope"):
|
|
@@ -696,7 +709,8 @@ class DirectionObsWrapper(ObservationWrapper):
|
|
|
self.type = type
|
|
|
|
|
|
def reset(self):
|
|
|
- obs = self.env.reset()
|
|
|
+ obs, _ = self.env.reset()
|
|
|
+
|
|
|
if not self.goal_position:
|
|
|
self.goal_position = [
|
|
|
x for x, y in enumerate(self.grid.grid) if isinstance(y, Goal)
|
|
@@ -707,14 +721,20 @@ class DirectionObsWrapper(ObservationWrapper):
|
|
|
int(self.goal_position[0] / self.height),
|
|
|
self.goal_position[0] % self.width,
|
|
|
)
|
|
|
- return obs
|
|
|
+
|
|
|
+ return self.observation(obs)
|
|
|
|
|
|
def observation(self, obs):
|
|
|
slope = np.divide(
|
|
|
self.goal_position[1] - self.agent_pos[1],
|
|
|
self.goal_position[0] - self.agent_pos[0],
|
|
|
)
|
|
|
- obs["goal_direction"] = np.arctan(slope) if self.type == "angle" else slope
|
|
|
+
|
|
|
+ if self.type == "angle":
|
|
|
+ obs["goal_direction"] = np.arctan(slope)
|
|
|
+ else:
|
|
|
+ obs["goal_direction"] = slope
|
|
|
+
|
|
|
return obs
|
|
|
|
|
|
|
|
@@ -723,6 +743,20 @@ class SymbolicObsWrapper(ObservationWrapper):
|
|
|
Fully observable grid with a symbolic state representation.
|
|
|
The symbol is a triple of (X, Y, IDX), where X and Y are
|
|
|
the coordinates on the grid, and IDX is the id of the object.
|
|
|
+
|
|
|
+ Example:
|
|
|
+ >>> import miniworld
|
|
|
+ >>> import gymnasium as gym
|
|
|
+ >>> import matplotlib.pyplot as plt
|
|
|
+ >>> from minigrid.wrappers import SymbolicObsWrapper
|
|
|
+ >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
|
|
|
+ >>> obs, _ = env.reset()
|
|
|
+ >>> obs['image'].shape
|
|
|
+ (7, 7, 3)
|
|
|
+ >>> env_obs = SymbolicObsWrapper(env)
|
|
|
+ >>> obs, _ = env_obs.reset()
|
|
|
+ >>> obs['image'].shape
|
|
|
+ (11, 11, 3)
|
|
|
"""
|
|
|
|
|
|
def __init__(self, env):
|
|
@@ -749,4 +783,5 @@ class SymbolicObsWrapper(ObservationWrapper):
|
|
|
grid = np.transpose(grid, (1, 2, 0))
|
|
|
grid[agent_pos[0], agent_pos[1], 2] = OBJECT_TO_IDX["agent"]
|
|
|
obs["image"] = grid
|
|
|
+
|
|
|
return obs
|