|
@@ -10,12 +10,15 @@ class GoToObjectEnv(MiniGridEnv):
|
|
|
def __init__(
|
|
|
self,
|
|
|
size=6,
|
|
|
- numObjs=2
|
|
|
+ numObjs=2,
|
|
|
+ waitEnds=True
|
|
|
):
|
|
|
self.numObjs = numObjs
|
|
|
super().__init__(gridSize=size, maxSteps=5*size)
|
|
|
|
|
|
- self.reward_range = (-1000, 1000)
|
|
|
+ self.reward_range = (-1, 1)
|
|
|
+
|
|
|
+ self.waitEnds = waitEnds
|
|
|
|
|
|
def _genGrid(self, width, height):
|
|
|
assert width == height
|
|
@@ -58,10 +61,12 @@ class GoToObjectEnv(MiniGridEnv):
|
|
|
self._randInt(1, gridSz - 1),
|
|
|
self._randInt(1, gridSz - 1)
|
|
|
)
|
|
|
-
|
|
|
- if pos != self.startPos:
|
|
|
- grid.set(*pos, obj)
|
|
|
- break
|
|
|
+ if grid.get(*pos) != None:
|
|
|
+ continue
|
|
|
+ if pos == self.startPos:
|
|
|
+ continue
|
|
|
+ grid.set(*pos, obj)
|
|
|
+ break
|
|
|
|
|
|
objs.append((objType, objColor))
|
|
|
objPos.append(pos)
|
|
@@ -99,13 +104,15 @@ class GoToObjectEnv(MiniGridEnv):
|
|
|
ax, ay = self.agentPos
|
|
|
tx, ty = self.targetPos
|
|
|
|
|
|
- # Reward being next to the object
|
|
|
- # Double reward waiting next to the object
|
|
|
- if abs(ax - tx) <= 1 and abs(ay - ty) <= 1:
|
|
|
- if action == self.actions.wait:
|
|
|
- reward = 2
|
|
|
- else:
|
|
|
+ # Toggle/pickup action terminates the episode
|
|
|
+ if action == self.actions.toggle:
|
|
|
+ done = True
|
|
|
+
|
|
|
+ # Reward performing the wait action next to the target object
|
|
|
+ if action == self.actions.wait:
|
|
|
+ if abs(ax - tx) <= 1 and abs(ay - ty) <= 1:
|
|
|
reward = 1
|
|
|
+ done = self.waitEnds
|
|
|
|
|
|
obs = self._observation(obs)
|
|
|
|