distshift.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from gym_minigrid.minigrid import Goal, Grid, Lava, MiniGridEnv
  2. class DistShiftEnv(MiniGridEnv):
  3. """
  4. Distributional shift environment.
  5. """
  6. def __init__(
  7. self,
  8. width=9,
  9. height=7,
  10. agent_start_pos=(1, 1),
  11. agent_start_dir=0,
  12. strip2_row=2,
  13. **kwargs
  14. ):
  15. self.agent_start_pos = agent_start_pos
  16. self.agent_start_dir = agent_start_dir
  17. self.goal_pos = (width - 2, 1)
  18. self.strip2_row = strip2_row
  19. super().__init__(
  20. width=width,
  21. height=height,
  22. max_steps=4 * width * height,
  23. # Set this to True for maximum speed
  24. see_through_walls=True,
  25. **kwargs
  26. )
  27. def _gen_grid(self, width, height):
  28. # Create an empty grid
  29. self.grid = Grid(width, height)
  30. # Generate the surrounding walls
  31. self.grid.wall_rect(0, 0, width, height)
  32. # Place a goal square in the bottom-right corner
  33. self.put_obj(Goal(), *self.goal_pos)
  34. # Place the lava rows
  35. for i in range(self.width - 6):
  36. self.grid.set(3 + i, 1, Lava())
  37. self.grid.set(3 + i, self.strip2_row, Lava())
  38. # Place the agent
  39. if self.agent_start_pos is not None:
  40. self.agent_pos = self.agent_start_pos
  41. self.agent_dir = self.agent_start_dir
  42. else:
  43. self.place_agent()
  44. self.mission = "get to the green goal square"