fourroomqa.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. from gym_minigrid.minigrid import *
  2. from gym_minigrid.register import register
  3. class Room:
  4. def __init__(
  5. self,
  6. top,
  7. size,
  8. color,
  9. objects
  10. ):
  11. self.top = top
  12. self.size = size
  13. # Color of the room
  14. self.color = color
  15. # List of objects contained
  16. self.objects = objects
  17. class FourRoomQAEnv(MiniGridEnv):
  18. """
  19. Environment to experiment with embodied question answering
  20. https://arxiv.org/abs/1711.11543
  21. """
  22. # Enumeration of possible actions
  23. class Actions(IntEnum):
  24. left = 0
  25. right = 1
  26. forward = 2
  27. toggle = 3
  28. wait = 4
  29. answer = 5
  30. def __init__(self, size=16):
  31. assert size >= 10
  32. super(FourRoomQAEnv, self).__init__(gridSize=size, maxSteps=8*size)
  33. # Action enumeration for this environment
  34. self.actions = FourRoomQAEnv.Actions
  35. # TODO: dictionary action_space, to include answer sentence?
  36. # Actions are discrete integer values
  37. self.action_space = spaces.Discrete(len(self.actions))
  38. # TODO: dictionary observation_space, to include question?
  39. self.reward_range = (-1000, 1000)
  40. def _randPos(self, room, border=1):
  41. return (
  42. self._randInt(
  43. room.top[0] + border,
  44. room.top[0] + room.size[0] - border
  45. ),
  46. self._randInt(
  47. room.top[1] + border,
  48. room.top[1] + room.size[1] - border
  49. ),
  50. )
  51. def _genGrid(self, width, height):
  52. grid = Grid(width, height)
  53. # Horizontal and vertical split indices
  54. vSplitIdx = self._randInt(5, width-4)
  55. hSplitIdx = self._randInt(5, height-4)
  56. # Create the four rooms
  57. self.rooms = []
  58. self.rooms.append(Room(
  59. (0, 0),
  60. (vSplitIdx, hSplitIdx),
  61. 'red',
  62. []
  63. ))
  64. self.rooms.append(Room(
  65. (vSplitIdx, 0),
  66. (width - vSplitIdx, hSplitIdx),
  67. 'purple',
  68. []
  69. ))
  70. self.rooms.append(Room(
  71. (0, hSplitIdx),
  72. (vSplitIdx, height - hSplitIdx),
  73. 'blue',
  74. []
  75. ))
  76. self.rooms.append(Room(
  77. (vSplitIdx, hSplitIdx),
  78. (width - vSplitIdx, height - hSplitIdx),
  79. 'yellow',
  80. []
  81. ))
  82. # Place the room walls
  83. for room in self.rooms:
  84. x, y = room.top
  85. w, h = room.size
  86. # Horizontal walls
  87. for i in range(w):
  88. grid.set(x + i, y, Wall(room.color))
  89. grid.set(x + i, y + h - 1, Wall(room.color))
  90. # Vertical walls
  91. for j in range(h):
  92. grid.set(x, y + j, Wall(room.color))
  93. grid.set(x + w - 1, y + j, Wall(room.color))
  94. # Place wall openings connecting the rooms
  95. hIdx = self._randInt(1, hSplitIdx-1)
  96. grid.set(vSplitIdx, hIdx, None)
  97. grid.set(vSplitIdx-1, hIdx, None)
  98. hIdx = self._randInt(hSplitIdx+1, height-1)
  99. grid.set(vSplitIdx, hIdx, None)
  100. grid.set(vSplitIdx-1, hIdx, None)
  101. vIdx = self._randInt(1, vSplitIdx-1)
  102. grid.set(vIdx, hSplitIdx, None)
  103. grid.set(vIdx, hSplitIdx-1, None)
  104. vIdx = self._randInt(vSplitIdx+1, width-1)
  105. grid.set(vIdx, hSplitIdx, None)
  106. grid.set(vIdx, hSplitIdx-1, None)
  107. # Select a random position for the agent to start at
  108. self.startDir = self._randInt(0, 4)
  109. room = self._randElem(self.rooms)
  110. self.startPos = self._randPos(room)
  111. # Possible object types and colors
  112. types = ['key', 'ball', 'box']
  113. colors = list(COLORS.keys())
  114. # Place a number of random objects
  115. numObjs = self._randInt(1, 10)
  116. for i in range(0, numObjs):
  117. # Generate a random object
  118. objType = self._randElem(types)
  119. objColor = self._randElem(colors)
  120. if objType == 'key':
  121. obj = Key(objColor)
  122. elif objType == 'ball':
  123. obj = Ball(objColor)
  124. elif objType == 'box':
  125. obj = Box(objColor)
  126. # Pick a random position that doesn't overlap with anything
  127. while True:
  128. room = self._randElem(self.rooms)
  129. pos = self._randPos(room, border=2)
  130. if pos == self.startPos:
  131. continue
  132. if grid.get(*pos) != None:
  133. continue
  134. grid.set(*pos, obj)
  135. break
  136. room.objects.append(obj)
  137. # Question examples:
  138. # - What color is the X?
  139. # - What color is the X in the ROOM?
  140. # - What room is the X located in?
  141. # - What color is the X in the blue room?
  142. # - How many rooms contain chairs?
  143. # - How many keys are there in the yellow room?
  144. # - How many <OBJs> in the <ROOM>?
  145. # Pick a random room to be the subject of the question
  146. room = self._randElem(self.rooms)
  147. # Pick a random object type
  148. objType = self._randElem(types)
  149. # Count the number of objects of this type in the room
  150. count = len(list(filter(lambda o: o.type == objType, room.objects)))
  151. # TODO: identify unique objects
  152. self.question = "Are there any %ss in the %s room?" % (objType, room.color)
  153. self.answer = "yes" if count > 0 else "no"
  154. # TODO: how many X in the Y room question type
  155. #print(self.question)
  156. #print(self.answer)
  157. return grid
  158. def _reset(self):
  159. obs = MiniGridEnv._reset(self)
  160. obs = {
  161. 'image': obs,
  162. 'question': self.question
  163. }
  164. return obs
  165. def _step(self, action):
  166. if isinstance(action, dict):
  167. answer = action['answer']
  168. action = action['action']
  169. else:
  170. answer = ''
  171. if action == self.actions.answer:
  172. # To the superclass, this action behaves like a noop
  173. obs, reward, done, info = MiniGridEnv._step(self, self.actions.wait)
  174. done = True
  175. if answer == self.answer:
  176. reward = 1000 - self.stepCount
  177. else:
  178. reward = -1000
  179. else:
  180. # Let the superclass handle the action
  181. obs, reward, done, info = MiniGridEnv._step(self, action)
  182. obs = {
  183. 'image': obs,
  184. 'question': self.question
  185. }
  186. return obs, reward, done, info
  187. register(
  188. id='MiniGrid-FourRoomQA-v0',
  189. entry_point='gym_minigrid.envs:FourRoomQAEnv'
  190. )