memory.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import numpy as np
  2. from gym_minigrid.minigrid import Ball, Grid, Key, MiniGridEnv, MissionSpace, Wall
  3. class MemoryEnv(MiniGridEnv):
  4. """
  5. This environment is a memory test. The agent starts in a small room
  6. where it sees an object. It then has to go through a narrow hallway
  7. which ends in a split. At each end of the split there is an object,
  8. one of which is the same as the object in the starting room. The
  9. agent has to remember the initial object, and go to the matching
  10. object at split.
  11. """
  12. def __init__(self, size=8, random_length=False, **kwargs):
  13. self.size = size
  14. self.random_length = random_length
  15. mission_space = MissionSpace(
  16. mission_func=lambda: "go to the matching object at the end of the hallway"
  17. )
  18. super().__init__(
  19. mission_space=mission_space,
  20. width=size,
  21. height=size,
  22. max_steps=5 * size**2,
  23. # Set this to True for maximum speed
  24. see_through_walls=False,
  25. **kwargs
  26. )
  27. def _gen_grid(self, width, height):
  28. self.grid = Grid(width, height)
  29. # Generate the surrounding walls
  30. self.grid.horz_wall(0, 0)
  31. self.grid.horz_wall(0, height - 1)
  32. self.grid.vert_wall(0, 0)
  33. self.grid.vert_wall(width - 1, 0)
  34. assert height % 2 == 1
  35. upper_room_wall = height // 2 - 2
  36. lower_room_wall = height // 2 + 2
  37. if self.random_length:
  38. hallway_end = self._rand_int(4, width - 2)
  39. else:
  40. hallway_end = width - 3
  41. # Start room
  42. for i in range(1, 5):
  43. self.grid.set(i, upper_room_wall, Wall())
  44. self.grid.set(i, lower_room_wall, Wall())
  45. self.grid.set(4, upper_room_wall + 1, Wall())
  46. self.grid.set(4, lower_room_wall - 1, Wall())
  47. # Horizontal hallway
  48. for i in range(5, hallway_end):
  49. self.grid.set(i, upper_room_wall + 1, Wall())
  50. self.grid.set(i, lower_room_wall - 1, Wall())
  51. # Vertical hallway
  52. for j in range(0, height):
  53. if j != height // 2:
  54. self.grid.set(hallway_end, j, Wall())
  55. self.grid.set(hallway_end + 2, j, Wall())
  56. # Fix the player's start position and orientation
  57. self.agent_pos = np.array((self._rand_int(1, hallway_end + 1), height // 2))
  58. self.agent_dir = 0
  59. # Place objects
  60. start_room_obj = self._rand_elem([Key, Ball])
  61. self.grid.set(1, height // 2 - 1, start_room_obj("green"))
  62. other_objs = self._rand_elem([[Ball, Key], [Key, Ball]])
  63. pos0 = (hallway_end + 1, height // 2 - 2)
  64. pos1 = (hallway_end + 1, height // 2 + 2)
  65. self.grid.set(*pos0, other_objs[0]("green"))
  66. self.grid.set(*pos1, other_objs[1]("green"))
  67. # Choose the target objects
  68. if start_room_obj == other_objs[0]:
  69. self.success_pos = (pos0[0], pos0[1] + 1)
  70. self.failure_pos = (pos1[0], pos1[1] - 1)
  71. else:
  72. self.success_pos = (pos1[0], pos1[1] - 1)
  73. self.failure_pos = (pos0[0], pos0[1] + 1)
  74. self.mission = "go to the matching object at the end of the hallway"
  75. def step(self, action):
  76. if action == self.Actions.pickup:
  77. action = self.Actions.toggle
  78. obs, reward, terminated, truncated, info = super().step(action)
  79. if tuple(self.agent_pos) == self.success_pos:
  80. reward = self._reward()
  81. terminated = True
  82. if tuple(self.agent_pos) == self.failure_pos:
  83. reward = 0
  84. terminated = True
  85. return obs, reward, terminated, truncated, info