distshift.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from gym_minigrid.minigrid import *
  2. from gym_minigrid.register import register
  3. class DistShiftEnv(MiniGridEnv):
  4. """
  5. Distributional shift environment.
  6. """
  7. def __init__(
  8. self,
  9. width=9,
  10. height=7,
  11. agent_start_pos=(1,1),
  12. agent_start_dir=0,
  13. strip2_row=2
  14. ):
  15. self.agent_start_pos = agent_start_pos
  16. self.agent_start_dir = agent_start_dir
  17. self.goal_pos = (width-2, 1)
  18. self.strip2_row = strip2_row
  19. super().__init__(
  20. width=width,
  21. height=height,
  22. max_steps=4*width*height,
  23. # Set this to True for maximum speed
  24. see_through_walls=True
  25. )
  26. def _gen_grid(self, width, height):
  27. # Create an empty grid
  28. self.grid = Grid(width, height)
  29. # Generate the surrounding walls
  30. self.grid.wall_rect(0, 0, width, height)
  31. # Place a goal square in the bottom-right corner
  32. self.put_obj(Goal(), *self.goal_pos)
  33. # Place the lava rows
  34. for i in range(self.width - 6):
  35. self.grid.set(3+i, 1, Lava())
  36. self.grid.set(3+i, self.strip2_row, Lava())
  37. # Place the agent
  38. if self.agent_start_pos is not None:
  39. self.agent_pos = self.agent_start_pos
  40. self.agent_dir = self.agent_start_dir
  41. else:
  42. self.place_agent()
  43. self.mission = "get to the green goal square"
  44. class DistShift1(DistShiftEnv):
  45. def __init__(self):
  46. super().__init__(strip2_row=2)
  47. class DistShift2(DistShiftEnv):
  48. def __init__(self):
  49. super().__init__(strip2_row=5)
  50. register(
  51. id='MiniGrid-DistShift1-v0',
  52. entry_point='gym_minigrid.envs:DistShift1'
  53. )
  54. register(
  55. id='MiniGrid-DistShift2-v0',
  56. entry_point='gym_minigrid.envs:DistShift2'
  57. )