roomgrid.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. from gym_minigrid.minigrid import *
  2. from gym_minigrid.register import register
  3. class Room:
  4. def __init__(
  5. self,
  6. top,
  7. size
  8. ):
  9. # Top-left corner and size (tuples)
  10. self.top = top
  11. self.size = size
  12. # List of door objects and door positions
  13. # Order of the doors is right, down, left, up
  14. self.doors = [None] * 4
  15. self.door_pos = [None] * 4
  16. # List of rooms this is connected to
  17. # Order of the neighbors is right, down, left, up
  18. self.neighbors = [None] * 4
  19. # Indicates if this room is locked
  20. self.locked = False
  21. # List of objects contained
  22. self.objs = []
  23. def rand_pos(self, env):
  24. topX, topY = self.top
  25. sizeX, sizeY = self.size
  26. return env._randPos(
  27. topX + 1, topX + sizeX - 1,
  28. topY + 1, topY + sizeY - 1
  29. )
  30. class RoomGrid(MiniGridEnv):
  31. """
  32. Environment with multiple rooms and random objects.
  33. This is meant to serve as a base class for other environments.
  34. """
  35. def __init__(
  36. self,
  37. room_size=6,
  38. num_cols=4,
  39. lockedRooms=False
  40. ):
  41. assert room_size > 0
  42. assert room_size >= 4
  43. assert num_cols > 0
  44. self.room_size = room_size
  45. self.num_cols = num_cols
  46. self.num_rows = num_cols
  47. self.lockedRooms = False
  48. grid_size = (room_size - 1) * num_cols + 1
  49. super().__init__(gridSize=grid_size, maxSteps=6*grid_size)
  50. self.reward_range = (0, 1)
  51. def room_from_pos(self, x, y):
  52. """Get the room a given position maps to"""
  53. assert x >= 0
  54. assert y >= 0
  55. i = x // self.room_size
  56. j = y // self.room_size
  57. assert i < self.num_cols
  58. assert j < self.num_rows
  59. return self.room_grid[j][i]
  60. def get_room(self, i, j):
  61. assert i < self.num_cols
  62. assert j < self.num_rows
  63. return self.room_grid[j][i]
  64. def _genGrid(self, width, height):
  65. # Create the grid
  66. self.grid = Grid(width, height)
  67. self.room_grid = []
  68. # For each row of rooms
  69. for j in range(0, self.num_rows):
  70. row = []
  71. # For each column of rooms
  72. for i in range(0, self.num_cols):
  73. room = Room(
  74. (i * (self.room_size-1), j * (self.room_size-1)),
  75. (self.room_size, self.room_size)
  76. )
  77. row.append(room)
  78. # Generate the walls for this room
  79. self.grid.wallRect(*room.top, *room.size)
  80. self.room_grid.append(row)
  81. # For each row of rooms
  82. for j in range(0, self.num_rows):
  83. # For each column of rooms
  84. for i in range(0, self.num_cols):
  85. room = self.room_grid[j][i]
  86. # Door positions, order is right, down, left, up
  87. if i < self.num_cols - 1:
  88. room.door_pos[0] = (room.top[0] + self.room_size - 1, room.top[1] + self.room_size // 2)
  89. room.neighbors[0] = self.room_grid[j][i+1]
  90. if j < self.num_rows - 1:
  91. room.door_pos[1] = (room.top[0] + self.room_size // 2, room.top[1] + self.room_size - 1)
  92. room.neighbors[1] = self.room_grid[j+1][i]
  93. if i > 0:
  94. room.door_pos[2] = (room.top[0], room.top[1] + self.room_size // 2)
  95. room.neighbors[2] = self.room_grid[j][i-1]
  96. if j > 0:
  97. room.door_pos[3] = (room.top[0] + self.room_size // 2, room.top[1])
  98. room.neighbors[3] = self.room_grid[j-1][i]
  99. # The agent starts in the middle, facing right
  100. self.startPos = (
  101. (self.num_cols // 2) * (self.room_size-1) + (self.room_size // 2),
  102. (self.num_rows // 2) * (self.room_size-1) + (self.room_size // 2)
  103. )
  104. self.startDir = 0
  105. # By default, this environment has no mission
  106. self.mission = ''
  107. def add_object(self, i, j, kind, color):
  108. """
  109. Add a new object to room (i, j)
  110. """
  111. # TODO: we probably want to add an Object.make helper function
  112. assert kind in ['key', 'ball', 'box']
  113. if kind == 'key':
  114. obj = Key(color)
  115. elif kind == 'ball':
  116. obj = Ball(color)
  117. elif kind == 'box':
  118. obj = Box(color)
  119. room = self.get_room(i, j)
  120. self.placeObj(obj, room.top, room.size)
  121. room.objs.append(obj)
  122. return obj
  123. def add_door(self, i, j, k, color, locked=False):
  124. """
  125. Add a door to a room, connecting it to a neighbor
  126. """
  127. room = self.get_room(i, j)
  128. assert room.doors[k] is None, "door already exists"
  129. if locked:
  130. door = LockedDoor(color)
  131. room.locked = True
  132. else:
  133. door = Door(color)
  134. self.grid.set(*room.door_pos[k], door)
  135. neighbor = room.neighbors[k]
  136. room.doors[k] = door
  137. neighbor.doors[(k+2) % 4] = door
  138. def connect_all(self):
  139. """
  140. Make sure that all rooms are reachable by the agent from its
  141. starting position
  142. """
  143. start_room = self.room_from_pos(*self.startPos)
  144. def find_reach():
  145. reach = set()
  146. stack = [start_room]
  147. while len(stack) > 0:
  148. room = stack.pop()
  149. if room in reach:
  150. continue
  151. reach.add(room)
  152. for i in range(0, 4):
  153. if room.doors[i]:
  154. stack.append(room.neighbors[i])
  155. return reach
  156. while True:
  157. # If all rooms are reachable, stop
  158. reach = find_reach()
  159. if len(reach) == self.num_rows * self.num_cols:
  160. break
  161. # Pick a random room and door position
  162. i = self._randInt(0, self.num_cols)
  163. j = self._randInt(0, self.num_rows)
  164. k = self._randInt(0, 4)
  165. room = self.get_room(i, j)
  166. # If there is already a door there, skip
  167. if not room.door_pos[k] or room.doors[k]:
  168. continue
  169. if room.locked or room.neighbors[k].locked:
  170. continue
  171. color = self._randElem(COLOR_NAMES)
  172. self.add_door(i, j, k, color)
  173. def step(self, action):
  174. obs, reward, done, info = super().step(action)
  175. return obs, reward, done, info
  176. register(
  177. id='MiniGrid-RoomGrid-v0',
  178. entry_point='gym_minigrid.envs:RoomGrid'
  179. )