Просмотр исходного кода

Made adjustments to GoToObject based on GoToDoor env

Maxime Chevalier-Boisvert 7 лет назад
Родитель
Сommit
73e6d3d2f1
1 измененных файлов с 19 добавлено и 12 удалено
  1. 19 12
      gym_minigrid/envs/gotoobject.py

+ 19 - 12
gym_minigrid/envs/gotoobject.py

@@ -10,12 +10,15 @@ class GoToObjectEnv(MiniGridEnv):
     def __init__(
         self,
         size=6,
-        numObjs=2
+        numObjs=2,
+        waitEnds=True
     ):
         self.numObjs = numObjs
         super().__init__(gridSize=size, maxSteps=5*size)
 
-        self.reward_range = (-1000, 1000)
+        self.reward_range = (-1, 1)
+
+        self.waitEnds = waitEnds
 
     def _genGrid(self, width, height):
         assert width == height
@@ -58,10 +61,12 @@ class GoToObjectEnv(MiniGridEnv):
                     self._randInt(1, gridSz - 1),
                     self._randInt(1, gridSz - 1)
                 )
-
-                if pos != self.startPos:
-                    grid.set(*pos, obj)
-                    break
+                if grid.get(*pos) != None:
+                    continue
+                if pos == self.startPos:
+                    continue
+                grid.set(*pos, obj)
+                break
 
             objs.append((objType, objColor))
             objPos.append(pos)
@@ -99,13 +104,15 @@ class GoToObjectEnv(MiniGridEnv):
         ax, ay = self.agentPos
         tx, ty = self.targetPos
 
-        # Reward being next to the object
-        # Double reward waiting next to the object
-        if abs(ax - tx) <= 1 and abs(ay - ty) <= 1:
-            if action == self.actions.wait:
-                reward = 2
-            else:
+        # Toggle/pickup action terminates the episode
+        if action == self.actions.toggle:
+            done = True
+
+        # Reward performing the wait action next to the target object
+        if action == self.actions.wait:
+            if abs(ax - tx) <= 1 and abs(ay - ty) <= 1:
                 reward = 1
+            done = self.waitEnds
 
         obs = self._observation(obs)