roomgrid.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. from __future__ import annotations
  2. import numpy as np
  3. from minigrid.core.constants import COLOR_NAMES
  4. from minigrid.core.grid import Grid
  5. from minigrid.core.world_object import Ball, Box, Door, Key, WorldObj
  6. from minigrid.minigrid_env import MiniGridEnv
  7. def reject_next_to(env: MiniGridEnv, pos: tuple[int, int]):
  8. """
  9. Function to filter out object positions that are right next to
  10. the agent's starting point
  11. """
  12. sx, sy = env.agent_pos
  13. x, y = pos
  14. d = abs(sx - x) + abs(sy - y)
  15. return d < 2
  16. class Room:
  17. def __init__(self, top: tuple[int, int], size: tuple[int, int]):
  18. # Top-left corner and size (tuples)
  19. self.top = top
  20. self.size = size
  21. # List of door objects and door positions
  22. # Order of the doors is right, down, left, up
  23. self.doors: list[bool | Door | None] = [None] * 4
  24. self.door_pos: list[tuple[int, int] | None] = [None] * 4
  25. # List of rooms adjacent to this one
  26. # Order of the neighbors is right, down, left, up
  27. self.neighbors: list[Room | None] = [None] * 4
  28. # Indicates if this room is behind a locked door
  29. self.locked: bool = False
  30. # List of objects contained
  31. self.objs: list[WorldObj] = []
  32. def rand_pos(self, env: MiniGridEnv) -> tuple[int, int]:
  33. topX, topY = self.top
  34. sizeX, sizeY = self.size
  35. return env._randPos(topX + 1, topX + sizeX - 1, topY + 1, topY + sizeY - 1)
  36. def pos_inside(self, x: int, y: int) -> bool:
  37. """
  38. Check if a position is within the bounds of this room
  39. """
  40. topX, topY = self.top
  41. sizeX, sizeY = self.size
  42. if x < topX or y < topY:
  43. return False
  44. if x >= topX + sizeX or y >= topY + sizeY:
  45. return False
  46. return True
  47. class RoomGrid(MiniGridEnv):
  48. """
  49. Environment with multiple rooms and random objects.
  50. This is meant to serve as a base class for other environments.
  51. """
  52. def __init__(
  53. self,
  54. room_size: int = 7,
  55. num_rows: int = 3,
  56. num_cols: int = 3,
  57. max_steps: int = 100,
  58. agent_view_size: int = 7,
  59. **kwargs,
  60. ):
  61. assert room_size > 0
  62. assert room_size >= 3
  63. assert num_rows > 0
  64. assert num_cols > 0
  65. self.room_size = room_size
  66. self.num_rows = num_rows
  67. self.num_cols = num_cols
  68. height = (room_size - 1) * num_rows + 1
  69. width = (room_size - 1) * num_cols + 1
  70. # By default, this environment has no mission
  71. self.mission = ""
  72. super().__init__(
  73. width=width,
  74. height=height,
  75. max_steps=max_steps,
  76. see_through_walls=False,
  77. agent_view_size=agent_view_size,
  78. **kwargs,
  79. )
  80. def room_from_pos(self, x: int, y: int) -> Room:
  81. """Get the room a given position maps to"""
  82. assert x >= 0
  83. assert y >= 0
  84. i = x // (self.room_size - 1)
  85. j = y // (self.room_size - 1)
  86. assert i < self.num_cols
  87. assert j < self.num_rows
  88. return self.room_grid[j][i]
  89. def get_room(self, i: int, j: int) -> Room:
  90. assert i < self.num_cols
  91. assert j < self.num_rows
  92. return self.room_grid[j][i]
  93. def _gen_grid(self, width, height):
  94. # Create the grid
  95. self.grid = Grid(width, height)
  96. self.room_grid = []
  97. # For each row of rooms
  98. for j in range(0, self.num_rows):
  99. row = []
  100. # For each column of rooms
  101. for i in range(0, self.num_cols):
  102. room = Room(
  103. (i * (self.room_size - 1), j * (self.room_size - 1)),
  104. (self.room_size, self.room_size),
  105. )
  106. row.append(room)
  107. # Generate the walls for this room
  108. self.grid.wall_rect(*room.top, *room.size)
  109. self.room_grid.append(row)
  110. # For each row of rooms
  111. for j in range(0, self.num_rows):
  112. # For each column of rooms
  113. for i in range(0, self.num_cols):
  114. room = self.room_grid[j][i]
  115. x_l, y_l = (room.top[0] + 1, room.top[1] + 1)
  116. x_m, y_m = (
  117. room.top[0] + room.size[0] - 1,
  118. room.top[1] + room.size[1] - 1,
  119. )
  120. # Door positions, order is right, down, left, up
  121. if i < self.num_cols - 1:
  122. room.neighbors[0] = self.room_grid[j][i + 1]
  123. room.door_pos[0] = (x_m, self._rand_int(y_l, y_m))
  124. if j < self.num_rows - 1:
  125. room.neighbors[1] = self.room_grid[j + 1][i]
  126. room.door_pos[1] = (self._rand_int(x_l, x_m), y_m)
  127. if i > 0:
  128. room.neighbors[2] = self.room_grid[j][i - 1]
  129. room.door_pos[2] = room.neighbors[2].door_pos[0]
  130. if j > 0:
  131. room.neighbors[3] = self.room_grid[j - 1][i]
  132. room.door_pos[3] = room.neighbors[3].door_pos[1]
  133. # The agent starts in the middle, facing right
  134. self.agent_pos = np.array(
  135. (
  136. (self.num_cols // 2) * (self.room_size - 1) + (self.room_size // 2),
  137. (self.num_rows // 2) * (self.room_size - 1) + (self.room_size // 2),
  138. )
  139. )
  140. self.agent_dir = 0
  141. def place_in_room(
  142. self, i: int, j: int, obj: WorldObj
  143. ) -> tuple[WorldObj, tuple[int, int]]:
  144. """
  145. Add an existing object to room (i, j)
  146. """
  147. room = self.get_room(i, j)
  148. pos = self.place_obj(
  149. obj, room.top, room.size, reject_fn=reject_next_to, max_tries=1000
  150. )
  151. room.objs.append(obj)
  152. return obj, pos
  153. def add_object(
  154. self,
  155. i: int,
  156. j: int,
  157. kind: str | None = None,
  158. color: str | None = None,
  159. ) -> tuple[WorldObj, tuple[int, int]]:
  160. """
  161. Add a new object to room (i, j)
  162. """
  163. if kind is None:
  164. kind = self._rand_elem(["key", "ball", "box"])
  165. if color is None:
  166. color = self._rand_color()
  167. # TODO: we probably want to add an Object.make helper function
  168. assert kind in ["key", "ball", "box"]
  169. if kind == "key":
  170. obj = Key(color)
  171. elif kind == "ball":
  172. obj = Ball(color)
  173. elif kind == "box":
  174. obj = Box(color)
  175. else:
  176. raise ValueError(
  177. f"{kind} object kind is not available in this environment."
  178. )
  179. return self.place_in_room(i, j, obj)
  180. def add_door(
  181. self,
  182. i: int,
  183. j: int,
  184. door_idx: int | None = None,
  185. color: str | None = None,
  186. locked: bool | None = None,
  187. ) -> tuple[Door, tuple[int, int]]:
  188. """
  189. Add a door to a room, connecting it to a neighbor
  190. """
  191. room = self.get_room(i, j)
  192. if door_idx is None:
  193. # Need to make sure that there is a neighbor along this wall
  194. # and that there is not already a door
  195. while True:
  196. door_idx = self._rand_int(0, 4)
  197. if room.neighbors[door_idx] and room.doors[door_idx] is None:
  198. break
  199. if color is None:
  200. color = self._rand_color()
  201. if locked is None:
  202. locked = self._rand_bool()
  203. assert room.doors[door_idx] is None, "door already exists"
  204. room.locked = locked
  205. door = Door(color, is_locked=locked)
  206. pos = room.door_pos[door_idx]
  207. assert pos is not None
  208. self.grid.set(pos[0], pos[1], door)
  209. door.cur_pos = pos
  210. assert door_idx is not None
  211. neighbor = room.neighbors[door_idx]
  212. assert neighbor is not None
  213. room.doors[door_idx] = door
  214. neighbor.doors[(door_idx + 2) % 4] = door
  215. return door, pos
  216. def remove_wall(self, i: int, j: int, wall_idx: int):
  217. """
  218. Remove a wall between two rooms
  219. """
  220. room = self.get_room(i, j)
  221. assert 0 <= wall_idx < 4
  222. assert room.doors[wall_idx] is None, "door exists on this wall"
  223. assert room.neighbors[wall_idx], "invalid wall"
  224. neighbor = room.neighbors[wall_idx]
  225. tx, ty = room.top
  226. w, h = room.size
  227. # Ordering of walls is right, down, left, up
  228. if wall_idx == 0:
  229. for i in range(1, h - 1):
  230. self.grid.set(tx + w - 1, ty + i, None)
  231. elif wall_idx == 1:
  232. for i in range(1, w - 1):
  233. self.grid.set(tx + i, ty + h - 1, None)
  234. elif wall_idx == 2:
  235. for i in range(1, h - 1):
  236. self.grid.set(tx, ty + i, None)
  237. elif wall_idx == 3:
  238. for i in range(1, w - 1):
  239. self.grid.set(tx + i, ty, None)
  240. else:
  241. assert False, "invalid wall index"
  242. # Mark the rooms as connected
  243. room.doors[wall_idx] = True
  244. assert neighbor is not None
  245. neighbor.doors[(wall_idx + 2) % 4] = True
  246. def place_agent(
  247. self, i: int | None = None, j: int | None = None, rand_dir: bool = True
  248. ) -> np.ndarray:
  249. """
  250. Place the agent in a room
  251. """
  252. if i is None:
  253. i = self._rand_int(0, self.num_cols)
  254. if j is None:
  255. j = self._rand_int(0, self.num_rows)
  256. room = self.room_grid[j][i]
  257. # Find a position that is not right in front of an object
  258. while True:
  259. super().place_agent(room.top, room.size, rand_dir, max_tries=1000)
  260. front_cell = self.grid.get(*self.front_pos)
  261. if front_cell is None or front_cell.type == "wall":
  262. break
  263. return self.agent_pos
  264. def connect_all(
  265. self, door_colors: list[str] = COLOR_NAMES, max_itrs: int = 5000
  266. ) -> list[Door]:
  267. """
  268. Make sure that all rooms are reachable by the agent from its
  269. starting position
  270. """
  271. start_room = self.room_from_pos(*self.agent_pos)
  272. added_doors = []
  273. def find_reach():
  274. reach = set()
  275. stack = [start_room]
  276. while len(stack) > 0:
  277. room = stack.pop()
  278. if room in reach:
  279. continue
  280. reach.add(room)
  281. for i in range(0, 4):
  282. if room.doors[i]:
  283. stack.append(room.neighbors[i])
  284. return reach
  285. num_itrs = 0
  286. while True:
  287. # This is to handle rare situations where random sampling produces
  288. # a level that cannot be connected, producing in an infinite loop
  289. if num_itrs > max_itrs:
  290. raise RecursionError("connect_all failed")
  291. num_itrs += 1
  292. # If all rooms are reachable, stop
  293. reach = find_reach()
  294. if len(reach) == self.num_rows * self.num_cols:
  295. break
  296. # Pick a random room and door position
  297. i = self._rand_int(0, self.num_cols)
  298. j = self._rand_int(0, self.num_rows)
  299. k = self._rand_int(0, 4)
  300. room = self.get_room(i, j)
  301. # If there is already a door there, skip
  302. if not room.door_pos[k] or room.doors[k]:
  303. continue
  304. neighbor_room = room.neighbors[k]
  305. assert neighbor_room is not None
  306. if room.locked or neighbor_room.locked:
  307. continue
  308. color = self._rand_elem(door_colors)
  309. door, _ = self.add_door(i, j, k, color, False)
  310. added_doors.append(door)
  311. return added_doors
  312. def add_distractors(
  313. self,
  314. i: int | None = None,
  315. j: int | None = None,
  316. num_distractors: int = 10,
  317. all_unique: bool = True,
  318. ) -> list[WorldObj]:
  319. """
  320. Add random objects that can potentially distract/confuse the agent.
  321. """
  322. # Collect a list of existing objects
  323. objs = []
  324. for row in self.room_grid:
  325. for room in row:
  326. for obj in room.objs:
  327. objs.append((obj.type, obj.color))
  328. # List of distractors added
  329. dists = []
  330. while len(dists) < num_distractors:
  331. color = self._rand_elem(COLOR_NAMES)
  332. type = self._rand_elem(["key", "ball", "box"])
  333. obj = (type, color)
  334. if all_unique and obj in objs:
  335. continue
  336. # Add the object to a random room if no room specified
  337. room_i = i
  338. room_j = j
  339. if room_i is None:
  340. room_i = self._rand_int(0, self.num_cols)
  341. if room_j is None:
  342. room_j = self._rand_int(0, self.num_rows)
  343. dist, pos = self.add_object(room_i, room_j, *obj)
  344. objs.append(obj)
  345. dists.append(dist)
  346. return dists