distshift.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from gym_minigrid.minigrid import Goal, Grid, Lava, MiniGridEnv
  2. from gym_minigrid.register import register
  3. class DistShiftEnv(MiniGridEnv):
  4. """
  5. Distributional shift environment.
  6. """
  7. def __init__(
  8. self, width=9, height=7, agent_start_pos=(1, 1), agent_start_dir=0, strip2_row=2
  9. ):
  10. self.agent_start_pos = agent_start_pos
  11. self.agent_start_dir = agent_start_dir
  12. self.goal_pos = (width - 2, 1)
  13. self.strip2_row = strip2_row
  14. super().__init__(
  15. width=width,
  16. height=height,
  17. max_steps=4 * width * height,
  18. # Set this to True for maximum speed
  19. see_through_walls=True,
  20. )
  21. def _gen_grid(self, width, height):
  22. # Create an empty grid
  23. self.grid = Grid(width, height)
  24. # Generate the surrounding walls
  25. self.grid.wall_rect(0, 0, width, height)
  26. # Place a goal square in the bottom-right corner
  27. self.put_obj(Goal(), *self.goal_pos)
  28. # Place the lava rows
  29. for i in range(self.width - 6):
  30. self.grid.set(3 + i, 1, Lava())
  31. self.grid.set(3 + i, self.strip2_row, Lava())
  32. # Place the agent
  33. if self.agent_start_pos is not None:
  34. self.agent_pos = self.agent_start_pos
  35. self.agent_dir = self.agent_start_dir
  36. else:
  37. self.place_agent()
  38. self.mission = "get to the green goal square"
  39. class DistShift1(DistShiftEnv):
  40. def __init__(self):
  41. super().__init__(strip2_row=2)
  42. class DistShift2(DistShiftEnv):
  43. def __init__(self):
  44. super().__init__(strip2_row=5)
  45. register(id="MiniGrid-DistShift1-v0", entry_point="gym_minigrid.envs:DistShift1")
  46. register(id="MiniGrid-DistShift2-v0", entry_point="gym_minigrid.envs:DistShift2")