fourroomqa.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. from gym_minigrid.minigrid import *
  2. from gym_minigrid.register import register
  3. class Room:
  4. def __init__(
  5. self,
  6. top,
  7. size,
  8. color,
  9. objects
  10. ):
  11. self.top = top
  12. self.size = size
  13. # Color of the room
  14. self.color = color
  15. # List of objects contained
  16. self.objects = objects
  17. class FourRoomQAEnv(MiniGridEnv):
  18. """
  19. Environment to experiment with embodied question answering
  20. https://arxiv.org/abs/1711.11543
  21. """
  22. # Enumeration of possible actions
  23. class Actions(IntEnum):
  24. left = 0
  25. right = 1
  26. forward = 2
  27. toggle = 3
  28. wait = 4
  29. answer = 5
  30. def __init__(self, size=16):
  31. assert size >= 10
  32. super(FourRoomQAEnv, self).__init__(gridSize=size, maxSteps=8*size)
  33. # Action enumeration for this environment
  34. self.actions = FourRoomQAEnv.Actions
  35. # TODO: dictionary action_space, to include answer sentence?
  36. # Actions are discrete integer values
  37. self.action_space = spaces.Discrete(len(self.actions))
  38. self.reward_range = (-1000, 1000)
  39. def _randPos(self, room, border=1):
  40. return (
  41. self._randInt(
  42. room.top[0] + border,
  43. room.top[0] + room.size[0] - border
  44. ),
  45. self._randInt(
  46. room.top[1] + border,
  47. room.top[1] + room.size[1] - border
  48. ),
  49. )
  50. def _genGrid(self, width, height):
  51. self.grid = Grid(width, height)
  52. # Horizontal and vertical split indices
  53. vSplitIdx = self._randInt(5, width-4)
  54. hSplitIdx = self._randInt(5, height-4)
  55. # Create the four rooms
  56. self.rooms = []
  57. self.rooms.append(Room(
  58. (0, 0),
  59. (vSplitIdx, hSplitIdx),
  60. 'red',
  61. []
  62. ))
  63. self.rooms.append(Room(
  64. (vSplitIdx, 0),
  65. (width - vSplitIdx, hSplitIdx),
  66. 'purple',
  67. []
  68. ))
  69. self.rooms.append(Room(
  70. (0, hSplitIdx),
  71. (vSplitIdx, height - hSplitIdx),
  72. 'blue',
  73. []
  74. ))
  75. self.rooms.append(Room(
  76. (vSplitIdx, hSplitIdx),
  77. (width - vSplitIdx, height - hSplitIdx),
  78. 'yellow',
  79. []
  80. ))
  81. # Place the room walls
  82. for room in self.rooms:
  83. x, y = room.top
  84. w, h = room.size
  85. # Horizontal walls
  86. for i in range(w):
  87. self.grid.set(x + i, y, Wall(room.color))
  88. self.grid.set(x + i, y + h - 1, Wall(room.color))
  89. # Vertical walls
  90. for j in range(h):
  91. self.grid.set(x, y + j, Wall(room.color))
  92. self.grid.set(x + w - 1, y + j, Wall(room.color))
  93. # Place wall openings connecting the rooms
  94. hIdx = self._randInt(1, hSplitIdx-1)
  95. self.grid.set(vSplitIdx, hIdx, None)
  96. self.grid.set(vSplitIdx-1, hIdx, None)
  97. hIdx = self._randInt(hSplitIdx+1, height-1)
  98. self.grid.set(vSplitIdx, hIdx, None)
  99. self.grid.set(vSplitIdx-1, hIdx, None)
  100. vIdx = self._randInt(1, vSplitIdx-1)
  101. self.grid.set(vIdx, hSplitIdx, None)
  102. self.grid.set(vIdx, hSplitIdx-1, None)
  103. vIdx = self._randInt(vSplitIdx+1, width-1)
  104. self.grid.set(vIdx, hSplitIdx, None)
  105. self.grid.set(vIdx, hSplitIdx-1, None)
  106. # Select a random position for the agent to start at
  107. self.startDir = self._randInt(0, 4)
  108. room = self._randElem(self.rooms)
  109. self.startPos = self._randPos(room)
  110. # Possible object types and colors
  111. types = ['key', 'ball', 'box']
  112. colors = list(COLORS.keys())
  113. # Place a number of random objects
  114. numObjs = self._randInt(1, 10)
  115. for i in range(0, numObjs):
  116. # Generate a random object
  117. objType = self._randElem(types)
  118. objColor = self._randElem(colors)
  119. if objType == 'key':
  120. obj = Key(objColor)
  121. elif objType == 'ball':
  122. obj = Ball(objColor)
  123. elif objType == 'box':
  124. obj = Box(objColor)
  125. # Pick a random position that doesn't overlap with anything
  126. while True:
  127. room = self._randElem(self.rooms)
  128. pos = self._randPos(room, border=2)
  129. if pos == self.startPos:
  130. continue
  131. if self.grid.get(*pos) != None:
  132. continue
  133. self.grid.set(*pos, obj)
  134. break
  135. room.objects.append(obj)
  136. # Question examples:
  137. # - What color is the X?
  138. # - What color is the X in the ROOM?
  139. # - What room is the X located in?
  140. # - What color is the X in the blue room?
  141. # - How many rooms contain chairs?
  142. # - How many keys are there in the yellow room?
  143. # - How many <OBJs> in the <ROOM>?
  144. # Pick a random room to be the subject of the question
  145. room = self._randElem(self.rooms)
  146. # Pick a random object type
  147. objType = self._randElem(types)
  148. # Count the number of objects of this type in the room
  149. count = len(list(filter(lambda o: o.type == objType, room.objects)))
  150. # TODO: identify unique objects
  151. self.mission = "Are there any %ss in the %s room?" % (objType, room.color)
  152. self.answer = "yes" if count > 0 else "no"
  153. # TODO: how many X in the Y room question type
  154. def step(self, action):
  155. if isinstance(action, dict):
  156. answer = action['answer']
  157. action = action['action']
  158. else:
  159. answer = ''
  160. if action == self.actions.answer:
  161. # To the superclass, this action behaves like a noop
  162. obs, reward, done, info = MiniGridEnv.step(self, self.actions.wait)
  163. done = True
  164. if answer == self.mission:
  165. reward = 1000 - self.stepCount
  166. else:
  167. reward = -1000
  168. else:
  169. # Let the superclass handle the action
  170. obs, reward, done, info = MiniGridEnv.step(self, action)
  171. return obs, reward, done, info
  172. register(
  173. id='MiniGrid-FourRoomQA-v0',
  174. entry_point='gym_minigrid.envs:FourRoomQAEnv'
  175. )