|
@@ -5,7 +5,7 @@ from functools import reduce
|
|
|
import numpy as np
|
|
|
import gym
|
|
|
from gym import error, spaces, utils
|
|
|
-from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX
|
|
|
+from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX, Goal
|
|
|
|
|
|
class ReseedWrapper(gym.core.Wrapper):
|
|
|
"""
|
|
@@ -331,7 +331,6 @@ class ViewSizeWrapper(gym.core.Wrapper):
|
|
|
def step(self, action):
|
|
|
return self.env.step(action)
|
|
|
|
|
|
-from .minigrid import Goal
|
|
|
class DirectionObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
|
Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
|
|
@@ -354,3 +353,31 @@ class DirectionObsWrapper(gym.core.ObservationWrapper):
|
|
|
slope = np.divide( self.goal_position[1] - self.agent_pos[1] , self.goal_position[0] - self.agent_pos[0])
|
|
|
obs['goal_direction'] = np.arctan( slope ) if self.type == 'angle' else slope
|
|
|
return obs
|
|
|
+
|
|
|
+class SymbolicObsWrapper(gym.core.ObservationWrapper):
|
|
|
+ """
|
|
|
+ Fully observable grid with a symbolic state representation.
|
|
|
+ The symbol is a triple of (X, Y, IDX), where X and Y are
|
|
|
+ the coordinates on the grid, and IDX is the id of the object.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, env):
|
|
|
+ super().__init__(env)
|
|
|
+
|
|
|
+ self.observation_space.spaces["image"] = spaces.Box(
|
|
|
+ low=0,
|
|
|
+ high=max(OBJECT_TO_IDX.values()),
|
|
|
+ shape=(self.env.width, self.env.height, 3),
|
|
|
+ dtype="uint8",
|
|
|
+ )
|
|
|
+
|
|
|
+ def observation(self, obs):
|
|
|
+ objects = np.array(
|
|
|
+ [OBJECT_TO_IDX[o.type] if o is not None else -1 for o in self.grid.grid]
|
|
|
+ )
|
|
|
+ w, h = self.width, self.height
|
|
|
+ grid = np.mgrid[:w, :h]
|
|
|
+ grid = np.concatenate([grid, objects.reshape(1, w, h)])
|
|
|
+ grid = np.transpose(grid, (1, 2, 0))
|
|
|
+ obs['image'] = grid
|
|
|
+ return obs
|