|
@@ -59,7 +59,7 @@ class FourRoomQAEnv(MiniGridEnv):
|
|
|
)
|
|
|
|
|
|
def _genGrid(self, width, height):
|
|
|
- grid = Grid(width, height)
|
|
|
+ self.grid = Grid(width, height)
|
|
|
|
|
|
# Horizontal and vertical split indices
|
|
|
vSplitIdx = self._randInt(5, width-4)
|
|
@@ -99,28 +99,28 @@ class FourRoomQAEnv(MiniGridEnv):
|
|
|
|
|
|
# Horizontal walls
|
|
|
for i in range(w):
|
|
|
- grid.set(x + i, y, Wall(room.color))
|
|
|
- grid.set(x + i, y + h - 1, Wall(room.color))
|
|
|
+ self.grid.set(x + i, y, Wall(room.color))
|
|
|
+ self.grid.set(x + i, y + h - 1, Wall(room.color))
|
|
|
|
|
|
# Vertical walls
|
|
|
for j in range(h):
|
|
|
- grid.set(x, y + j, Wall(room.color))
|
|
|
- grid.set(x + w - 1, y + j, Wall(room.color))
|
|
|
+ self.grid.set(x, y + j, Wall(room.color))
|
|
|
+ self.grid.set(x + w - 1, y + j, Wall(room.color))
|
|
|
|
|
|
# Place wall openings connecting the rooms
|
|
|
hIdx = self._randInt(1, hSplitIdx-1)
|
|
|
- grid.set(vSplitIdx, hIdx, None)
|
|
|
- grid.set(vSplitIdx-1, hIdx, None)
|
|
|
+ self.grid.set(vSplitIdx, hIdx, None)
|
|
|
+ self.grid.set(vSplitIdx-1, hIdx, None)
|
|
|
hIdx = self._randInt(hSplitIdx+1, height-1)
|
|
|
- grid.set(vSplitIdx, hIdx, None)
|
|
|
- grid.set(vSplitIdx-1, hIdx, None)
|
|
|
+ self.grid.set(vSplitIdx, hIdx, None)
|
|
|
+ self.grid.set(vSplitIdx-1, hIdx, None)
|
|
|
|
|
|
vIdx = self._randInt(1, vSplitIdx-1)
|
|
|
- grid.set(vIdx, hSplitIdx, None)
|
|
|
- grid.set(vIdx, hSplitIdx-1, None)
|
|
|
+ self.grid.set(vIdx, hSplitIdx, None)
|
|
|
+ self.grid.set(vIdx, hSplitIdx-1, None)
|
|
|
vIdx = self._randInt(vSplitIdx+1, width-1)
|
|
|
- grid.set(vIdx, hSplitIdx, None)
|
|
|
- grid.set(vIdx, hSplitIdx-1, None)
|
|
|
+ self.grid.set(vIdx, hSplitIdx, None)
|
|
|
+ self.grid.set(vIdx, hSplitIdx-1, None)
|
|
|
|
|
|
# Select a random position for the agent to start at
|
|
|
self.startDir = self._randInt(0, 4)
|
|
@@ -150,9 +150,9 @@ class FourRoomQAEnv(MiniGridEnv):
|
|
|
pos = self._randPos(room, border=2)
|
|
|
if pos == self.startPos:
|
|
|
continue
|
|
|
- if grid.get(*pos) != None:
|
|
|
+ if self.grid.get(*pos) != None:
|
|
|
continue
|
|
|
- grid.set(*pos, obj)
|
|
|
+ self.grid.set(*pos, obj)
|
|
|
break
|
|
|
|
|
|
room.objects.append(obj)
|
|
@@ -182,11 +182,6 @@ class FourRoomQAEnv(MiniGridEnv):
|
|
|
|
|
|
# TODO: how many X in the Y room question type
|
|
|
|
|
|
- #print(self.mission)
|
|
|
- #print(self.answer)
|
|
|
-
|
|
|
- return grid
|
|
|
-
|
|
|
def step(self, action):
|
|
|
if isinstance(action, dict):
|
|
|
answer = action['answer']
|