fourroomqa.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. from gym_minigrid.minigrid import *
  2. from gym_minigrid.register import register
  3. class Room:
  4. def __init__(
  5. self,
  6. top,
  7. size,
  8. color,
  9. objects
  10. ):
  11. self.top = top
  12. self.size = size
  13. # Color of the room
  14. self.color = color
  15. # List of objects contained
  16. self.objects = objects
  17. class FourRoomQAEnv(MiniGridEnv):
  18. """
  19. Environment to experiment with embodied question answering
  20. https://arxiv.org/abs/1711.11543
  21. """
  22. # Enumeration of possible actions
  23. class Actions(IntEnum):
  24. left = 0
  25. right = 1
  26. forward = 2
  27. toggle = 3
  28. say = 4
  29. def __init__(self, size=16):
  30. assert size >= 10
  31. super(FourRoomQAEnv, self).__init__(gridSize=size, maxSteps=8*size)
  32. # Action enumeration for this environment
  33. self.actions = MiniGridEnv.Actions
  34. # TODO: dictionary action_space, to include answer sentence?
  35. # Actions are discrete integer values
  36. self.action_space = spaces.Discrete(len(self.actions))
  37. # TODO: dictionary observation_space, to include question?
  38. def _genGrid(self, width, height):
  39. grid = Grid(width, height)
  40. # Horizontal and vertical split indices
  41. vSplitIdx = self._randInt(5, width-4)
  42. hSplitIdx = self._randInt(5, height-4)
  43. # Create the four rooms
  44. self.rooms = []
  45. self.rooms.append(Room(
  46. (0, 0),
  47. (vSplitIdx, hSplitIdx),
  48. 'red',
  49. []
  50. ))
  51. self.rooms.append(Room(
  52. (vSplitIdx, 0),
  53. (width - vSplitIdx, hSplitIdx),
  54. 'purple',
  55. []
  56. ))
  57. self.rooms.append(Room(
  58. (0, hSplitIdx),
  59. (vSplitIdx, height - hSplitIdx),
  60. 'blue',
  61. []
  62. ))
  63. self.rooms.append(Room(
  64. (vSplitIdx, hSplitIdx),
  65. (width - vSplitIdx, height - hSplitIdx),
  66. 'yellow',
  67. []
  68. ))
  69. # Place the room walls
  70. for room in self.rooms:
  71. x, y = room.top
  72. w, h = room.size
  73. # Horizontal walls
  74. for i in range(w):
  75. grid.set(x + i, y, Wall(room.color))
  76. grid.set(x + i, y + h - 1, Wall(room.color))
  77. # Vertical walls
  78. for j in range(h):
  79. grid.set(x, y + j, Wall(room.color))
  80. grid.set(x + w - 1, y + j, Wall(room.color))
  81. # Place wall openings connecting the rooms
  82. hIdx = self._randInt(1, hSplitIdx-1)
  83. grid.set(vSplitIdx, hIdx, None)
  84. grid.set(vSplitIdx-1, hIdx, None)
  85. hIdx = self._randInt(hSplitIdx+1, height-1)
  86. grid.set(vSplitIdx, hIdx, None)
  87. grid.set(vSplitIdx-1, hIdx, None)
  88. # TODO: pick a random room to be the subject of the question
  89. # TODO: identify unique objects
  90. # TODO:
  91. # Generate a question and answer
  92. self.question = ''
  93. # Question examples:
  94. # - What color is the X?
  95. # - What color is the X in the ROOM?
  96. # - What room is the X located in?
  97. # - What color is the X in the blue room?
  98. # - How many rooms contain chairs?
  99. # - How many keys are there in the yellow room?
  100. # - How many <OBJs> in the <ROOM>?
  101. #self.answer
  102. return grid
  103. def _reset(self):
  104. obs = MiniGridEnv._reset(self)
  105. obs = {
  106. 'image': obs,
  107. 'question': self.question
  108. }
  109. return obs
  110. def _step(self, action):
  111. obs, reward, done, info = MiniGridEnv._step(self, action)
  112. obs = {
  113. 'image': obs,
  114. 'question': self.question
  115. }
  116. return obs, reward, done, info
  117. register(
  118. id='MiniGrid-FourRoomQA-v0',
  119. entry_point='gym_minigrid.envs:FourRoomQAEnv'
  120. )