Sfoglia il codice sorgente

Completed PutNear environment

Maxime Chevalier-Boisvert 7 anni fa
parent
commit
4267b1d39e
1 ha cambiato i file con 37 aggiunte e 17 eliminazioni
  1. 37 17
      gym_minigrid/envs/putnear.py

+ 37 - 17
gym_minigrid/envs/putnear.py

@@ -73,18 +73,24 @@ class PutNearEnv(MiniGridEnv):
             objs.append((objType, objColor))
             objPos.append(pos)
 
-        # Choose a random object to be moved up
+        # Choose a random object to be moved
         objIdx = self._randInt(0, len(objs))
         self.moveType, self.moveColor = objs[objIdx]
         self.movePos = objPos[objIdx]
 
+        # Choose a target object (to put the first object next to)
+        while True:
+            targetIdx = self._randInt(0, len(objs))
+            if targetIdx != objIdx:
+                break
+        self.targetType, self.targetColor = objs[targetIdx]
+        self.targetPos = objPos[targetIdx]
 
-
-
-
-        self.mission = 'put the %s %s near the Y' % (
+        self.mission = 'put the %s %s near the %s %s' % (
             self.moveColor,
-            self.moveType
+            self.moveType,
+            self.targetColor,
+            self.targetType
         )
 
         return grid
@@ -106,28 +112,42 @@ class PutNearEnv(MiniGridEnv):
         return self._observation(obs)
 
     def _step(self, action):
+        preCarrying = self.carrying
+
         obs, reward, done, info = MiniGridEnv._step(self, action)
 
-        """
-        ax, ay = self.agentPos
+        u, v = self.getDirVec()
+        ox, oy = (self.agentPos[0] + u, self.agentPos[1] + v)
         tx, ty = self.targetPos
 
-        # Toggle/pickup action terminates the episode
+        # Pickup/drop action
         if action == self.actions.toggle:
-            done = True
-
-        # Reward performing the wait action next to the target object
-        if action == self.actions.wait:
-            if abs(ax - tx) <= 1 and abs(ay - ty) <= 1:
-                reward = 1
-            done = self.waitEnds
-        """
+            # If we picked up the wrong object, terminate the episode
+            if self.carrying:
+                if self.carrying.type != self.moveType or self.carrying.color != self.moveColor:
+                    done = True
+
+            # If successfully dropping an object near the target
+            if preCarrying:
+                if self.grid.get(ox, oy) is preCarrying:
+                    if abs(ox - tx) <= 1 and abs(oy - ty) <= 1:
+                        reward = 1
+                done = True
 
         obs = self._observation(obs)
 
         return obs, reward, done, info
 
+class PutNear8x8N3(PutNearEnv):
+    def __init__(self):
+        super().__init__(size=8, numObjs=3)
+
 register(
     id='MiniGrid-PutNear-6x6-N2-v0',
     entry_point='gym_minigrid.envs:PutNearEnv'
 )
+
+register(
+    id='MiniGrid-PutNear-8x8-N3-v0',
+    entry_point='gym_minigrid.envs:PutNear8x8N3'
+)