浏览代码

FourRoomQA environment now generates one question type

Maxime Chevalier-Boisvert 7 年之前
父节点
当前提交
ebdcd5aba2
共有 1 个文件被更改,包括 71 次插入24 次删除
  1. 71 24
      gym_minigrid/envs/fourroomqa.py

+ 71 - 24
gym_minigrid/envs/fourroomqa.py

@@ -45,6 +45,18 @@ class FourRoomQAEnv(MiniGridEnv):
 
         # TODO: dictionary observation_space, to include question?
 
+    def _randPos(self, room, border=1):
+        return (
+            self._randInt(
+                room.top[0] + border,
+                room.top[0] + room.size[0] - border
+            ),
+            self._randInt(
+                room.top[1] + border,
+                room.top[1] + room.size[1] - border
+            ),
+        )
+
     def _genGrid(self, width, height):
         grid = Grid(width, height)
 
@@ -112,22 +124,35 @@ class FourRoomQAEnv(MiniGridEnv):
         # Select a random position for the agent to start at
         self.startDir = self._randInt(0, 4)
         room = self._randElem(self.rooms)
-        self.startPos = (
-            self._randInt(room.top[0] + 1, room.top[0] + room.size[0] - 1),
-            self._randInt(room.top[1] + 1, room.top[1] + room.size[1] - 1),
-        )
-
-
-
-
-
-        # TODO: pick a random room to be the subject of the question
-        # TODO: identify unique objects
-
-
-        # TODO:
-        # Generate a question and answer
-        self.question = ''
+        self.startPos = self._randPos(room)
+
+        # Possible object types and colors
+        types = ['key', 'ball']
+        colors = list(COLORS.keys())
+
+        # Place a number of random objects
+        numObjs = self._randInt(1, 10)
+        for i in range(0, numObjs):
+            # Generate a random object
+            objType = self._randElem(types)
+            objColor = self._randElem(colors)
+            if objType == 'key':
+                obj = Key(objColor)
+            elif objType == 'ball':
+                obj = Ball(objColor)
+
+            # Pick a random position that doesn't overlap with anything
+            while True:
+                room = self._randElem(self.rooms)
+                pos = self._randPos(room, border=2)
+                if pos == self.startPos:
+                    continue
+                if grid.get(*pos) != None:
+                    continue
+                grid.set(*pos, obj)
+                break
+
+            room.objects.append(obj)
 
         # Question examples:
         # - What color is the X?
@@ -138,15 +163,33 @@ class FourRoomQAEnv(MiniGridEnv):
         # - How many keys are there in the yellow room?
         # - How many <OBJs> in the <ROOM>?
 
-        #self.answer
+        # Pick a random room to be the subject of the question
+        room = self._randElem(self.rooms)
 
+        # Pick a random object type
+        objType = self._randElem(types)
 
+        # Count the number of objects of this type in the room
+        count = len(list(filter(lambda o: o.type == objType, room.objects)))
 
+        # TODO: identify unique objects
 
+        self.question = "Are there any %ss in the %s room?" % (objType, room.color)
+        self.answer = "yes" if count > 0 else "no"
+
+        # TODO: how many X in the Y room question type
 
-        return grid
 
 
+
+
+
+
+        print(self.question)
+        print(self.answer)
+
+        return grid
+
     def _reset(self):
         obs = MiniGridEnv._reset(self)
 
@@ -158,8 +201,18 @@ class FourRoomQAEnv(MiniGridEnv):
         return obs
 
     def _step(self, action):
+        if isinstance(action, dict):
+            answer = action['answer']
+            action = action['action']
+        else:
+            answer = ''
+
         obs, reward, done, info = MiniGridEnv._step(self, action)
 
+        if answer == self.answer:
+            reward = 1000 - self.stepCount
+            done = True
+
         obs = {
             'image': obs,
             'question': self.question
@@ -167,12 +220,6 @@ class FourRoomQAEnv(MiniGridEnv):
 
         return obs, reward, done, info
 
-
-
-
-
-
-
 register(
     id='MiniGrid-FourRoomQA-v0',
     entry_point='gym_minigrid.envs:FourRoomQAEnv'