levelgen.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. """
  2. Copied and adapted from https://github.com/mila-iqia/babyai
  3. """
  4. from minigrid.core.constants import COLOR_NAMES
  5. from minigrid.core.roomgrid import Room
  6. from minigrid.envs.babyai.core.roomgrid_level import RoomGridLevel
  7. from minigrid.envs.babyai.core.verifier import (
  8. LOC_NAMES,
  9. OBJ_TYPES,
  10. OBJ_TYPES_NOT_DOOR,
  11. AfterInstr,
  12. AndInstr,
  13. BeforeInstr,
  14. GoToInstr,
  15. ObjDesc,
  16. OpenInstr,
  17. PickupInstr,
  18. PutNextInstr,
  19. )
  20. class LevelGen(RoomGridLevel):
  21. """
  22. Level generator which attempts to produce every possible sentence in
  23. the baby language as an instruction.
  24. """
  25. def __init__(
  26. self,
  27. room_size=8,
  28. num_rows=3,
  29. num_cols=3,
  30. num_dists=18,
  31. locked_room_prob=0.5,
  32. locations=True,
  33. unblocking=True,
  34. implicit_unlock=True,
  35. action_kinds=["goto", "pickup", "open", "putnext"],
  36. instr_kinds=["action", "and", "seq"],
  37. **kwargs
  38. ):
  39. self.num_dists = num_dists
  40. self.locked_room_prob = locked_room_prob
  41. self.locations = locations
  42. self.unblocking = unblocking
  43. self.implicit_unlock = implicit_unlock
  44. self.action_kinds = action_kinds
  45. self.instr_kinds = instr_kinds
  46. self.locked_room = None
  47. super().__init__(
  48. room_size=room_size, num_rows=num_rows, num_cols=num_cols, **kwargs
  49. )
  50. def gen_mission(self):
  51. if self._rand_float(0, 1) < self.locked_room_prob:
  52. self.add_locked_room()
  53. self.connect_all()
  54. self.add_distractors(num_distractors=self.num_dists, all_unique=False)
  55. # The agent must be placed after all the object to respect constraints
  56. while True:
  57. self.place_agent()
  58. start_room = self.room_from_pos(*self.agent_pos)
  59. # Ensure that we are not placing the agent in the locked room
  60. if start_room is self.locked_room:
  61. continue
  62. break
  63. # If no unblocking required, make sure all objects are
  64. # reachable without unblocking
  65. if not self.unblocking:
  66. self.check_objs_reachable()
  67. # Generate random instructions
  68. self.instrs = self.rand_instr(
  69. action_kinds=self.action_kinds, instr_kinds=self.instr_kinds
  70. )
  71. def add_locked_room(self):
  72. # Until we've successfully added a locked room
  73. while True:
  74. i = self._rand_int(0, self.num_cols)
  75. j = self._rand_int(0, self.num_rows)
  76. door_idx = self._rand_int(0, 4)
  77. self.locked_room = self.get_room(i, j)
  78. # Don't add a locked door in an external wall
  79. if self.locked_room.neighbors[door_idx] is None:
  80. continue
  81. door, _ = self.add_door(i, j, door_idx, locked=True)
  82. # Done adding locked room
  83. break
  84. # Until we find a room to put the key
  85. while True:
  86. i = self._rand_int(0, self.num_cols)
  87. j = self._rand_int(0, self.num_rows)
  88. key_room = self.get_room(i, j)
  89. if key_room is self.locked_room:
  90. continue
  91. self.add_object(i, j, "key", door.color)
  92. break
  93. def rand_obj(self, types=OBJ_TYPES, colors=COLOR_NAMES, max_tries=100):
  94. """
  95. Generate a random object descriptor
  96. """
  97. num_tries = 0
  98. # Keep trying until we find a matching object
  99. while True:
  100. if num_tries > max_tries:
  101. raise RecursionError("failed to find suitable object")
  102. num_tries += 1
  103. color = self._rand_elem([None, *colors])
  104. type = self._rand_elem(types)
  105. loc = None
  106. if self.locations and self._rand_bool():
  107. loc = self._rand_elem(LOC_NAMES)
  108. desc = ObjDesc(type, color, loc)
  109. # Find all objects matching the descriptor
  110. objs, poss = desc.find_matching_objs(self)
  111. # The description must match at least one object
  112. if len(objs) == 0:
  113. continue
  114. # If no implicit unlocking is required
  115. if not self.implicit_unlock and isinstance(self.locked_room, Room):
  116. locked_room = self.locked_room
  117. # Check that at least one object is not in the locked room
  118. pos_not_locked = list(
  119. filter(lambda p: not locked_room.pos_inside(*p), poss)
  120. )
  121. if len(pos_not_locked) == 0:
  122. continue
  123. # Found a valid object description
  124. return desc
  125. def rand_instr(self, action_kinds, instr_kinds, depth=0):
  126. """
  127. Generate random instructions
  128. """
  129. kind = self._rand_elem(instr_kinds)
  130. if kind == "action":
  131. action = self._rand_elem(action_kinds)
  132. if action == "goto":
  133. return GoToInstr(self.rand_obj())
  134. elif action == "pickup":
  135. return PickupInstr(self.rand_obj(types=OBJ_TYPES_NOT_DOOR))
  136. elif action == "open":
  137. return OpenInstr(self.rand_obj(types=["door"]))
  138. elif action == "putnext":
  139. return PutNextInstr(
  140. self.rand_obj(types=OBJ_TYPES_NOT_DOOR), self.rand_obj()
  141. )
  142. assert False
  143. elif kind == "and":
  144. instr_a = self.rand_instr(
  145. action_kinds=action_kinds, instr_kinds=["action"], depth=depth + 1
  146. )
  147. instr_b = self.rand_instr(
  148. action_kinds=action_kinds, instr_kinds=["action"], depth=depth + 1
  149. )
  150. return AndInstr(instr_a, instr_b)
  151. elif kind == "seq":
  152. instr_a = self.rand_instr(
  153. action_kinds=action_kinds,
  154. instr_kinds=["action", "and"],
  155. depth=depth + 1,
  156. )
  157. instr_b = self.rand_instr(
  158. action_kinds=action_kinds,
  159. instr_kinds=["action", "and"],
  160. depth=depth + 1,
  161. )
  162. kind = self._rand_elem(["before", "after"])
  163. if kind == "before":
  164. return BeforeInstr(instr_a, instr_b)
  165. elif kind == "after":
  166. return AfterInstr(instr_a, instr_b)
  167. assert False
  168. assert False