memory.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import numpy as np
  2. from minigrid.core.actions import Actions
  3. from minigrid.core.grid import Grid
  4. from minigrid.core.mission import MissionSpace
  5. from minigrid.core.world_object import Ball, Key, Wall
  6. from minigrid.minigrid import MiniGridEnv
  7. class MemoryEnv(MiniGridEnv):
  8. """
  9. ### Description
  10. This environment is a memory test. The agent starts in a small room where it
  11. sees an object. It then has to go through a narrow hallway which ends in a
  12. split. At each end of the split there is an object, one of which is the same
  13. as the object in the starting room. The agent has to remember the initial
  14. object, and go to the matching object at split.
  15. ### Mission Space
  16. "go to the matching object at the end of the hallway"
  17. ### Action Space
  18. | Num | Name | Action |
  19. |-----|--------------|---------------------------|
  20. | 0 | left | Turn left |
  21. | 1 | right | Turn right |
  22. | 2 | forward | Move forward |
  23. | 3 | pickup | Pick up an object |
  24. | 4 | drop | Unused |
  25. | 5 | toggle | Toggle/activate an object |
  26. | 6 | done | Unused |
  27. ### Observation Encoding
  28. - Each tile is encoded as a 3 dimensional tuple:
  29. `(OBJECT_IDX, COLOR_IDX, STATE)`
  30. - `OBJECT_TO_IDX` and `COLOR_TO_IDX` mapping can be found in
  31. [minigrid/minigrid.py](minigrid/minigrid.py)
  32. - `STATE` refers to the door state with 0=open, 1=closed and 2=locked
  33. ### Rewards
  34. A reward of '1' is given for success, and '0' for failure.
  35. ### Termination
  36. The episode ends if any one of the following conditions is met:
  37. 1. The agent reaches the correct matching object.
  38. 2. The agent reaches the wrong matching object.
  39. 3. Timeout (see `max_steps`).
  40. ### Registered Configurations
  41. S: size of map SxS.
  42. - `MiniGrid-MemoryS17Random-v0`
  43. - `MiniGrid-MemoryS13Random-v0`
  44. - `MiniGrid-MemoryS13-v0`
  45. - `MiniGrid-MemoryS11-v0`
  46. """
  47. def __init__(self, size=8, random_length=False, **kwargs):
  48. self.size = size
  49. self.random_length = random_length
  50. mission_space = MissionSpace(
  51. mission_func=lambda: "go to the matching object at the end of the hallway"
  52. )
  53. super().__init__(
  54. mission_space=mission_space,
  55. width=size,
  56. height=size,
  57. max_steps=5 * size**2,
  58. # Set this to True for maximum speed
  59. see_through_walls=False,
  60. **kwargs
  61. )
  62. def _gen_grid(self, width, height):
  63. self.grid = Grid(width, height)
  64. # Generate the surrounding walls
  65. self.grid.horz_wall(0, 0)
  66. self.grid.horz_wall(0, height - 1)
  67. self.grid.vert_wall(0, 0)
  68. self.grid.vert_wall(width - 1, 0)
  69. assert height % 2 == 1
  70. upper_room_wall = height // 2 - 2
  71. lower_room_wall = height // 2 + 2
  72. if self.random_length:
  73. hallway_end = self._rand_int(4, width - 2)
  74. else:
  75. hallway_end = width - 3
  76. # Start room
  77. for i in range(1, 5):
  78. self.grid.set(i, upper_room_wall, Wall())
  79. self.grid.set(i, lower_room_wall, Wall())
  80. self.grid.set(4, upper_room_wall + 1, Wall())
  81. self.grid.set(4, lower_room_wall - 1, Wall())
  82. # Horizontal hallway
  83. for i in range(5, hallway_end):
  84. self.grid.set(i, upper_room_wall + 1, Wall())
  85. self.grid.set(i, lower_room_wall - 1, Wall())
  86. # Vertical hallway
  87. for j in range(0, height):
  88. if j != height // 2:
  89. self.grid.set(hallway_end, j, Wall())
  90. self.grid.set(hallway_end + 2, j, Wall())
  91. # Fix the player's start position and orientation
  92. self.agent_pos = np.array((self._rand_int(1, hallway_end + 1), height // 2))
  93. self.agent_dir = 0
  94. # Place objects
  95. start_room_obj = self._rand_elem([Key, Ball])
  96. self.grid.set(1, height // 2 - 1, start_room_obj("green"))
  97. other_objs = self._rand_elem([[Ball, Key], [Key, Ball]])
  98. pos0 = (hallway_end + 1, height // 2 - 2)
  99. pos1 = (hallway_end + 1, height // 2 + 2)
  100. self.grid.set(*pos0, other_objs[0]("green"))
  101. self.grid.set(*pos1, other_objs[1]("green"))
  102. # Choose the target objects
  103. if start_room_obj == other_objs[0]:
  104. self.success_pos = (pos0[0], pos0[1] + 1)
  105. self.failure_pos = (pos1[0], pos1[1] - 1)
  106. else:
  107. self.success_pos = (pos1[0], pos1[1] - 1)
  108. self.failure_pos = (pos0[0], pos0[1] + 1)
  109. self.mission = "go to the matching object at the end of the hallway"
  110. def step(self, action):
  111. if action == Actions.pickup:
  112. action = Actions.toggle
  113. obs, reward, terminated, truncated, info = super().step(action)
  114. if tuple(self.agent_pos) == self.success_pos:
  115. reward = self._reward()
  116. terminated = True
  117. if tuple(self.agent_pos) == self.failure_pos:
  118. reward = 0
  119. terminated = True
  120. return obs, reward, terminated, truncated, info