roomgrid.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. from gym_minigrid.minigrid import *
  2. from gym_minigrid.register import register
  3. class Room:
  4. def __init__(
  5. self,
  6. top,
  7. size
  8. ):
  9. # Top-left corner and size (tuples)
  10. self.top = top
  11. self.size = size
  12. # List of door objects and door positions
  13. # Order of the doors is right, down, left, up
  14. self.doors = [None] * 4
  15. self.door_pos = [None] * 4
  16. # List of rooms this is connected to
  17. # Order of the neighbors is right, down, left, up
  18. self.neighbors = [None] * 4
  19. # Indicates if this room is locked
  20. self.locked = False
  21. # List of objects contained
  22. self.objs = []
  23. def rand_pos(self, env):
  24. topX, topY = self.top
  25. sizeX, sizeY = self.size
  26. return env._randPos(
  27. topX + 1, topX + sizeX - 1,
  28. topY + 1, topY + sizeY - 1
  29. )
  30. class RoomGrid(MiniGridEnv):
  31. """
  32. Environment with multiple rooms and random objects.
  33. This is meant to serve as a base class for other environments.
  34. """
  35. def __init__(
  36. self,
  37. room_size=6,
  38. num_cols=4,
  39. lockedRooms=False
  40. ):
  41. assert room_size > 0
  42. assert room_size >= 4
  43. assert num_cols > 0
  44. self.room_size = room_size
  45. self.num_cols = num_cols
  46. self.num_rows = num_cols
  47. self.lockedRooms = False
  48. grid_size = (room_size - 1) * num_cols + 1
  49. super().__init__(gridSize=grid_size, maxSteps=6*grid_size)
  50. self.reward_range = (0, 1)
  51. def room_from_pos(self, x, y):
  52. """Get the room a given position maps to"""
  53. assert x >= 0
  54. assert y >= 0
  55. i = x // self.room_size
  56. j = y // self.room_size
  57. assert i < self.num_cols
  58. assert j < self.num_rows
  59. return self.room_grid[j][i]
  60. def get_room(self, i, j):
  61. assert i < self.num_cols
  62. assert j < self.num_rows
  63. return self.room_grid[j][i]
  64. def _genGrid(self, width, height):
  65. # Create the grid
  66. self.grid = Grid(width, height)
  67. self.room_grid = []
  68. # For each row of rooms
  69. for j in range(0, self.num_rows):
  70. row = []
  71. # For each column of rooms
  72. for i in range(0, self.num_cols):
  73. room = Room(
  74. (i * (self.room_size-1), j * (self.room_size-1)),
  75. (self.room_size, self.room_size)
  76. )
  77. row.append(room)
  78. # Generate the walls for this room
  79. self.grid.wallRect(*room.top, *room.size)
  80. self.room_grid.append(row)
  81. # For each row of rooms
  82. for j in range(0, self.num_rows):
  83. # For each column of rooms
  84. for i in range(0, self.num_cols):
  85. room = self.room_grid[j][i]
  86. x_l, y_l = room.top
  87. x_m, y_m = (room.top[0] + room.size[0] - 1, room.top[1] + room.size[1] - 1)
  88. # Door positions, order is right, down, left, up
  89. if i < self.num_cols - 1:
  90. room.neighbors[0] = self.room_grid[j][i+1]
  91. room.door_pos[0] = (x_m, self._randInt(y_l, y_m))
  92. if j < self.num_rows - 1:
  93. room.neighbors[1] = self.room_grid[j+1][i]
  94. room.door_pos[1] = (self._randInt(x_l, x_m), y_m)
  95. if i > 0:
  96. room.neighbors[2] = self.room_grid[j][i-1]
  97. room.door_pos[2] = room.neighbors[2].door_pos[0]
  98. if j > 0:
  99. room.neighbors[3] = self.room_grid[j-1][i]
  100. room.door_pos[3] = room.neighbors[3].door_pos[1]
  101. # The agent starts in the middle, facing right
  102. self.startPos = (
  103. (self.num_cols // 2) * (self.room_size-1) + (self.room_size // 2),
  104. (self.num_rows // 2) * (self.room_size-1) + (self.room_size // 2)
  105. )
  106. self.startDir = 0
  107. # By default, this environment has no mission
  108. self.mission = ''
  109. def add_object(self, i, j, kind, color):
  110. """
  111. Add a new object to room (i, j)
  112. """
  113. # TODO: we probably want to add an Object.make helper function
  114. assert kind in ['key', 'ball', 'box']
  115. if kind == 'key':
  116. obj = Key(color)
  117. elif kind == 'ball':
  118. obj = Ball(color)
  119. elif kind == 'box':
  120. obj = Box(color)
  121. room = self.get_room(i, j)
  122. self.placeObj(obj, room.top, room.size)
  123. room.objs.append(obj)
  124. return obj
  125. def add_door(self, i, j, k, color, locked=False):
  126. """
  127. Add a door to a room, connecting it to a neighbor
  128. """
  129. room = self.get_room(i, j)
  130. assert room.doors[k] is None, "door already exists"
  131. if locked:
  132. door = LockedDoor(color)
  133. room.locked = True
  134. else:
  135. door = Door(color)
  136. self.grid.set(*room.door_pos[k], door)
  137. neighbor = room.neighbors[k]
  138. room.doors[k] = door
  139. neighbor.doors[(k+2) % 4] = door
  140. def connect_all(self):
  141. """
  142. Make sure that all rooms are reachable by the agent from its
  143. starting position
  144. """
  145. start_room = self.room_from_pos(*self.startPos)
  146. def find_reach():
  147. reach = set()
  148. stack = [start_room]
  149. while len(stack) > 0:
  150. room = stack.pop()
  151. if room in reach:
  152. continue
  153. reach.add(room)
  154. for i in range(0, 4):
  155. if room.doors[i]:
  156. stack.append(room.neighbors[i])
  157. return reach
  158. while True:
  159. # If all rooms are reachable, stop
  160. reach = find_reach()
  161. if len(reach) == self.num_rows * self.num_cols:
  162. break
  163. # Pick a random room and door position
  164. i = self._randInt(0, self.num_cols)
  165. j = self._randInt(0, self.num_rows)
  166. k = self._randInt(0, 4)
  167. room = self.get_room(i, j)
  168. # If there is already a door there, skip
  169. if not room.door_pos[k] or room.doors[k]:
  170. continue
  171. if room.locked or room.neighbors[k].locked:
  172. continue
  173. color = self._randElem(COLOR_NAMES)
  174. self.add_door(i, j, k, color)
  175. def step(self, action):
  176. obs, reward, done, info = super().step(action)
  177. return obs, reward, done, info
  178. register(
  179. id='MiniGrid-RoomGrid-v0',
  180. entry_point='gym_minigrid.envs:RoomGrid'
  181. )