memory.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import numpy as np
  2. from gym_minigrid.minigrid import Ball, Grid, Key, MiniGridEnv, MissionSpace, Wall
  3. class MemoryEnv(MiniGridEnv):
  4. """
  5. ### Description
  6. This environment is a memory test. The agent starts in a small room where it
  7. sees an object. It then has to go through a narrow hallway which ends in a
  8. split. At each end of the split there is an object, one of which is the same
  9. as the object in the starting room. The agent has to remember the initial
  10. object, and go to the matching object at split.
  11. ### Mission Space
  12. "go to the matching object at the end of the hallway"
  13. ### Action Space
  14. | Num | Name | Action |
  15. |-----|--------------|---------------------------|
  16. | 0 | left | Turn left |
  17. | 1 | right | Turn right |
  18. | 2 | forward | Move forward |
  19. | 3 | pickup | Pick up an object |
  20. | 4 | drop | Unused |
  21. | 5 | toggle | Toggle/activate an object |
  22. | 6 | done | Unused |
  23. ### Observation Encoding
  24. - Each tile is encoded as a 3 dimensional tuple:
  25. `(OBJECT_IDX, COLOR_IDX, STATE)`
  26. - `OBJECT_TO_IDX` and `COLOR_TO_IDX` mapping can be found in
  27. [gym_minigrid/minigrid.py](gym_minigrid/minigrid.py)
  28. - `STATE` refers to the door state with 0=open, 1=closed and 2=locked
  29. ### Rewards
  30. A reward of '1' is given for success, and '0' for failure.
  31. ### Termination
  32. The episode ends if any one of the following conditions is met:
  33. 1. The agent reaches the correct matching object.
  34. 2. The agent reaches the wrong matching object.
  35. 3. Timeout (see `max_steps`).
  36. ### Registered Configurations
  37. S: size of map SxS.
  38. - `MiniGrid-MemoryS17Random-v0`
  39. - `MiniGrid-MemoryS13Random-v0`
  40. - `MiniGrid-MemoryS13-v0`
  41. - `MiniGrid-MemoryS11-v0`
  42. """
  43. def __init__(self, size=8, random_length=False, **kwargs):
  44. self.size = size
  45. self.random_length = random_length
  46. mission_space = MissionSpace(
  47. mission_func=lambda: "go to the matching object at the end of the hallway"
  48. )
  49. super().__init__(
  50. mission_space=mission_space,
  51. width=size,
  52. height=size,
  53. max_steps=5 * size**2,
  54. # Set this to True for maximum speed
  55. see_through_walls=False,
  56. **kwargs
  57. )
  58. def _gen_grid(self, width, height):
  59. self.grid = Grid(width, height)
  60. # Generate the surrounding walls
  61. self.grid.horz_wall(0, 0)
  62. self.grid.horz_wall(0, height - 1)
  63. self.grid.vert_wall(0, 0)
  64. self.grid.vert_wall(width - 1, 0)
  65. assert height % 2 == 1
  66. upper_room_wall = height // 2 - 2
  67. lower_room_wall = height // 2 + 2
  68. if self.random_length:
  69. hallway_end = self._rand_int(4, width - 2)
  70. else:
  71. hallway_end = width - 3
  72. # Start room
  73. for i in range(1, 5):
  74. self.grid.set(i, upper_room_wall, Wall())
  75. self.grid.set(i, lower_room_wall, Wall())
  76. self.grid.set(4, upper_room_wall + 1, Wall())
  77. self.grid.set(4, lower_room_wall - 1, Wall())
  78. # Horizontal hallway
  79. for i in range(5, hallway_end):
  80. self.grid.set(i, upper_room_wall + 1, Wall())
  81. self.grid.set(i, lower_room_wall - 1, Wall())
  82. # Vertical hallway
  83. for j in range(0, height):
  84. if j != height // 2:
  85. self.grid.set(hallway_end, j, Wall())
  86. self.grid.set(hallway_end + 2, j, Wall())
  87. # Fix the player's start position and orientation
  88. self.agent_pos = np.array((self._rand_int(1, hallway_end + 1), height // 2))
  89. self.agent_dir = 0
  90. # Place objects
  91. start_room_obj = self._rand_elem([Key, Ball])
  92. self.grid.set(1, height // 2 - 1, start_room_obj("green"))
  93. other_objs = self._rand_elem([[Ball, Key], [Key, Ball]])
  94. pos0 = (hallway_end + 1, height // 2 - 2)
  95. pos1 = (hallway_end + 1, height // 2 + 2)
  96. self.grid.set(*pos0, other_objs[0]("green"))
  97. self.grid.set(*pos1, other_objs[1]("green"))
  98. # Choose the target objects
  99. if start_room_obj == other_objs[0]:
  100. self.success_pos = (pos0[0], pos0[1] + 1)
  101. self.failure_pos = (pos1[0], pos1[1] - 1)
  102. else:
  103. self.success_pos = (pos1[0], pos1[1] - 1)
  104. self.failure_pos = (pos0[0], pos0[1] + 1)
  105. self.mission = "go to the matching object at the end of the hallway"
  106. def step(self, action):
  107. if action == self.Actions.pickup:
  108. action = self.Actions.toggle
  109. obs, reward, terminated, truncated, info = super().step(action)
  110. if tuple(self.agent_pos) == self.success_pos:
  111. reward = self._reward()
  112. terminated = True
  113. if tuple(self.agent_pos) == self.failure_pos:
  114. reward = 0
  115. terminated = True
  116. return obs, reward, terminated, truncated, info