roomgrid_level.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. """
  2. Copied and adapted from https://github.com/mila-iqia/babyai
  3. """
  4. from __future__ import annotations
  5. from minigrid.core.roomgrid import RoomGrid
  6. from minigrid.envs.babyai.core.verifier import (
  7. ActionInstr,
  8. AfterInstr,
  9. AndInstr,
  10. BeforeInstr,
  11. PutNextInstr,
  12. SeqInstr,
  13. )
  14. from minigrid.minigrid_env import MissionSpace
  15. class RejectSampling(Exception):
  16. """
  17. Exception used for rejection sampling
  18. """
  19. pass
  20. class BabyAIMissionSpace(MissionSpace):
  21. """
  22. Class that mimics the behavior required by minigrid.minigrid_env.MissionSpace,
  23. but does not change how missions are generated for BabyAI. It silences
  24. the gymnasium.utils.passive_env_checker given that it considers all strings to be
  25. plausible samples.
  26. """
  27. def __init__(self):
  28. super().__init__(mission_func=self._gen_mission)
  29. @staticmethod
  30. def _gen_mission():
  31. return "go"
  32. def contains(self, x: str):
  33. return True
  34. class RoomGridLevel(RoomGrid):
  35. """
  36. Base for levels based on RoomGrid.
  37. A level, generates missions generated from
  38. one or more patterns. Levels should produce a family of missions
  39. of approximately similar difficulty.
  40. """
  41. def __init__(self, room_size=8, max_steps: int | None = None, **kwargs):
  42. mission_space = BabyAIMissionSpace()
  43. # If `max_steps` arg is passed it will be fixed for every episode,
  44. # if not it will vary after reset depending on the maze size.
  45. self.fixed_max_steps = False
  46. if max_steps is not None:
  47. self.fixed_max_steps = True
  48. else:
  49. max_steps = 0 # only for initialization
  50. super().__init__(
  51. room_size=room_size,
  52. mission_space=mission_space,
  53. max_steps=max_steps,
  54. **kwargs,
  55. )
  56. def reset(self, **kwargs):
  57. obs = super().reset(**kwargs)
  58. # Recreate the verifier
  59. self.instrs.reset_verifier(self)
  60. # Compute the time step limit based on the maze size and instructions
  61. nav_time_room = self.room_size**2
  62. nav_time_maze = nav_time_room * self.num_rows * self.num_cols
  63. num_navs = self.num_navs_needed(self.instrs)
  64. if not self.fixed_max_steps:
  65. self.max_steps = num_navs * nav_time_maze
  66. return obs
  67. def step(self, action):
  68. obs, reward, terminated, truncated, info = super().step(action)
  69. # If we drop an object, we need to update its position in the environment
  70. if action == self.actions.drop:
  71. self.update_objs_poss()
  72. # If we've successfully completed the mission
  73. status = self.instrs.verify(action)
  74. if status == "success":
  75. terminated = True
  76. reward = self._reward()
  77. elif status == "failure":
  78. terminated = True
  79. reward = 0
  80. return obs, reward, terminated, truncated, info
  81. def update_objs_poss(self, instr=None):
  82. if instr is None:
  83. instr = self.instrs
  84. if (
  85. isinstance(instr, BeforeInstr)
  86. or isinstance(instr, AndInstr)
  87. or isinstance(instr, AfterInstr)
  88. ):
  89. self.update_objs_poss(instr.instr_a)
  90. self.update_objs_poss(instr.instr_b)
  91. else:
  92. instr.update_objs_poss()
  93. def _gen_grid(self, width, height):
  94. # We catch RecursionError to deal with rare cases where
  95. # rejection sampling gets stuck in an infinite loop
  96. while True:
  97. try:
  98. super()._gen_grid(width, height)
  99. # Generate the mission
  100. self.gen_mission()
  101. # Validate the instructions
  102. self.validate_instrs(self.instrs)
  103. except RecursionError as error:
  104. print("Timeout during mission generation:", error)
  105. continue
  106. except RejectSampling as error:
  107. print("Sampling rejected:", error)
  108. continue
  109. break
  110. # Generate the surface form for the instructions
  111. self.surface = self.instrs.surface(self)
  112. self.mission = self.surface
  113. def validate_instrs(self, instr):
  114. """
  115. Perform some validation on the generated instructions
  116. """
  117. # Gather the colors of locked doors
  118. colors_of_locked_doors = []
  119. if hasattr(self, "unblocking") and self.unblocking:
  120. for i in range(self.num_cols):
  121. for j in range(self.num_rows):
  122. room = self.get_room(i, j)
  123. for door in room.doors:
  124. if door and door.is_locked:
  125. colors_of_locked_doors.append(door.color)
  126. if isinstance(instr, PutNextInstr):
  127. # Resolve the objects referenced by the instruction
  128. instr.reset_verifier(self)
  129. # Check that the objects are not already next to each other
  130. if set(instr.desc_move.obj_set).intersection(set(instr.desc_fixed.obj_set)):
  131. raise RejectSampling(
  132. "there are objects that match both lhs and rhs of PutNext"
  133. )
  134. if instr.objs_next():
  135. raise RejectSampling("objs already next to each other")
  136. # Check that we are not asking to move an object next to itself
  137. move = instr.desc_move
  138. fixed = instr.desc_fixed
  139. if len(move.obj_set) == 1 and len(fixed.obj_set) == 1:
  140. if move.obj_set[0] is fixed.obj_set[0]:
  141. raise RejectSampling("cannot move an object next to itself")
  142. if isinstance(instr, ActionInstr):
  143. if not hasattr(self, "unblocking") or not self.unblocking:
  144. return
  145. # TODO: either relax this a bit or make the bot handle this super corner-y scenarios
  146. # Check that the instruction doesn't involve a key that matches the color of a locked door
  147. potential_objects = ("desc", "desc_move", "desc_fixed")
  148. for attr in potential_objects:
  149. if hasattr(instr, attr):
  150. obj = getattr(instr, attr)
  151. if obj.type == "key" and obj.color in colors_of_locked_doors:
  152. raise RejectSampling(
  153. "cannot do anything with/to a key that can be used to open a door"
  154. )
  155. return
  156. if isinstance(instr, SeqInstr):
  157. self.validate_instrs(instr.instr_a)
  158. self.validate_instrs(instr.instr_b)
  159. return
  160. assert False, "unhandled instruction type"
  161. def gen_mission(self):
  162. """
  163. Generate a mission (instructions and matching environment)
  164. Derived level classes should implement this method
  165. """
  166. raise NotImplementedError
  167. @property
  168. def level_name(self):
  169. return self.__class__.level_name
  170. @property
  171. def gym_id(self):
  172. return self.__class__.gym_id
  173. def num_navs_needed(self, instr) -> int:
  174. """
  175. Compute the maximum number of navigations needed to perform
  176. a simple or complex instruction
  177. """
  178. if isinstance(instr, PutNextInstr):
  179. return 2
  180. elif isinstance(instr, ActionInstr):
  181. return 1
  182. elif isinstance(instr, SeqInstr):
  183. na = self.num_navs_needed(instr.instr_a)
  184. nb = self.num_navs_needed(instr.instr_b)
  185. return na + nb
  186. else:
  187. raise NotImplementedError(
  188. "instr needs to be an instance of PutNextInstr, ActionInstr, or SeqInstr"
  189. )
  190. def open_all_doors(self):
  191. """
  192. Open all the doors in the maze
  193. """
  194. for i in range(self.num_cols):
  195. for j in range(self.num_rows):
  196. room = self.get_room(i, j)
  197. for door in room.doors:
  198. if door:
  199. door.is_open = True
  200. def check_objs_reachable(self, raise_exc=True):
  201. """
  202. Check that all objects are reachable from the agent's starting
  203. position without requiring any other object to be moved
  204. (without unblocking)
  205. """
  206. # Reachable positions
  207. reachable = set()
  208. # Work list
  209. stack = [self.agent_pos]
  210. while len(stack) > 0:
  211. i, j = stack.pop()
  212. if i < 0 or i >= self.grid.width or j < 0 or j >= self.grid.height:
  213. continue
  214. if (i, j) in reachable:
  215. continue
  216. # This position is reachable
  217. reachable.add((i, j))
  218. cell = self.grid.get(i, j)
  219. # If there is something other than a door in this cell, it
  220. # blocks reachability
  221. if cell and cell.type != "door":
  222. continue
  223. # Visit the horizontal and vertical neighbors
  224. stack.append((i + 1, j))
  225. stack.append((i - 1, j))
  226. stack.append((i, j + 1))
  227. stack.append((i, j - 1))
  228. # Check that all objects are reachable
  229. for i in range(self.grid.width):
  230. for j in range(self.grid.height):
  231. cell = self.grid.get(i, j)
  232. if not cell or cell.type == "wall":
  233. continue
  234. if (i, j) not in reachable:
  235. if not raise_exc:
  236. return False
  237. raise RejectSampling("unreachable object at " + str((i, j)))
  238. # All objects reachable
  239. return True