1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- from gym_minigrid.minigrid import *
- from gym_minigrid.register import register
- class DistShiftEnv(MiniGridEnv):
- """
- Distributional shift environment.
- """
- def __init__(
- self,
- width=9,
- height=7,
- agent_start_pos=(1,1),
- agent_start_dir=0,
- strip2_row=2
- ):
- self.agent_start_pos = agent_start_pos
- self.agent_start_dir = agent_start_dir
- self.goal_pos = (width-2, 1)
- self.strip2_row = strip2_row
- super().__init__(
- width=width,
- height=height,
- max_steps=4*width*height,
- # Set this to True for maximum speed
- see_through_walls=True
- )
- def _gen_grid(self, width, height):
- # Create an empty grid
- self.grid = Grid(width, height)
- # Generate the surrounding walls
- self.grid.wall_rect(0, 0, width, height)
- # Place a goal square in the bottom-right corner
- self.put_obj(Goal(), *self.goal_pos)
- # Place the lava rows
- for i in range(self.width - 6):
- self.grid.set(3+i, 1, Lava())
- self.grid.set(3+i, self.strip2_row, Lava())
- # Place the agent
- if self.agent_start_pos is not None:
- self.agent_pos = self.agent_start_pos
- self.agent_dir = self.agent_start_dir
- else:
- self.place_agent()
- self.mission = "get to the green goal square"
- class DistShift1(DistShiftEnv):
- def __init__(self):
- super().__init__(strip2_row=2)
- class DistShift2(DistShiftEnv):
- def __init__(self):
- super().__init__(strip2_row=5)
- register(
- id='MiniGrid-DistShift1-v0',
- entry_point='gym_minigrid.envs:DistShift1'
- )
- register(
- id='MiniGrid-DistShift2-v0',
- entry_point='gym_minigrid.envs:DistShift2'
- )
|