|
@@ -8,10 +8,8 @@ from gym import spaces
|
|
|
|
|
|
from gym_minigrid.minigrid import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX, Goal
|
|
|
|
|
|
-from gym_minigrid.minigrid import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX, Goal
|
|
|
-
|
|
|
|
|
|
-class ReseedWrapper(Wrapper):
|
|
|
+class ReseedWrapper(gym.Wrapper):
|
|
|
"""
|
|
|
Wrapper to always regenerate an environment with the same set of seeds.
|
|
|
This can be used to force an environment to always keep the same
|
|
@@ -33,7 +31,7 @@ class ReseedWrapper(Wrapper):
|
|
|
return obs, reward, done, info
|
|
|
|
|
|
|
|
|
-class ActionBonus(Wrapper):
|
|
|
+class ActionBonus(gym.Wrapper):
|
|
|
"""
|
|
|
Wrapper which adds an exploration bonus.
|
|
|
This is a reward to encourage exploration of less
|
|
@@ -68,7 +66,7 @@ class ActionBonus(Wrapper):
|
|
|
return self.env.reset(**kwargs)
|
|
|
|
|
|
|
|
|
-class StateBonus(Wrapper):
|
|
|
+class StateBonus(gym.Wrapper):
|
|
|
"""
|
|
|
Adds an exploration bonus based on which positions
|
|
|
are visited on the grid.
|
|
@@ -104,7 +102,7 @@ class StateBonus(Wrapper):
|
|
|
return self.env.reset(**kwargs)
|
|
|
|
|
|
|
|
|
-class ImgObsWrapper(ObservationWrapper):
|
|
|
+class ImgObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
|
Use the image as the only observation output, no language/mission.
|
|
|
"""
|
|
@@ -117,7 +115,7 @@ class ImgObsWrapper(ObservationWrapper):
|
|
|
return obs["image"]
|
|
|
|
|
|
|
|
|
-class OneHotPartialObsWrapper(ObservationWrapper):
|
|
|
+class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
|
Wrapper to get a one-hot encoding of a partially observable
|
|
|
agent view as observation.
|
|
@@ -157,7 +155,7 @@ class OneHotPartialObsWrapper(ObservationWrapper):
|
|
|
return {**obs, "image": out}
|
|
|
|
|
|
|
|
|
-class RGBImgObsWrapper(ObservationWrapper):
|
|
|
+class RGBImgObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
|
Wrapper to use fully observable RGB image as observation,
|
|
|
This can be used to have the agent to solve the gridworld in pixel space.
|
|
@@ -189,7 +187,7 @@ class RGBImgObsWrapper(ObservationWrapper):
|
|
|
return {**obs, "image": rgb_img}
|
|
|
|
|
|
|
|
|
-class RGBImgPartialObsWrapper(ObservationWrapper):
|
|
|
+class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
|
Wrapper to use partially observable RGB image as observation.
|
|
|
This can be used to have the agent to solve the gridworld in pixel space.
|
|
@@ -220,7 +218,7 @@ class RGBImgPartialObsWrapper(ObservationWrapper):
|
|
|
return {**obs, "image": rgb_img_partial}
|
|
|
|
|
|
|
|
|
-class FullyObsWrapper(ObservationWrapper):
|
|
|
+class FullyObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
|
Fully observable gridworld using a compact grid encoding
|
|
|
"""
|
|
@@ -249,7 +247,7 @@ class FullyObsWrapper(ObservationWrapper):
|
|
|
return {**obs, "image": full_grid}
|
|
|
|
|
|
|
|
|
-class DictObservationSpaceWrapper(ObservationWrapper):
|
|
|
+class DictObservationSpaceWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
|
Transforms the observation space (that has a textual component) to a fully numerical observation space,
|
|
|
where the textual instructions are replaced by arrays representing the indices of each word in a fixed vocabulary.
|
|
@@ -367,7 +365,7 @@ class DictObservationSpaceWrapper(ObservationWrapper):
|
|
|
return obs
|
|
|
|
|
|
|
|
|
-class FlatObsWrapper(ObservationWrapper):
|
|
|
+class FlatObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
|
Encode mission strings using a one-hot scheme,
|
|
|
and combine these with observed images into one flat array
|
|
@@ -411,6 +409,10 @@ class FlatObsWrapper(ObservationWrapper):
|
|
|
chNo = ord(ch) - ord("a")
|
|
|
elif ch == " ":
|
|
|
chNo = ord("z") - ord("a") + 1
|
|
|
+ else:
|
|
|
+ raise ValueError(
|
|
|
+ f"Character {ch} is not available in mission string."
|
|
|
+ )
|
|
|
assert chNo < self.numCharCodes, "%s : %d" % (ch, chNo)
|
|
|
strArray[idx, chNo] = 1
|
|
|
|
|
@@ -422,7 +424,7 @@ class FlatObsWrapper(ObservationWrapper):
|
|
|
return obs
|
|
|
|
|
|
|
|
|
-class ViewSizeWrapper(Wrapper):
|
|
|
+class ViewSizeWrapper(gym.Wrapper):
|
|
|
"""
|
|
|
Wrapper to customize the agent field of view size.
|
|
|
This cannot be used with fully observable wrappers.
|
|
@@ -457,7 +459,7 @@ class ViewSizeWrapper(Wrapper):
|
|
|
return {**obs, "image": image}
|
|
|
|
|
|
|
|
|
-class DirectionObsWrapper(ObservationWrapper):
|
|
|
+class DirectionObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
|
Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
|
|
|
type = {slope , angle}
|
|
@@ -491,7 +493,7 @@ class DirectionObsWrapper(ObservationWrapper):
|
|
|
return obs
|
|
|
|
|
|
|
|
|
-class SymbolicObsWrapper(ObservationWrapper):
|
|
|
+class SymbolicObsWrapper(gym.core.ObservationWrapper):
|
|
|
"""
|
|
|
Fully observable grid with a symbolic state representation.
|
|
|
The symbol is a triple of (X, Y, IDX), where X and Y are
|