| 
					
				 | 
			
			
				@@ -45,6 +45,18 @@ class FourRoomQAEnv(MiniGridEnv): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # TODO: dictionary observation_space, to include question? 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _randPos(self, room, border=1): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self._randInt( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                room.top[0] + border, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                room.top[0] + room.size[0] - border 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self._randInt( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                room.top[1] + border, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                room.top[1] + room.size[1] - border 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _genGrid(self, width, height): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         grid = Grid(width, height) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -112,22 +124,35 @@ class FourRoomQAEnv(MiniGridEnv): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # Select a random position for the agent to start at 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.startDir = self._randInt(0, 4) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         room = self._randElem(self.rooms) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.startPos = ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            self._randInt(room.top[0] + 1, room.top[0] + room.size[0] - 1), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            self._randInt(room.top[1] + 1, room.top[1] + room.size[1] - 1), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # TODO: pick a random room to be the subject of the question 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # TODO: identify unique objects 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # TODO: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # Generate a question and answer 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.question = '' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.startPos = self._randPos(room) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Possible object types and colors 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        types = ['key', 'ball'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        colors = list(COLORS.keys()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Place a number of random objects 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        numObjs = self._randInt(1, 10) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for i in range(0, numObjs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # Generate a random object 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            objType = self._randElem(types) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            objColor = self._randElem(colors) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if objType == 'key': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                obj = Key(objColor) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            elif objType == 'ball': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                obj = Ball(objColor) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # Pick a random position that doesn't overlap with anything 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            while True: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                room = self._randElem(self.rooms) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                pos = self._randPos(room, border=2) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if pos == self.startPos: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if grid.get(*pos) != None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                grid.set(*pos, obj) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            room.objects.append(obj) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # Question examples: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # - What color is the X? 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -138,15 +163,33 @@ class FourRoomQAEnv(MiniGridEnv): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # - How many keys are there in the yellow room? 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # - How many <OBJs> in the <ROOM>? 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #self.answer 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Pick a random room to be the subject of the question 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        room = self._randElem(self.rooms) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Pick a random object type 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        objType = self._randElem(types) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Count the number of objects of this type in the room 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        count = len(list(filter(lambda o: o.type == objType, room.objects))) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # TODO: identify unique objects 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.question = "Are there any %ss in the %s room?" % (objType, room.color) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.answer = "yes" if count > 0 else "no" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # TODO: how many X in the Y room question type 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return grid 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        print(self.question) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        print(self.answer) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return grid 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _reset(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         obs = MiniGridEnv._reset(self) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -158,8 +201,18 @@ class FourRoomQAEnv(MiniGridEnv): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return obs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _step(self, action): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if isinstance(action, dict): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            answer = action['answer'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            action = action['action'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            answer = '' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         obs, reward, done, info = MiniGridEnv._step(self, action) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if answer == self.answer: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            reward = 1000 - self.stepCount 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            done = True 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         obs = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             'image': obs, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             'question': self.question 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -167,12 +220,6 @@ class FourRoomQAEnv(MiniGridEnv): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return obs, reward, done, info 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 register( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     id='MiniGrid-FourRoomQA-v0', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     entry_point='gym_minigrid.envs:FourRoomQAEnv' 
			 |