roomgrid.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. from gym_minigrid.minigrid import *
  2. from gym_minigrid.register import register
  3. class Room:
  4. def __init__(
  5. self,
  6. top,
  7. size
  8. ):
  9. # Top-left corner and size (tuples)
  10. self.top = top
  11. self.size = size
  12. # List of door objects and door positions
  13. # Order of the doors is right, down, left, up
  14. self.doors = [None] * 4
  15. self.door_pos = [None] * 4
  16. # List of rooms this is connected to
  17. # Order of the neighbors is right, down, left, up
  18. self.neighbors = [None] * 4
  19. # Indicates if this room is locked
  20. self.locked = False
  21. # List of objects contained
  22. self.objs = []
  23. def rand_pos(self, env):
  24. topX, topY = self.top
  25. sizeX, sizeY = self.size
  26. return env._randPos(
  27. topX + 1, topX + sizeX - 1,
  28. topY + 1, topY + sizeY - 1
  29. )
  30. class RoomGrid(MiniGridEnv):
  31. """
  32. Environment with multiple rooms and random objects.
  33. This is meant to serve as a base class for other environments.
  34. """
  35. def __init__(
  36. self,
  37. room_size=6,
  38. num_cols=4,
  39. max_steps=200
  40. ):
  41. assert room_size > 0
  42. assert room_size >= 4
  43. assert num_cols > 0
  44. self.room_size = room_size
  45. self.num_cols = num_cols
  46. self.num_rows = num_cols
  47. grid_size = (room_size - 1) * num_cols + 1
  48. super().__init__(gridSize=grid_size, maxSteps=max_steps)
  49. self.reward_range = (0, 1)
  50. def room_from_pos(self, x, y):
  51. """Get the room a given position maps to"""
  52. assert x >= 0
  53. assert y >= 0
  54. i = x // self.room_size
  55. j = y // self.room_size
  56. assert i < self.num_cols
  57. assert j < self.num_rows
  58. return self.room_grid[j][i]
  59. def get_room(self, i, j):
  60. assert i < self.num_cols
  61. assert j < self.num_rows
  62. return self.room_grid[j][i]
  63. def _genGrid(self, width, height):
  64. # Create the grid
  65. self.grid = Grid(width, height)
  66. self.room_grid = []
  67. # For each row of rooms
  68. for j in range(0, self.num_rows):
  69. row = []
  70. # For each column of rooms
  71. for i in range(0, self.num_cols):
  72. room = Room(
  73. (i * (self.room_size-1), j * (self.room_size-1)),
  74. (self.room_size, self.room_size)
  75. )
  76. row.append(room)
  77. # Generate the walls for this room
  78. self.grid.wallRect(*room.top, *room.size)
  79. self.room_grid.append(row)
  80. # For each row of rooms
  81. for j in range(0, self.num_rows):
  82. # For each column of rooms
  83. for i in range(0, self.num_cols):
  84. room = self.room_grid[j][i]
  85. x_l, y_l = room.top
  86. x_m, y_m = (room.top[0] + room.size[0] - 1, room.top[1] + room.size[1] - 1)
  87. # Door positions, order is right, down, left, up
  88. if i < self.num_cols - 1:
  89. room.neighbors[0] = self.room_grid[j][i+1]
  90. room.door_pos[0] = (x_m, self._randInt(y_l, y_m))
  91. if j < self.num_rows - 1:
  92. room.neighbors[1] = self.room_grid[j+1][i]
  93. room.door_pos[1] = (self._randInt(x_l, x_m), y_m)
  94. if i > 0:
  95. room.neighbors[2] = self.room_grid[j][i-1]
  96. room.door_pos[2] = room.neighbors[2].door_pos[0]
  97. if j > 0:
  98. room.neighbors[3] = self.room_grid[j-1][i]
  99. room.door_pos[3] = room.neighbors[3].door_pos[1]
  100. # The agent starts in the middle, facing right
  101. self.startPos = (
  102. (self.num_cols // 2) * (self.room_size-1) + (self.room_size // 2),
  103. (self.num_rows // 2) * (self.room_size-1) + (self.room_size // 2)
  104. )
  105. self.startDir = 0
  106. # By default, this environment has no mission
  107. self.mission = ''
  108. def add_object(self, i, j, kind, color):
  109. """
  110. Add a new object to room (i, j)
  111. """
  112. # TODO: we probably want to add an Object.make helper function
  113. assert kind in ['key', 'ball', 'box']
  114. if kind == 'key':
  115. obj = Key(color)
  116. elif kind == 'ball':
  117. obj = Ball(color)
  118. elif kind == 'box':
  119. obj = Box(color)
  120. room = self.get_room(i, j)
  121. self.placeObj(obj, room.top, room.size)
  122. room.objs.append(obj)
  123. return obj
  124. def add_door(self, i, j, k, color, locked=False):
  125. """
  126. Add a door to a room, connecting it to a neighbor
  127. """
  128. room = self.get_room(i, j)
  129. assert room.doors[k] is None, "door already exists"
  130. if locked:
  131. door = LockedDoor(color)
  132. room.locked = True
  133. else:
  134. door = Door(color)
  135. self.grid.set(*room.door_pos[k], door)
  136. neighbor = room.neighbors[k]
  137. room.doors[k] = door
  138. neighbor.doors[(k+2) % 4] = door
  139. def connect_all(self):
  140. """
  141. Make sure that all rooms are reachable by the agent from its
  142. starting position
  143. """
  144. start_room = self.room_from_pos(*self.startPos)
  145. def find_reach():
  146. reach = set()
  147. stack = [start_room]
  148. while len(stack) > 0:
  149. room = stack.pop()
  150. if room in reach:
  151. continue
  152. reach.add(room)
  153. for i in range(0, 4):
  154. if room.doors[i]:
  155. stack.append(room.neighbors[i])
  156. return reach
  157. while True:
  158. # If all rooms are reachable, stop
  159. reach = find_reach()
  160. if len(reach) == self.num_rows * self.num_cols:
  161. break
  162. # Pick a random room and door position
  163. i = self._randInt(0, self.num_cols)
  164. j = self._randInt(0, self.num_rows)
  165. k = self._randInt(0, 4)
  166. room = self.get_room(i, j)
  167. # If there is already a door there, skip
  168. if not room.door_pos[k] or room.doors[k]:
  169. continue
  170. if room.locked or room.neighbors[k].locked:
  171. continue
  172. color = self._randElem(COLOR_NAMES)
  173. self.add_door(i, j, k, color)
  174. def step(self, action):
  175. obs, reward, done, info = super().step(action)
  176. return obs, reward, done, info
  177. register(
  178. id='MiniGrid-RoomGrid-v0',
  179. entry_point='gym_minigrid.envs:RoomGrid'
  180. )