levelgen.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. """
  2. Copied and adapted from https://github.com/mila-iqia/babyai
  3. """
  4. from __future__ import annotations
  5. from minigrid.core.constants import COLOR_NAMES
  6. from minigrid.core.roomgrid import Room
  7. from minigrid.envs.babyai.core.roomgrid_level import RoomGridLevel
  8. from minigrid.envs.babyai.core.verifier import (
  9. LOC_NAMES,
  10. OBJ_TYPES,
  11. OBJ_TYPES_NOT_DOOR,
  12. AfterInstr,
  13. AndInstr,
  14. BeforeInstr,
  15. GoToInstr,
  16. ObjDesc,
  17. OpenInstr,
  18. PickupInstr,
  19. PutNextInstr,
  20. )
  21. class LevelGen(RoomGridLevel):
  22. """
  23. Level generator which attempts to produce every possible sentence in
  24. the baby language as an instruction.
  25. """
  26. def __init__(
  27. self,
  28. room_size=8,
  29. num_rows=3,
  30. num_cols=3,
  31. num_dists=18,
  32. locked_room_prob=0.5,
  33. locations=True,
  34. unblocking=True,
  35. implicit_unlock=True,
  36. action_kinds=["goto", "pickup", "open", "putnext"],
  37. instr_kinds=["action", "and", "seq"],
  38. **kwargs,
  39. ):
  40. self.num_dists = num_dists
  41. self.locked_room_prob = locked_room_prob
  42. self.locations = locations
  43. self.unblocking = unblocking
  44. self.implicit_unlock = implicit_unlock
  45. self.action_kinds = action_kinds
  46. self.instr_kinds = instr_kinds
  47. self.locked_room = None
  48. super().__init__(
  49. room_size=room_size, num_rows=num_rows, num_cols=num_cols, **kwargs
  50. )
  51. def gen_mission(self):
  52. if self._rand_float(0, 1) < self.locked_room_prob:
  53. self.add_locked_room()
  54. self.connect_all()
  55. self.add_distractors(num_distractors=self.num_dists, all_unique=False)
  56. # The agent must be placed after all the object to respect constraints
  57. while True:
  58. self.place_agent()
  59. start_room = self.room_from_pos(*self.agent_pos)
  60. # Ensure that we are not placing the agent in the locked room
  61. if start_room is self.locked_room:
  62. continue
  63. break
  64. # If no unblocking required, make sure all objects are
  65. # reachable without unblocking
  66. if not self.unblocking:
  67. self.check_objs_reachable()
  68. # Generate random instructions
  69. self.instrs = self.rand_instr(
  70. action_kinds=self.action_kinds, instr_kinds=self.instr_kinds
  71. )
  72. def add_locked_room(self):
  73. # Until we've successfully added a locked room
  74. while True:
  75. i = self._rand_int(0, self.num_cols)
  76. j = self._rand_int(0, self.num_rows)
  77. door_idx = self._rand_int(0, 4)
  78. self.locked_room = self.get_room(i, j)
  79. # Don't add a locked door in an external wall
  80. if self.locked_room.neighbors[door_idx] is None:
  81. continue
  82. door, _ = self.add_door(i, j, door_idx, locked=True)
  83. # Done adding locked room
  84. break
  85. # Until we find a room to put the key
  86. while True:
  87. i = self._rand_int(0, self.num_cols)
  88. j = self._rand_int(0, self.num_rows)
  89. key_room = self.get_room(i, j)
  90. if key_room is self.locked_room:
  91. continue
  92. self.add_object(i, j, "key", door.color)
  93. break
  94. def rand_obj(self, types=OBJ_TYPES, colors=COLOR_NAMES, max_tries=100):
  95. """
  96. Generate a random object descriptor
  97. """
  98. num_tries = 0
  99. # Keep trying until we find a matching object
  100. while True:
  101. if num_tries > max_tries:
  102. raise RecursionError("failed to find suitable object")
  103. num_tries += 1
  104. color = self._rand_elem([None, *colors])
  105. type = self._rand_elem(types)
  106. loc = None
  107. if self.locations and self._rand_bool():
  108. loc = self._rand_elem(LOC_NAMES)
  109. desc = ObjDesc(type, color, loc)
  110. # Find all objects matching the descriptor
  111. objs, poss = desc.find_matching_objs(self)
  112. # The description must match at least one object
  113. if len(objs) == 0:
  114. continue
  115. # If no implicit unlocking is required
  116. if not self.implicit_unlock and isinstance(self.locked_room, Room):
  117. locked_room = self.locked_room
  118. # Check that at least one object is not in the locked room
  119. pos_not_locked = list(
  120. filter(lambda p: not locked_room.pos_inside(*p), poss)
  121. )
  122. if len(pos_not_locked) == 0:
  123. continue
  124. # Found a valid object description
  125. return desc
  126. def rand_instr(self, action_kinds, instr_kinds, depth=0):
  127. """
  128. Generate random instructions
  129. """
  130. kind = self._rand_elem(instr_kinds)
  131. if kind == "action":
  132. action = self._rand_elem(action_kinds)
  133. if action == "goto":
  134. return GoToInstr(self.rand_obj())
  135. elif action == "pickup":
  136. return PickupInstr(self.rand_obj(types=OBJ_TYPES_NOT_DOOR))
  137. elif action == "open":
  138. return OpenInstr(self.rand_obj(types=["door"]))
  139. elif action == "putnext":
  140. return PutNextInstr(
  141. self.rand_obj(types=OBJ_TYPES_NOT_DOOR), self.rand_obj()
  142. )
  143. assert False
  144. elif kind == "and":
  145. instr_a = self.rand_instr(
  146. action_kinds=action_kinds, instr_kinds=["action"], depth=depth + 1
  147. )
  148. instr_b = self.rand_instr(
  149. action_kinds=action_kinds, instr_kinds=["action"], depth=depth + 1
  150. )
  151. return AndInstr(instr_a, instr_b)
  152. elif kind == "seq":
  153. instr_a = self.rand_instr(
  154. action_kinds=action_kinds,
  155. instr_kinds=["action", "and"],
  156. depth=depth + 1,
  157. )
  158. instr_b = self.rand_instr(
  159. action_kinds=action_kinds,
  160. instr_kinds=["action", "and"],
  161. depth=depth + 1,
  162. )
  163. kind = self._rand_elem(["before", "after"])
  164. if kind == "before":
  165. return BeforeInstr(instr_a, instr_b)
  166. elif kind == "after":
  167. return AfterInstr(instr_a, instr_b)
  168. assert False
  169. assert False