|
@@ -15,6 +15,13 @@ class FetchEnv(MiniGridEnv):
|
|
|
self.numObjs = numObjs
|
|
|
super().__init__(gridSize=size, maxSteps=5*size)
|
|
|
|
|
|
+ obsSize = OBS_ARRAY_SIZE[0]*OBS_ARRAY_SIZE[1]*OBS_ARRAY_SIZE[2]
|
|
|
+ self.observation_space = spaces.Box(
|
|
|
+ low=0,
|
|
|
+ high=255,
|
|
|
+ shape=obsSize + 2
|
|
|
+ )
|
|
|
+
|
|
|
def _genGrid(self, width, height):
|
|
|
assert width == height
|
|
|
gridSz = width
|
|
@@ -83,11 +90,17 @@ class FetchEnv(MiniGridEnv):
|
|
|
Encode observations
|
|
|
"""
|
|
|
|
|
|
+ """
|
|
|
obs = {
|
|
|
'image': obs,
|
|
|
'mission': self.mission,
|
|
|
'advice' : ''
|
|
|
}
|
|
|
+ """
|
|
|
+
|
|
|
+ typeIdx = OBJECT_TO_IDX[self.targetType]
|
|
|
+ colorIdx= COLOR_TO_IDX[self.targetColor]
|
|
|
+ obs = np.hstack((obs.flatten(), [typeIdx, colorIdx]))
|
|
|
|
|
|
return obs
|
|
|
|
|
@@ -104,7 +117,7 @@ class FetchEnv(MiniGridEnv):
|
|
|
reward = 1000 - self.stepCount
|
|
|
done = True
|
|
|
else:
|
|
|
- reward = -1
|
|
|
+ reward = -1000
|
|
|
done = True
|
|
|
|
|
|
obs = self._observation(obs)
|