memory.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import numpy as np
  2. from gym_minigrid.minigrid import Ball, Grid, Key, MiniGridEnv, Wall
  3. from gym_minigrid.register import register
  4. class MemoryEnv(MiniGridEnv):
  5. """
  6. This environment is a memory test. The agent starts in a small room
  7. where it sees an object. It then has to go through a narrow hallway
  8. which ends in a split. At each end of the split there is an object,
  9. one of which is the same as the object in the starting room. The
  10. agent has to remember the initial object, and go to the matching
  11. object at split.
  12. """
  13. def __init__(self, size=8, random_length=False, **kwargs):
  14. self.random_length = random_length
  15. super().__init__(
  16. grid_size=size,
  17. max_steps=5 * size**2,
  18. # Set this to True for maximum speed
  19. see_through_walls=False,
  20. **kwargs
  21. )
  22. def _gen_grid(self, width, height):
  23. self.grid = Grid(width, height)
  24. # Generate the surrounding walls
  25. self.grid.horz_wall(0, 0)
  26. self.grid.horz_wall(0, height - 1)
  27. self.grid.vert_wall(0, 0)
  28. self.grid.vert_wall(width - 1, 0)
  29. assert height % 2 == 1
  30. upper_room_wall = height // 2 - 2
  31. lower_room_wall = height // 2 + 2
  32. if self.random_length:
  33. hallway_end = self._rand_int(4, width - 2)
  34. else:
  35. hallway_end = width - 3
  36. # Start room
  37. for i in range(1, 5):
  38. self.grid.set(i, upper_room_wall, Wall())
  39. self.grid.set(i, lower_room_wall, Wall())
  40. self.grid.set(4, upper_room_wall + 1, Wall())
  41. self.grid.set(4, lower_room_wall - 1, Wall())
  42. # Horizontal hallway
  43. for i in range(5, hallway_end):
  44. self.grid.set(i, upper_room_wall + 1, Wall())
  45. self.grid.set(i, lower_room_wall - 1, Wall())
  46. # Vertical hallway
  47. for j in range(0, height):
  48. if j != height // 2:
  49. self.grid.set(hallway_end, j, Wall())
  50. self.grid.set(hallway_end + 2, j, Wall())
  51. # Fix the player's start position and orientation
  52. self.agent_pos = np.array((self._rand_int(1, hallway_end + 1), height // 2))
  53. self.agent_dir = 0
  54. # Place objects
  55. start_room_obj = self._rand_elem([Key, Ball])
  56. self.grid.set(1, height // 2 - 1, start_room_obj("green"))
  57. other_objs = self._rand_elem([[Ball, Key], [Key, Ball]])
  58. pos0 = (hallway_end + 1, height // 2 - 2)
  59. pos1 = (hallway_end + 1, height // 2 + 2)
  60. self.grid.set(*pos0, other_objs[0]("green"))
  61. self.grid.set(*pos1, other_objs[1]("green"))
  62. # Choose the target objects
  63. if start_room_obj == other_objs[0]:
  64. self.success_pos = (pos0[0], pos0[1] + 1)
  65. self.failure_pos = (pos1[0], pos1[1] - 1)
  66. else:
  67. self.success_pos = (pos1[0], pos1[1] - 1)
  68. self.failure_pos = (pos0[0], pos0[1] + 1)
  69. self.mission = "go to the matching object at the end of the hallway"
  70. def step(self, action):
  71. if action == MiniGridEnv.Actions.pickup:
  72. action = MiniGridEnv.Actions.toggle
  73. obs, reward, done, info = MiniGridEnv.step(self, action)
  74. if tuple(self.agent_pos) == self.success_pos:
  75. reward = self._reward()
  76. done = True
  77. if tuple(self.agent_pos) == self.failure_pos:
  78. reward = 0
  79. done = True
  80. return obs, reward, done, info
  81. register(
  82. id="MiniGrid-MemoryS17Random-v0",
  83. entry_point="gym_minigrid.envs.memory:MemoryEnv",
  84. size=17,
  85. random_length=True,
  86. )
  87. register(
  88. id="MiniGrid-MemoryS13Random-v0",
  89. entry_point="gym_minigrid.envs.memory:MemoryEnv",
  90. size=13,
  91. random_length=True,
  92. )
  93. register(
  94. id="MiniGrid-MemoryS13-v0",
  95. entry_point="gym_minigrid.envs.memory:MemoryEnv",
  96. size=13,
  97. )
  98. register(
  99. id="MiniGrid-MemoryS11-v0",
  100. entry_point="gym_minigrid.envs.memory:MemoryEnv",
  101. size=11,
  102. )
  103. register(
  104. id="MiniGrid-MemoryS9-v0", entry_point="gym_minigrid.envs.memory:MemoryEnv", size=9
  105. )
  106. register(
  107. id="MiniGrid-MemoryS7-v0", entry_point="gym_minigrid.envs.memory:MemoryEnv", size=7
  108. )