roomgrid_level.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. """
  2. Copied and adapted from https://github.com/mila-iqia/babyai
  3. """
  4. from minigrid.core.roomgrid import RoomGrid
  5. from minigrid.envs.babyai.core.verifier import (
  6. ActionInstr,
  7. AfterInstr,
  8. AndInstr,
  9. BeforeInstr,
  10. PutNextInstr,
  11. SeqInstr,
  12. )
  13. from minigrid.minigrid_env import MissionSpace
  14. class RejectSampling(Exception):
  15. """
  16. Exception used for rejection sampling
  17. """
  18. pass
  19. class BabyAIMissionSpace(MissionSpace):
  20. """
  21. Class that mimics the behavior required by minigrid.minigrid_env.MissionSpace,
  22. but does not change how missions are generated for BabyAI. It silences
  23. the gymnasium.utils.passive_env_checker given that it considers all strings to be
  24. plausible samples.
  25. """
  26. def __init__(self):
  27. super().__init__(mission_func=lambda: "go")
  28. def contains(self, x: str):
  29. return True
  30. class RoomGridLevel(RoomGrid):
  31. """
  32. Base for levels based on RoomGrid.
  33. A level, generates missions generated from
  34. one or more patterns. Levels should produce a family of missions
  35. of approximately similar difficulty.
  36. """
  37. def __init__(self, room_size=8, **kwargs):
  38. mission_space = BabyAIMissionSpace()
  39. super().__init__(room_size=room_size, mission_space=mission_space, **kwargs)
  40. def reset(self, **kwargs):
  41. obs = super().reset(**kwargs)
  42. # Recreate the verifier
  43. self.instrs.reset_verifier(self)
  44. # Compute the time step limit based on the maze size and instructions
  45. nav_time_room = self.room_size**2
  46. nav_time_maze = nav_time_room * self.num_rows * self.num_cols
  47. num_navs = self.num_navs_needed(self.instrs)
  48. self.max_steps = num_navs * nav_time_maze
  49. return obs
  50. def step(self, action):
  51. obs, reward, terminated, truncated, info = super().step(action)
  52. # If we drop an object, we need to update its position in the environment
  53. if action == self.actions.drop:
  54. self.update_objs_poss()
  55. # If we've successfully completed the mission
  56. status = self.instrs.verify(action)
  57. if status == "success":
  58. terminated = True
  59. reward = self._reward()
  60. elif status == "failure":
  61. terminated = True
  62. reward = 0
  63. return obs, reward, terminated, truncated, info
  64. def update_objs_poss(self, instr=None):
  65. if instr is None:
  66. instr = self.instrs
  67. if (
  68. isinstance(instr, BeforeInstr)
  69. or isinstance(instr, AndInstr)
  70. or isinstance(instr, AfterInstr)
  71. ):
  72. self.update_objs_poss(instr.instr_a)
  73. self.update_objs_poss(instr.instr_b)
  74. else:
  75. instr.update_objs_poss()
  76. def _gen_grid(self, width, height):
  77. # We catch RecursionError to deal with rare cases where
  78. # rejection sampling gets stuck in an infinite loop
  79. while True:
  80. try:
  81. super()._gen_grid(width, height)
  82. # Generate the mission
  83. self.gen_mission()
  84. # Validate the instructions
  85. self.validate_instrs(self.instrs)
  86. except RecursionError as error:
  87. print("Timeout during mission generation:", error)
  88. continue
  89. except RejectSampling as error:
  90. print("Sampling rejected:", error)
  91. continue
  92. break
  93. # Generate the surface form for the instructions
  94. self.surface = self.instrs.surface(self)
  95. self.mission = self.surface
  96. def validate_instrs(self, instr):
  97. """
  98. Perform some validation on the generated instructions
  99. """
  100. # Gather the colors of locked doors
  101. colors_of_locked_doors = []
  102. if hasattr(self, "unblocking") and self.unblocking:
  103. for i in range(self.num_cols):
  104. for j in range(self.num_rows):
  105. room = self.get_room(i, j)
  106. for door in room.doors:
  107. if door and door.is_locked:
  108. colors_of_locked_doors.append(door.color)
  109. if isinstance(instr, PutNextInstr):
  110. # Resolve the objects referenced by the instruction
  111. instr.reset_verifier(self)
  112. # Check that the objects are not already next to each other
  113. if set(instr.desc_move.obj_set).intersection(set(instr.desc_fixed.obj_set)):
  114. raise RejectSampling(
  115. "there are objects that match both lhs and rhs of PutNext"
  116. )
  117. if instr.objs_next():
  118. raise RejectSampling("objs already next to each other")
  119. # Check that we are not asking to move an object next to itself
  120. move = instr.desc_move
  121. fixed = instr.desc_fixed
  122. if len(move.obj_set) == 1 and len(fixed.obj_set) == 1:
  123. if move.obj_set[0] is fixed.obj_set[0]:
  124. raise RejectSampling("cannot move an object next to itself")
  125. if isinstance(instr, ActionInstr):
  126. if not hasattr(self, "unblocking") or not self.unblocking:
  127. return
  128. # TODO: either relax this a bit or make the bot handle this super corner-y scenarios
  129. # Check that the instruction doesn't involve a key that matches the color of a locked door
  130. potential_objects = ("desc", "desc_move", "desc_fixed")
  131. for attr in potential_objects:
  132. if hasattr(instr, attr):
  133. obj = getattr(instr, attr)
  134. if obj.type == "key" and obj.color in colors_of_locked_doors:
  135. raise RejectSampling(
  136. "cannot do anything with/to a key that can be used to open a door"
  137. )
  138. return
  139. if isinstance(instr, SeqInstr):
  140. self.validate_instrs(instr.instr_a)
  141. self.validate_instrs(instr.instr_b)
  142. return
  143. assert False, "unhandled instruction type"
  144. def gen_mission(self):
  145. """
  146. Generate a mission (instructions and matching environment)
  147. Derived level classes should implement this method
  148. """
  149. raise NotImplementedError
  150. @property
  151. def level_name(self):
  152. return self.__class__.level_name
  153. @property
  154. def gym_id(self):
  155. return self.__class__.gym_id
  156. def num_navs_needed(self, instr) -> int:
  157. """
  158. Compute the maximum number of navigations needed to perform
  159. a simple or complex instruction
  160. """
  161. if isinstance(instr, PutNextInstr):
  162. return 2
  163. elif isinstance(instr, ActionInstr):
  164. return 1
  165. elif isinstance(instr, SeqInstr):
  166. na = self.num_navs_needed(instr.instr_a)
  167. nb = self.num_navs_needed(instr.instr_b)
  168. return na + nb
  169. else:
  170. raise NotImplementedError(
  171. "instr needs to be an instance of PutNextInstr, ActionInstr, or SeqInstr"
  172. )
  173. def open_all_doors(self):
  174. """
  175. Open all the doors in the maze
  176. """
  177. for i in range(self.num_cols):
  178. for j in range(self.num_rows):
  179. room = self.get_room(i, j)
  180. for door in room.doors:
  181. if door:
  182. door.is_open = True
  183. def check_objs_reachable(self, raise_exc=True):
  184. """
  185. Check that all objects are reachable from the agent's starting
  186. position without requiring any other object to be moved
  187. (without unblocking)
  188. """
  189. # Reachable positions
  190. reachable = set()
  191. # Work list
  192. stack = [self.agent_pos]
  193. while len(stack) > 0:
  194. i, j = stack.pop()
  195. if i < 0 or i >= self.grid.width or j < 0 or j >= self.grid.height:
  196. continue
  197. if (i, j) in reachable:
  198. continue
  199. # This position is reachable
  200. reachable.add((i, j))
  201. cell = self.grid.get(i, j)
  202. # If there is something other than a door in this cell, it
  203. # blocks reachability
  204. if cell and cell.type != "door":
  205. continue
  206. # Visit the horizontal and vertical neighbors
  207. stack.append((i + 1, j))
  208. stack.append((i - 1, j))
  209. stack.append((i, j + 1))
  210. stack.append((i, j - 1))
  211. # Check that all objects are reachable
  212. for i in range(self.grid.width):
  213. for j in range(self.grid.height):
  214. cell = self.grid.get(i, j)
  215. if not cell or cell.type == "wall":
  216. continue
  217. if (i, j) not in reachable:
  218. if not raise_exc:
  219. return False
  220. raise RejectSampling("unreachable object at " + str((i, j)))
  221. # All objects reachable
  222. return True