|
@@ -45,6 +45,18 @@ class FourRoomQAEnv(MiniGridEnv):
|
|
|
|
|
|
# TODO: dictionary observation_space, to include question?
|
|
|
|
|
|
+ def _randPos(self, room, border=1):
|
|
|
+ return (
|
|
|
+ self._randInt(
|
|
|
+ room.top[0] + border,
|
|
|
+ room.top[0] + room.size[0] - border
|
|
|
+ ),
|
|
|
+ self._randInt(
|
|
|
+ room.top[1] + border,
|
|
|
+ room.top[1] + room.size[1] - border
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
def _genGrid(self, width, height):
|
|
|
grid = Grid(width, height)
|
|
|
|
|
@@ -112,22 +124,35 @@ class FourRoomQAEnv(MiniGridEnv):
|
|
|
# Select a random position for the agent to start at
|
|
|
self.startDir = self._randInt(0, 4)
|
|
|
room = self._randElem(self.rooms)
|
|
|
- self.startPos = (
|
|
|
- self._randInt(room.top[0] + 1, room.top[0] + room.size[0] - 1),
|
|
|
- self._randInt(room.top[1] + 1, room.top[1] + room.size[1] - 1),
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
- # TODO: pick a random room to be the subject of the question
|
|
|
- # TODO: identify unique objects
|
|
|
-
|
|
|
-
|
|
|
- # TODO:
|
|
|
- # Generate a question and answer
|
|
|
- self.question = ''
|
|
|
+ self.startPos = self._randPos(room)
|
|
|
+
|
|
|
+ # Possible object types and colors
|
|
|
+ types = ['key', 'ball']
|
|
|
+ colors = list(COLORS.keys())
|
|
|
+
|
|
|
+ # Place a number of random objects
|
|
|
+ numObjs = self._randInt(1, 10)
|
|
|
+ for i in range(0, numObjs):
|
|
|
+ # Generate a random object
|
|
|
+ objType = self._randElem(types)
|
|
|
+ objColor = self._randElem(colors)
|
|
|
+ if objType == 'key':
|
|
|
+ obj = Key(objColor)
|
|
|
+ elif objType == 'ball':
|
|
|
+ obj = Ball(objColor)
|
|
|
+
|
|
|
+ # Pick a random position that doesn't overlap with anything
|
|
|
+ while True:
|
|
|
+ room = self._randElem(self.rooms)
|
|
|
+ pos = self._randPos(room, border=2)
|
|
|
+ if pos == self.startPos:
|
|
|
+ continue
|
|
|
+ if grid.get(*pos) != None:
|
|
|
+ continue
|
|
|
+ grid.set(*pos, obj)
|
|
|
+ break
|
|
|
+
|
|
|
+ room.objects.append(obj)
|
|
|
|
|
|
# Question examples:
|
|
|
# - What color is the X?
|
|
@@ -138,15 +163,33 @@ class FourRoomQAEnv(MiniGridEnv):
|
|
|
# - How many keys are there in the yellow room?
|
|
|
# - How many <OBJs> in the <ROOM>?
|
|
|
|
|
|
- #self.answer
|
|
|
+ # Pick a random room to be the subject of the question
|
|
|
+ room = self._randElem(self.rooms)
|
|
|
|
|
|
+ # Pick a random object type
|
|
|
+ objType = self._randElem(types)
|
|
|
|
|
|
+ # Count the number of objects of this type in the room
|
|
|
+ count = len(list(filter(lambda o: o.type == objType, room.objects)))
|
|
|
|
|
|
+ # TODO: identify unique objects
|
|
|
|
|
|
+ self.question = "Are there any %ss in the %s room?" % (objType, room.color)
|
|
|
+ self.answer = "yes" if count > 0 else "no"
|
|
|
+
|
|
|
+ # TODO: how many X in the Y room question type
|
|
|
|
|
|
- return grid
|
|
|
|
|
|
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ print(self.question)
|
|
|
+ print(self.answer)
|
|
|
+
|
|
|
+ return grid
|
|
|
+
|
|
|
def _reset(self):
|
|
|
obs = MiniGridEnv._reset(self)
|
|
|
|
|
@@ -158,8 +201,18 @@ class FourRoomQAEnv(MiniGridEnv):
|
|
|
return obs
|
|
|
|
|
|
def _step(self, action):
|
|
|
+ if isinstance(action, dict):
|
|
|
+ answer = action['answer']
|
|
|
+ action = action['action']
|
|
|
+ else:
|
|
|
+ answer = ''
|
|
|
+
|
|
|
obs, reward, done, info = MiniGridEnv._step(self, action)
|
|
|
|
|
|
+ if answer == self.answer:
|
|
|
+ reward = 1000 - self.stepCount
|
|
|
+ done = True
|
|
|
+
|
|
|
obs = {
|
|
|
'image': obs,
|
|
|
'question': self.question
|
|
@@ -167,12 +220,6 @@ class FourRoomQAEnv(MiniGridEnv):
|
|
|
|
|
|
return obs, reward, done, info
|
|
|
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
register(
|
|
|
id='MiniGrid-FourRoomQA-v0',
|
|
|
entry_point='gym_minigrid.envs:FourRoomQAEnv'
|