fourrooms.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from gym_minigrid.minigrid import Goal, Grid, MiniGridEnv, MissionSpace
  2. class FourRoomsEnv(MiniGridEnv):
  3. """
  4. Classic 4 rooms gridworld environment.
  5. Can specify agent and goal position, if not it set at random.
  6. """
  7. def __init__(self, agent_pos=None, goal_pos=None, **kwargs):
  8. self._agent_default_pos = agent_pos
  9. self._goal_default_pos = goal_pos
  10. self.size = 19
  11. mission_space = MissionSpace(mission_func=lambda: "reach the goal")
  12. super().__init__(
  13. mission_space=mission_space,
  14. width=self.size,
  15. height=self.size,
  16. max_steps=100,
  17. **kwargs
  18. )
  19. def _gen_grid(self, width, height):
  20. # Create the grid
  21. self.grid = Grid(width, height)
  22. # Generate the surrounding walls
  23. self.grid.horz_wall(0, 0)
  24. self.grid.horz_wall(0, height - 1)
  25. self.grid.vert_wall(0, 0)
  26. self.grid.vert_wall(width - 1, 0)
  27. room_w = width // 2
  28. room_h = height // 2
  29. # For each row of rooms
  30. for j in range(0, 2):
  31. # For each column
  32. for i in range(0, 2):
  33. xL = i * room_w
  34. yT = j * room_h
  35. xR = xL + room_w
  36. yB = yT + room_h
  37. # Bottom wall and door
  38. if i + 1 < 2:
  39. self.grid.vert_wall(xR, yT, room_h)
  40. pos = (xR, self._rand_int(yT + 1, yB))
  41. self.grid.set(*pos, None)
  42. # Bottom wall and door
  43. if j + 1 < 2:
  44. self.grid.horz_wall(xL, yB, room_w)
  45. pos = (self._rand_int(xL + 1, xR), yB)
  46. self.grid.set(*pos, None)
  47. # Randomize the player start position and orientation
  48. if self._agent_default_pos is not None:
  49. self.agent_pos = self._agent_default_pos
  50. self.grid.set(*self._agent_default_pos, None)
  51. # assuming random start direction
  52. self.agent_dir = self._rand_int(0, 4)
  53. else:
  54. self.place_agent()
  55. if self._goal_default_pos is not None:
  56. goal = Goal()
  57. self.put_obj(goal, *self._goal_default_pos)
  58. goal.init_pos, goal.cur_pos = self._goal_default_pos
  59. else:
  60. self.place_obj(Goal())