123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303 |
- """
- Copied and adapted from https://github.com/mila-iqia/babyai
- """
- from __future__ import annotations
- from minigrid.core.roomgrid import RoomGrid
- from minigrid.envs.babyai.core.verifier import (
- ActionInstr,
- AfterInstr,
- AndInstr,
- BeforeInstr,
- PutNextInstr,
- SeqInstr,
- )
- from minigrid.minigrid_env import MissionSpace
- class RejectSampling(Exception):
- """
- Exception used for rejection sampling
- """
- pass
- class BabyAIMissionSpace(MissionSpace):
- """
- Class that mimics the behavior required by minigrid.minigrid_env.MissionSpace,
- but does not change how missions are generated for BabyAI. It silences
- the gymnasium.utils.passive_env_checker given that it considers all strings to be
- plausible samples.
- """
- def __init__(self):
- super().__init__(mission_func=self._gen_mission)
- @staticmethod
- def _gen_mission():
- return "go"
- def contains(self, x: str):
- return True
- class RoomGridLevel(RoomGrid):
- """
- Base for levels based on RoomGrid.
- A level, generates missions generated from
- one or more patterns. Levels should produce a family of missions
- of approximately similar difficulty.
- """
- def __init__(self, room_size=8, max_steps: int | None = None, **kwargs):
- mission_space = BabyAIMissionSpace()
- # If `max_steps` arg is passed it will be fixed for every episode,
- # if not it will vary after reset depending on the maze size.
- self.fixed_max_steps = False
- if max_steps is not None:
- self.fixed_max_steps = True
- else:
- max_steps = 0 # only for initialization
- super().__init__(
- room_size=room_size,
- mission_space=mission_space,
- max_steps=max_steps,
- **kwargs,
- )
- def reset(self, **kwargs):
- obs = super().reset(**kwargs)
- # Recreate the verifier
- self.instrs.reset_verifier(self)
- # Compute the time step limit based on the maze size and instructions
- nav_time_room = self.room_size**2
- nav_time_maze = nav_time_room * self.num_rows * self.num_cols
- num_navs = self.num_navs_needed(self.instrs)
- if not self.fixed_max_steps:
- self.max_steps = num_navs * nav_time_maze
- return obs
- def step(self, action):
- obs, reward, terminated, truncated, info = super().step(action)
- # If we drop an object, we need to update its position in the environment
- if action == self.actions.drop:
- self.update_objs_poss()
- # If we've successfully completed the mission
- status = self.instrs.verify(action)
- if status == "success":
- terminated = True
- reward = self._reward()
- elif status == "failure":
- terminated = True
- reward = 0
- return obs, reward, terminated, truncated, info
- def update_objs_poss(self, instr=None):
- if instr is None:
- instr = self.instrs
- if (
- isinstance(instr, BeforeInstr)
- or isinstance(instr, AndInstr)
- or isinstance(instr, AfterInstr)
- ):
- self.update_objs_poss(instr.instr_a)
- self.update_objs_poss(instr.instr_b)
- else:
- instr.update_objs_poss()
- def _gen_grid(self, width, height):
- # We catch RecursionError to deal with rare cases where
- # rejection sampling gets stuck in an infinite loop
- while True:
- try:
- super()._gen_grid(width, height)
- # Generate the mission
- self.gen_mission()
- # Validate the instructions
- self.validate_instrs(self.instrs)
- except RecursionError as error:
- print("Timeout during mission generation:", error)
- continue
- except RejectSampling as error:
- print("Sampling rejected:", error)
- continue
- break
- # Generate the surface form for the instructions
- self.surface = self.instrs.surface(self)
- self.mission = self.surface
- def validate_instrs(self, instr):
- """
- Perform some validation on the generated instructions
- """
- # Gather the colors of locked doors
- colors_of_locked_doors = []
- if hasattr(self, "unblocking") and self.unblocking:
- for i in range(self.num_cols):
- for j in range(self.num_rows):
- room = self.get_room(i, j)
- for door in room.doors:
- if door and door.is_locked:
- colors_of_locked_doors.append(door.color)
- if isinstance(instr, PutNextInstr):
- # Resolve the objects referenced by the instruction
- instr.reset_verifier(self)
- # Check that the objects are not already next to each other
- if set(instr.desc_move.obj_set).intersection(set(instr.desc_fixed.obj_set)):
- raise RejectSampling(
- "there are objects that match both lhs and rhs of PutNext"
- )
- if instr.objs_next():
- raise RejectSampling("objs already next to each other")
- # Check that we are not asking to move an object next to itself
- move = instr.desc_move
- fixed = instr.desc_fixed
- if len(move.obj_set) == 1 and len(fixed.obj_set) == 1:
- if move.obj_set[0] is fixed.obj_set[0]:
- raise RejectSampling("cannot move an object next to itself")
- if isinstance(instr, ActionInstr):
- if not hasattr(self, "unblocking") or not self.unblocking:
- return
- # TODO: either relax this a bit or make the bot handle this super corner-y scenarios
- # Check that the instruction doesn't involve a key that matches the color of a locked door
- potential_objects = ("desc", "desc_move", "desc_fixed")
- for attr in potential_objects:
- if hasattr(instr, attr):
- obj = getattr(instr, attr)
- if obj.type == "key" and obj.color in colors_of_locked_doors:
- raise RejectSampling(
- "cannot do anything with/to a key that can be used to open a door"
- )
- return
- if isinstance(instr, SeqInstr):
- self.validate_instrs(instr.instr_a)
- self.validate_instrs(instr.instr_b)
- return
- assert False, "unhandled instruction type"
- def gen_mission(self):
- """
- Generate a mission (instructions and matching environment)
- Derived level classes should implement this method
- """
- raise NotImplementedError
- @property
- def level_name(self):
- return self.__class__.level_name
- @property
- def gym_id(self):
- return self.__class__.gym_id
- def num_navs_needed(self, instr) -> int:
- """
- Compute the maximum number of navigations needed to perform
- a simple or complex instruction
- """
- if isinstance(instr, PutNextInstr):
- return 2
- elif isinstance(instr, ActionInstr):
- return 1
- elif isinstance(instr, SeqInstr):
- na = self.num_navs_needed(instr.instr_a)
- nb = self.num_navs_needed(instr.instr_b)
- return na + nb
- else:
- raise NotImplementedError(
- "instr needs to be an instance of PutNextInstr, ActionInstr, or SeqInstr"
- )
- def open_all_doors(self):
- """
- Open all the doors in the maze
- """
- for i in range(self.num_cols):
- for j in range(self.num_rows):
- room = self.get_room(i, j)
- for door in room.doors:
- if door:
- door.is_open = True
- def check_objs_reachable(self, raise_exc=True):
- """
- Check that all objects are reachable from the agent's starting
- position without requiring any other object to be moved
- (without unblocking)
- """
- # Reachable positions
- reachable = set()
- # Work list
- stack = [self.agent_pos]
- while len(stack) > 0:
- i, j = stack.pop()
- if i < 0 or i >= self.grid.width or j < 0 or j >= self.grid.height:
- continue
- if (i, j) in reachable:
- continue
- # This position is reachable
- reachable.add((i, j))
- cell = self.grid.get(i, j)
- # If there is something other than a door in this cell, it
- # blocks reachability
- if cell and cell.type != "door":
- continue
- # Visit the horizontal and vertical neighbors
- stack.append((i + 1, j))
- stack.append((i - 1, j))
- stack.append((i, j + 1))
- stack.append((i, j - 1))
- # Check that all objects are reachable
- for i in range(self.grid.width):
- for j in range(self.grid.height):
- cell = self.grid.get(i, j)
- if not cell or cell.type == "wall":
- continue
- if (i, j) not in reachable:
- if not raise_exc:
- return False
- raise RejectSampling("unreachable object at " + str((i, j)))
- # All objects reachable
- return True
|