浏览代码

Modified fetch env for flat encoding

Maxime Chevalier-Boisvert 7 年之前
父节点
当前提交
fad2fc94ee
共有 1 个文件被更改,包括 14 次插入1 次删除
  1. 14 1
      gym_minigrid/envs/fetch.py

+ 14 - 1
gym_minigrid/envs/fetch.py

@@ -15,6 +15,13 @@ class FetchEnv(MiniGridEnv):
         self.numObjs = numObjs
         self.numObjs = numObjs
         super().__init__(gridSize=size, maxSteps=5*size)
         super().__init__(gridSize=size, maxSteps=5*size)
 
 
+        obsSize = OBS_ARRAY_SIZE[0]*OBS_ARRAY_SIZE[1]*OBS_ARRAY_SIZE[2]
+        self.observation_space = spaces.Box(
+            low=0,
+            high=255,
+            shape=obsSize + 2
+        )
+
     def _genGrid(self, width, height):
     def _genGrid(self, width, height):
         assert width == height
         assert width == height
         gridSz = width
         gridSz = width
@@ -83,11 +90,17 @@ class FetchEnv(MiniGridEnv):
         Encode observations
         Encode observations
         """
         """
 
 
+        """
         obs = {
         obs = {
             'image': obs,
             'image': obs,
             'mission': self.mission,
             'mission': self.mission,
             'advice' : ''
             'advice' : ''
         }
         }
+        """
+
+        typeIdx = OBJECT_TO_IDX[self.targetType]
+        colorIdx= COLOR_TO_IDX[self.targetColor]
+        obs = np.hstack((obs.flatten(), [typeIdx, colorIdx]))
 
 
         return obs
         return obs
 
 
@@ -104,7 +117,7 @@ class FetchEnv(MiniGridEnv):
                 reward = 1000 - self.stepCount
                 reward = 1000 - self.stepCount
                 done = True
                 done = True
             else:
             else:
-                reward = -1
+                reward = -1000
                 done = True
                 done = True
 
 
         obs = self._observation(obs)
         obs = self._observation(obs)