roomgrid.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. import numpy as np
  2. from gym_minigrid.minigrid import COLOR_NAMES, Ball, Box, Door, Grid, Key, MiniGridEnv
  3. def reject_next_to(env, pos):
  4. """
  5. Function to filter out object positions that are right next to
  6. the agent's starting point
  7. """
  8. sx, sy = env.agent_pos
  9. x, y = pos
  10. d = abs(sx - x) + abs(sy - y)
  11. return d < 2
  12. class Room:
  13. def __init__(self, top, size):
  14. # Top-left corner and size (tuples)
  15. self.top = top
  16. self.size = size
  17. # List of door objects and door positions
  18. # Order of the doors is right, down, left, up
  19. self.doors = [None] * 4
  20. self.door_pos = [None] * 4
  21. # List of rooms adjacent to this one
  22. # Order of the neighbors is right, down, left, up
  23. self.neighbors = [None] * 4
  24. # Indicates if this room is behind a locked door
  25. self.locked = False
  26. # List of objects contained
  27. self.objs = []
  28. def rand_pos(self, env):
  29. topX, topY = self.top
  30. sizeX, sizeY = self.size
  31. return env._randPos(topX + 1, topX + sizeX - 1, topY + 1, topY + sizeY - 1)
  32. def pos_inside(self, x, y):
  33. """
  34. Check if a position is within the bounds of this room
  35. """
  36. topX, topY = self.top
  37. sizeX, sizeY = self.size
  38. if x < topX or y < topY:
  39. return False
  40. if x >= topX + sizeX or y >= topY + sizeY:
  41. return False
  42. return True
  43. class RoomGrid(MiniGridEnv):
  44. """
  45. Environment with multiple rooms and random objects.
  46. This is meant to serve as a base class for other environments.
  47. """
  48. def __init__(
  49. self,
  50. room_size=7,
  51. num_rows=3,
  52. num_cols=3,
  53. max_steps=100,
  54. agent_view_size=7,
  55. **kwargs,
  56. ):
  57. assert room_size > 0
  58. assert room_size >= 3
  59. assert num_rows > 0
  60. assert num_cols > 0
  61. self.room_size = room_size
  62. self.num_rows = num_rows
  63. self.num_cols = num_cols
  64. height = (room_size - 1) * num_rows + 1
  65. width = (room_size - 1) * num_cols + 1
  66. # By default, this environment has no mission
  67. self.mission = ""
  68. super().__init__(
  69. width=width,
  70. height=height,
  71. max_steps=max_steps,
  72. see_through_walls=False,
  73. agent_view_size=agent_view_size,
  74. **kwargs,
  75. )
  76. def room_from_pos(self, x, y):
  77. """Get the room a given position maps to"""
  78. assert x >= 0
  79. assert y >= 0
  80. i = x // (self.room_size - 1)
  81. j = y // (self.room_size - 1)
  82. assert i < self.num_cols
  83. assert j < self.num_rows
  84. return self.room_grid[j][i]
  85. def get_room(self, i, j):
  86. assert i < self.num_cols
  87. assert j < self.num_rows
  88. return self.room_grid[j][i]
  89. def _gen_grid(self, width, height):
  90. # Create the grid
  91. self.grid = Grid(width, height)
  92. self.room_grid = []
  93. # For each row of rooms
  94. for j in range(0, self.num_rows):
  95. row = []
  96. # For each column of rooms
  97. for i in range(0, self.num_cols):
  98. room = Room(
  99. (i * (self.room_size - 1), j * (self.room_size - 1)),
  100. (self.room_size, self.room_size),
  101. )
  102. row.append(room)
  103. # Generate the walls for this room
  104. self.grid.wall_rect(*room.top, *room.size)
  105. self.room_grid.append(row)
  106. # For each row of rooms
  107. for j in range(0, self.num_rows):
  108. # For each column of rooms
  109. for i in range(0, self.num_cols):
  110. room = self.room_grid[j][i]
  111. x_l, y_l = (room.top[0] + 1, room.top[1] + 1)
  112. x_m, y_m = (
  113. room.top[0] + room.size[0] - 1,
  114. room.top[1] + room.size[1] - 1,
  115. )
  116. # Door positions, order is right, down, left, up
  117. if i < self.num_cols - 1:
  118. room.neighbors[0] = self.room_grid[j][i + 1]
  119. room.door_pos[0] = (x_m, self._rand_int(y_l, y_m))
  120. if j < self.num_rows - 1:
  121. room.neighbors[1] = self.room_grid[j + 1][i]
  122. room.door_pos[1] = (self._rand_int(x_l, x_m), y_m)
  123. if i > 0:
  124. room.neighbors[2] = self.room_grid[j][i - 1]
  125. room.door_pos[2] = room.neighbors[2].door_pos[0]
  126. if j > 0:
  127. room.neighbors[3] = self.room_grid[j - 1][i]
  128. room.door_pos[3] = room.neighbors[3].door_pos[1]
  129. # The agent starts in the middle, facing right
  130. self.agent_pos = np.array(
  131. (
  132. (self.num_cols // 2) * (self.room_size - 1) + (self.room_size // 2),
  133. (self.num_rows // 2) * (self.room_size - 1) + (self.room_size // 2),
  134. )
  135. )
  136. self.agent_dir = 0
  137. def place_in_room(self, i, j, obj):
  138. """
  139. Add an existing object to room (i, j)
  140. """
  141. room = self.get_room(i, j)
  142. pos = self.place_obj(
  143. obj, room.top, room.size, reject_fn=reject_next_to, max_tries=1000
  144. )
  145. room.objs.append(obj)
  146. return obj, pos
  147. def add_object(self, i, j, kind=None, color=None):
  148. """
  149. Add a new object to room (i, j)
  150. """
  151. if kind is None:
  152. kind = self._rand_elem(["key", "ball", "box"])
  153. if color is None:
  154. color = self._rand_color()
  155. # TODO: we probably want to add an Object.make helper function
  156. assert kind in ["key", "ball", "box"]
  157. if kind == "key":
  158. obj = Key(color)
  159. elif kind == "ball":
  160. obj = Ball(color)
  161. elif kind == "box":
  162. obj = Box(color)
  163. else:
  164. raise ValueError(
  165. f"{kind} object kind is not available in this environment."
  166. )
  167. return self.place_in_room(i, j, obj)
  168. def add_door(self, i, j, door_idx=None, color=None, locked=None):
  169. """
  170. Add a door to a room, connecting it to a neighbor
  171. """
  172. room = self.get_room(i, j)
  173. if door_idx is None:
  174. # Need to make sure that there is a neighbor along this wall
  175. # and that there is not already a door
  176. while True:
  177. door_idx = self._rand_int(0, 4)
  178. if room.neighbors[door_idx] and room.doors[door_idx] is None:
  179. break
  180. if color is None:
  181. color = self._rand_color()
  182. if locked is None:
  183. locked = self._rand_bool()
  184. assert room.doors[door_idx] is None, "door already exists"
  185. room.locked = locked
  186. door = Door(color, is_locked=locked)
  187. pos = room.door_pos[door_idx]
  188. self.grid.set(pos[0], pos[1], door)
  189. door.cur_pos = pos
  190. neighbor = room.neighbors[door_idx]
  191. room.doors[door_idx] = door
  192. neighbor.doors[(door_idx + 2) % 4] = door
  193. return door, pos
  194. def remove_wall(self, i, j, wall_idx):
  195. """
  196. Remove a wall between two rooms
  197. """
  198. room = self.get_room(i, j)
  199. assert wall_idx >= 0 and wall_idx < 4
  200. assert room.doors[wall_idx] is None, "door exists on this wall"
  201. assert room.neighbors[wall_idx], "invalid wall"
  202. neighbor = room.neighbors[wall_idx]
  203. tx, ty = room.top
  204. w, h = room.size
  205. # Ordering of walls is right, down, left, up
  206. if wall_idx == 0:
  207. for i in range(1, h - 1):
  208. self.grid.set(tx + w - 1, ty + i, None)
  209. elif wall_idx == 1:
  210. for i in range(1, w - 1):
  211. self.grid.set(tx + i, ty + h - 1, None)
  212. elif wall_idx == 2:
  213. for i in range(1, h - 1):
  214. self.grid.set(tx, ty + i, None)
  215. elif wall_idx == 3:
  216. for i in range(1, w - 1):
  217. self.grid.set(tx + i, ty, None)
  218. else:
  219. assert False, "invalid wall index"
  220. # Mark the rooms as connected
  221. room.doors[wall_idx] = True
  222. neighbor.doors[(wall_idx + 2) % 4] = True
  223. def place_agent(self, i=None, j=None, rand_dir=True):
  224. """
  225. Place the agent in a room
  226. """
  227. if i is None:
  228. i = self._rand_int(0, self.num_cols)
  229. if j is None:
  230. j = self._rand_int(0, self.num_rows)
  231. room = self.room_grid[j][i]
  232. # Find a position that is not right in front of an object
  233. while True:
  234. super().place_agent(room.top, room.size, rand_dir, max_tries=1000)
  235. front_cell = self.grid.get(*self.front_pos)
  236. if front_cell is None or front_cell.type == "wall":
  237. break
  238. return self.agent_pos
  239. def connect_all(self, door_colors=COLOR_NAMES, max_itrs=5000):
  240. """
  241. Make sure that all rooms are reachable by the agent from its
  242. starting position
  243. """
  244. start_room = self.room_from_pos(*self.agent_pos)
  245. added_doors = []
  246. def find_reach():
  247. reach = set()
  248. stack = [start_room]
  249. while len(stack) > 0:
  250. room = stack.pop()
  251. if room in reach:
  252. continue
  253. reach.add(room)
  254. for i in range(0, 4):
  255. if room.doors[i]:
  256. stack.append(room.neighbors[i])
  257. return reach
  258. num_itrs = 0
  259. while True:
  260. # This is to handle rare situations where random sampling produces
  261. # a level that cannot be connected, producing in an infinite loop
  262. if num_itrs > max_itrs:
  263. raise RecursionError("connect_all failed")
  264. num_itrs += 1
  265. # If all rooms are reachable, stop
  266. reach = find_reach()
  267. if len(reach) == self.num_rows * self.num_cols:
  268. break
  269. # Pick a random room and door position
  270. i = self._rand_int(0, self.num_cols)
  271. j = self._rand_int(0, self.num_rows)
  272. k = self._rand_int(0, 4)
  273. room = self.get_room(i, j)
  274. # If there is already a door there, skip
  275. if not room.door_pos[k] or room.doors[k]:
  276. continue
  277. if room.locked or room.neighbors[k].locked:
  278. continue
  279. color = self._rand_elem(door_colors)
  280. door, _ = self.add_door(i, j, k, color, False)
  281. added_doors.append(door)
  282. return added_doors
  283. def add_distractors(self, i=None, j=None, num_distractors=10, all_unique=True):
  284. """
  285. Add random objects that can potentially distract/confuse the agent.
  286. """
  287. # Collect a list of existing objects
  288. objs = []
  289. for row in self.room_grid:
  290. for room in row:
  291. for obj in room.objs:
  292. objs.append((obj.type, obj.color))
  293. # List of distractors added
  294. dists = []
  295. while len(dists) < num_distractors:
  296. color = self._rand_elem(COLOR_NAMES)
  297. type = self._rand_elem(["key", "ball", "box"])
  298. obj = (type, color)
  299. if all_unique and obj in objs:
  300. continue
  301. # Add the object to a random room if no room specified
  302. room_i = i
  303. room_j = j
  304. if room_i is None:
  305. room_i = self._rand_int(0, self.num_cols)
  306. if room_j is None:
  307. room_j = self._rand_int(0, self.num_rows)
  308. dist, pos = self.add_object(room_i, room_j, *obj)
  309. objs.append(obj)
  310. dists.append(dist)
  311. return dists