|
@@ -73,18 +73,24 @@ class PutNearEnv(MiniGridEnv):
|
|
|
objs.append((objType, objColor))
|
|
|
objPos.append(pos)
|
|
|
|
|
|
- # Choose a random object to be moved up
|
|
|
+ # Choose a random object to be moved
|
|
|
objIdx = self._randInt(0, len(objs))
|
|
|
self.moveType, self.moveColor = objs[objIdx]
|
|
|
self.movePos = objPos[objIdx]
|
|
|
|
|
|
+ # Choose a target object (to put the first object next to)
|
|
|
+ while True:
|
|
|
+ targetIdx = self._randInt(0, len(objs))
|
|
|
+ if targetIdx != objIdx:
|
|
|
+ break
|
|
|
+ self.targetType, self.targetColor = objs[targetIdx]
|
|
|
+ self.targetPos = objPos[targetIdx]
|
|
|
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
- self.mission = 'put the %s %s near the Y' % (
|
|
|
+ self.mission = 'put the %s %s near the %s %s' % (
|
|
|
self.moveColor,
|
|
|
- self.moveType
|
|
|
+ self.moveType,
|
|
|
+ self.targetColor,
|
|
|
+ self.targetType
|
|
|
)
|
|
|
|
|
|
return grid
|
|
@@ -106,28 +112,42 @@ class PutNearEnv(MiniGridEnv):
|
|
|
return self._observation(obs)
|
|
|
|
|
|
def _step(self, action):
|
|
|
+ preCarrying = self.carrying
|
|
|
+
|
|
|
obs, reward, done, info = MiniGridEnv._step(self, action)
|
|
|
|
|
|
- """
|
|
|
- ax, ay = self.agentPos
|
|
|
+ u, v = self.getDirVec()
|
|
|
+ ox, oy = (self.agentPos[0] + u, self.agentPos[1] + v)
|
|
|
tx, ty = self.targetPos
|
|
|
|
|
|
- # Toggle/pickup action terminates the episode
|
|
|
+ # Pickup/drop action
|
|
|
if action == self.actions.toggle:
|
|
|
- done = True
|
|
|
-
|
|
|
- # Reward performing the wait action next to the target object
|
|
|
- if action == self.actions.wait:
|
|
|
- if abs(ax - tx) <= 1 and abs(ay - ty) <= 1:
|
|
|
- reward = 1
|
|
|
- done = self.waitEnds
|
|
|
- """
|
|
|
+ # If we picked up the wrong object, terminate the episode
|
|
|
+ if self.carrying:
|
|
|
+ if self.carrying.type != self.moveType or self.carrying.color != self.moveColor:
|
|
|
+ done = True
|
|
|
+
|
|
|
+ # If successfully dropping an object near the target
|
|
|
+ if preCarrying:
|
|
|
+ if self.grid.get(ox, oy) is preCarrying:
|
|
|
+ if abs(ox - tx) <= 1 and abs(oy - ty) <= 1:
|
|
|
+ reward = 1
|
|
|
+ done = True
|
|
|
|
|
|
obs = self._observation(obs)
|
|
|
|
|
|
return obs, reward, done, info
|
|
|
|
|
|
+class PutNear8x8N3(PutNearEnv):
|
|
|
+ def __init__(self):
|
|
|
+ super().__init__(size=8, numObjs=3)
|
|
|
+
|
|
|
register(
|
|
|
id='MiniGrid-PutNear-6x6-N2-v0',
|
|
|
entry_point='gym_minigrid.envs:PutNearEnv'
|
|
|
)
|
|
|
+
|
|
|
+register(
|
|
|
+ id='MiniGrid-PutNear-8x8-N3-v0',
|
|
|
+ entry_point='gym_minigrid.envs:PutNear8x8N3'
|
|
|
+)
|