fourroomqa.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. from gym_minigrid.minigrid import *
  2. from gym_minigrid.register import register
  3. class Room:
  4. def __init__(
  5. self,
  6. top,
  7. size,
  8. color,
  9. objects
  10. ):
  11. self.top = top
  12. self.size = size
  13. # Color of the room
  14. self.color = color
  15. # List of objects contained
  16. self.objects = objects
  17. class FourRoomQAEnv(MiniGridEnv):
  18. """
  19. Environment to experiment with embodied question answering
  20. https://arxiv.org/abs/1711.11543
  21. """
  22. # Enumeration of possible actions
  23. class Actions(IntEnum):
  24. left = 0
  25. right = 1
  26. forward = 2
  27. toggle = 3
  28. say = 4
  29. def __init__(self, size=16):
  30. assert size >= 10
  31. super(FourRoomQAEnv, self).__init__(gridSize=size, maxSteps=8*size)
  32. # Action enumeration for this environment
  33. self.actions = MiniGridEnv.Actions
  34. # TODO: dictionary action_space, to include answer sentence?
  35. # Actions are discrete integer values
  36. self.action_space = spaces.Discrete(len(self.actions))
  37. # TODO: dictionary observation_space, to include question?
  38. def _randPos(self, room, border=1):
  39. return (
  40. self._randInt(
  41. room.top[0] + border,
  42. room.top[0] + room.size[0] - border
  43. ),
  44. self._randInt(
  45. room.top[1] + border,
  46. room.top[1] + room.size[1] - border
  47. ),
  48. )
  49. def _genGrid(self, width, height):
  50. grid = Grid(width, height)
  51. # Horizontal and vertical split indices
  52. vSplitIdx = self._randInt(5, width-4)
  53. hSplitIdx = self._randInt(5, height-4)
  54. # Create the four rooms
  55. self.rooms = []
  56. self.rooms.append(Room(
  57. (0, 0),
  58. (vSplitIdx, hSplitIdx),
  59. 'red',
  60. []
  61. ))
  62. self.rooms.append(Room(
  63. (vSplitIdx, 0),
  64. (width - vSplitIdx, hSplitIdx),
  65. 'purple',
  66. []
  67. ))
  68. self.rooms.append(Room(
  69. (0, hSplitIdx),
  70. (vSplitIdx, height - hSplitIdx),
  71. 'blue',
  72. []
  73. ))
  74. self.rooms.append(Room(
  75. (vSplitIdx, hSplitIdx),
  76. (width - vSplitIdx, height - hSplitIdx),
  77. 'yellow',
  78. []
  79. ))
  80. # Place the room walls
  81. for room in self.rooms:
  82. x, y = room.top
  83. w, h = room.size
  84. # Horizontal walls
  85. for i in range(w):
  86. grid.set(x + i, y, Wall(room.color))
  87. grid.set(x + i, y + h - 1, Wall(room.color))
  88. # Vertical walls
  89. for j in range(h):
  90. grid.set(x, y + j, Wall(room.color))
  91. grid.set(x + w - 1, y + j, Wall(room.color))
  92. # Place wall openings connecting the rooms
  93. hIdx = self._randInt(1, hSplitIdx-1)
  94. grid.set(vSplitIdx, hIdx, None)
  95. grid.set(vSplitIdx-1, hIdx, None)
  96. hIdx = self._randInt(hSplitIdx+1, height-1)
  97. grid.set(vSplitIdx, hIdx, None)
  98. grid.set(vSplitIdx-1, hIdx, None)
  99. vIdx = self._randInt(1, vSplitIdx-1)
  100. grid.set(vIdx, hSplitIdx, None)
  101. grid.set(vIdx, hSplitIdx-1, None)
  102. vIdx = self._randInt(vSplitIdx+1, width-1)
  103. grid.set(vIdx, hSplitIdx, None)
  104. grid.set(vIdx, hSplitIdx-1, None)
  105. # Select a random position for the agent to start at
  106. self.startDir = self._randInt(0, 4)
  107. room = self._randElem(self.rooms)
  108. self.startPos = self._randPos(room)
  109. # Possible object types and colors
  110. types = ['key', 'ball']
  111. colors = list(COLORS.keys())
  112. # Place a number of random objects
  113. numObjs = self._randInt(1, 10)
  114. for i in range(0, numObjs):
  115. # Generate a random object
  116. objType = self._randElem(types)
  117. objColor = self._randElem(colors)
  118. if objType == 'key':
  119. obj = Key(objColor)
  120. elif objType == 'ball':
  121. obj = Ball(objColor)
  122. # Pick a random position that doesn't overlap with anything
  123. while True:
  124. room = self._randElem(self.rooms)
  125. pos = self._randPos(room, border=2)
  126. if pos == self.startPos:
  127. continue
  128. if grid.get(*pos) != None:
  129. continue
  130. grid.set(*pos, obj)
  131. break
  132. room.objects.append(obj)
  133. # Question examples:
  134. # - What color is the X?
  135. # - What color is the X in the ROOM?
  136. # - What room is the X located in?
  137. # - What color is the X in the blue room?
  138. # - How many rooms contain chairs?
  139. # - How many keys are there in the yellow room?
  140. # - How many <OBJs> in the <ROOM>?
  141. # Pick a random room to be the subject of the question
  142. room = self._randElem(self.rooms)
  143. # Pick a random object type
  144. objType = self._randElem(types)
  145. # Count the number of objects of this type in the room
  146. count = len(list(filter(lambda o: o.type == objType, room.objects)))
  147. # TODO: identify unique objects
  148. self.question = "Are there any %ss in the %s room?" % (objType, room.color)
  149. self.answer = "yes" if count > 0 else "no"
  150. # TODO: how many X in the Y room question type
  151. print(self.question)
  152. print(self.answer)
  153. return grid
  154. def _reset(self):
  155. obs = MiniGridEnv._reset(self)
  156. obs = {
  157. 'image': obs,
  158. 'question': self.question
  159. }
  160. return obs
  161. def _step(self, action):
  162. if isinstance(action, dict):
  163. answer = action['answer']
  164. action = action['action']
  165. else:
  166. answer = ''
  167. obs, reward, done, info = MiniGridEnv._step(self, action)
  168. if answer == self.answer:
  169. reward = 1000 - self.stepCount
  170. done = True
  171. obs = {
  172. 'image': obs,
  173. 'question': self.question
  174. }
  175. return obs, reward, done, info
  176. register(
  177. id='MiniGrid-FourRoomQA-v0',
  178. entry_point='gym_minigrid.envs:FourRoomQAEnv'
  179. )