| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 | from gym_minigrid.minigrid import *from gym_minigrid.register import registerclass DistShiftEnv(MiniGridEnv):    """    Distributional shift environment.    """    def __init__(        self,        width=9,        height=7,        agent_start_pos=(1,1),        agent_start_dir=0,        strip2_row=2    ):        self.agent_start_pos = agent_start_pos        self.agent_start_dir = agent_start_dir        self.goal_pos = (width-2, 1)        self.strip2_row = strip2_row        super().__init__(            width=width,            height=height,            max_steps=4*width*height,            # Set this to True for maximum speed            see_through_walls=True        )    def _gen_grid(self, width, height):        # Create an empty grid        self.grid = Grid(width, height)        # Generate the surrounding walls        self.grid.wall_rect(0, 0, width, height)        # Place a goal square in the bottom-right corner        self.put_obj(Goal(), *self.goal_pos)        # Place the lava rows        for i in range(self.width - 6):            self.grid.set(3+i, 1, Lava())            self.grid.set(3+i, self.strip2_row, Lava())        # Place the agent        if self.agent_start_pos is not None:            self.agent_pos = self.agent_start_pos            self.agent_dir = self.agent_start_dir        else:            self.place_agent()        self.mission = "get to the green goal square"class DistShift1(DistShiftEnv):    def __init__(self):        super().__init__(strip2_row=2)class DistShift2(DistShiftEnv):    def __init__(self):        super().__init__(strip2_row=5)register(    id='MiniGrid-DistShift1-v0',    entry_point='gym_minigrid.envs:DistShift1')register(    id='MiniGrid-DistShift2-v0',    entry_point='gym_minigrid.envs:DistShift2')
 |