فهرست منبع

Fixed bug, improved reward function in GoToDoor env

Maxime Chevalier-Boisvert 7 سال پیش
والد
کامیت
1a80488ad0
1فایلهای تغییر یافته به همراه10 افزوده شده و 8 حذف شده
  1. 10 8
      gym_minigrid/envs/gotodoor.py

+ 10 - 8
gym_minigrid/envs/gotodoor.py

@@ -13,7 +13,10 @@ class GoToDoorEnv(MiniGridEnv):
     ):
     ):
         super().__init__(gridSize=size, maxSteps=10*size)
         super().__init__(gridSize=size, maxSteps=10*size)
 
 
-        self.reward_range = (-1, self.maxSteps)
+        self.reward_range = (-1000, 1000)
+
+        # Flag determining whether the wait action ends the episode
+        self.waitEnds = True
 
 
     def _genGrid(self, width, height):
     def _genGrid(self, width, height):
         assert width == height
         assert width == height
@@ -58,8 +61,8 @@ class GoToDoorEnv(MiniGridEnv):
 
 
         # Select a random target door
         # Select a random target door
         doorIdx = self._randInt(0, len(doorPos))
         doorIdx = self._randInt(0, len(doorPos))
-        self.targetPos = doorPos[idx]
-        self.targetColor = doorColors[idx]
+        self.targetPos = doorPos[doorIdx]
+        self.targetColor = doorColors[doorIdx]
 
 
         # Generate the mission string
         # Generate the mission string
         self.mission = 'go to the %s door' % self.targetColor
         self.mission = 'go to the %s door' % self.targetColor
@@ -97,12 +100,11 @@ class GoToDoorEnv(MiniGridEnv):
 
 
         # Reward waiting in front of the target door
         # Reward waiting in front of the target door
         if action == self.actions.wait:
         if action == self.actions.wait:
-            if ax == tx and abs(ay - ty) == 1:
-                reward = 1
-            elif ay == ty and abs(ax - tx) == 1:
+            if (ax == tx and abs(ay - ty) == 1) or (ay == ty and abs(ax - tx) == 1):
                 reward = 1
                 reward = 1
-            #else:
-            #    reward = -0.1
+            else:
+                reward = 0
+            done = self.waitEnds
 
 
         obs = self._observation(obs)
         obs = self._observation(obs)