Browse Source

Modified fetch env for flat encoding

Maxime Chevalier-Boisvert 7 years ago
parent
commit
fad2fc94ee
1 changed files with 14 additions and 1 deletions
  1. 14 1
      gym_minigrid/envs/fetch.py

+ 14 - 1
gym_minigrid/envs/fetch.py

@@ -15,6 +15,13 @@ class FetchEnv(MiniGridEnv):
         self.numObjs = numObjs
         super().__init__(gridSize=size, maxSteps=5*size)
 
+        obsSize = OBS_ARRAY_SIZE[0]*OBS_ARRAY_SIZE[1]*OBS_ARRAY_SIZE[2]
+        self.observation_space = spaces.Box(
+            low=0,
+            high=255,
+            shape=obsSize + 2
+        )
+
     def _genGrid(self, width, height):
         assert width == height
         gridSz = width
@@ -83,11 +90,17 @@ class FetchEnv(MiniGridEnv):
         Encode observations
         """
 
+        """
         obs = {
             'image': obs,
             'mission': self.mission,
             'advice' : ''
         }
+        """
+
+        typeIdx = OBJECT_TO_IDX[self.targetType]
+        colorIdx= COLOR_TO_IDX[self.targetColor]
+        obs = np.hstack((obs.flatten(), [typeIdx, colorIdx]))
 
         return obs
 
@@ -104,7 +117,7 @@ class FetchEnv(MiniGridEnv):
                 reward = 1000 - self.stepCount
                 done = True
             else:
-                reward = -1
+                reward = -1000
                 done = True
 
         obs = self._observation(obs)