memory.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. from gym_minigrid.minigrid import *
  2. from gym_minigrid.register import register
  3. class MemoryEnv(MiniGridEnv):
  4. """
  5. This environment is a memory test. The agent starts in a small room
  6. where it sees an object. It then has to go through a narrow hallway
  7. which ends in a split. At each end of the split there is an object,
  8. one of which is the same as the object in the starting room. The
  9. agent has to remember the initial object, and go to the matching
  10. object at split.
  11. """
  12. def __init__(
  13. self,
  14. seed,
  15. size=8,
  16. random_length=False,
  17. **kwargs
  18. ):
  19. self.random_length = random_length
  20. super().__init__(
  21. seed=seed,
  22. grid_size=size,
  23. max_steps=5*size**2,
  24. # Set this to True for maximum speed
  25. see_through_walls=False,
  26. **kwargs
  27. )
  28. def _gen_grid(self, width, height):
  29. self.grid = Grid(width, height)
  30. # Generate the surrounding walls
  31. self.grid.horz_wall(0, 0)
  32. self.grid.horz_wall(0, height-1)
  33. self.grid.vert_wall(0, 0)
  34. self.grid.vert_wall(width - 1, 0)
  35. assert height % 2 == 1
  36. upper_room_wall = height // 2 - 2
  37. lower_room_wall = height // 2 + 2
  38. if self.random_length:
  39. hallway_end = self._rand_int(4, width - 2)
  40. else:
  41. hallway_end = width - 3
  42. # Start room
  43. for i in range(1, 5):
  44. self.grid.set(i, upper_room_wall, Wall())
  45. self.grid.set(i, lower_room_wall, Wall())
  46. self.grid.set(4, upper_room_wall + 1, Wall())
  47. self.grid.set(4, lower_room_wall - 1, Wall())
  48. # Horizontal hallway
  49. for i in range(5, hallway_end):
  50. self.grid.set(i, upper_room_wall + 1, Wall())
  51. self.grid.set(i, lower_room_wall - 1, Wall())
  52. # Vertical hallway
  53. for j in range(0, height):
  54. if j != height // 2:
  55. self.grid.set(hallway_end, j, Wall())
  56. self.grid.set(hallway_end + 2, j, Wall())
  57. # Fix the player's start position and orientation
  58. self.agent_pos = (self._rand_int(1, hallway_end + 1), height // 2)
  59. self.agent_dir = 0
  60. # Place objects
  61. start_room_obj = self._rand_elem([Key, Ball])
  62. self.grid.set(1, height // 2 - 1, start_room_obj('green'))
  63. other_objs = self._rand_elem([[Ball, Key], [Key, Ball]])
  64. pos0 = (hallway_end + 1, height // 2 - 2)
  65. pos1 = (hallway_end + 1, height // 2 + 2)
  66. self.grid.set(*pos0, other_objs[0]('green'))
  67. self.grid.set(*pos1, other_objs[1]('green'))
  68. # Choose the target objects
  69. if start_room_obj == other_objs[0]:
  70. self.success_pos = (pos0[0], pos0[1] + 1)
  71. self.failure_pos = (pos1[0], pos1[1] - 1)
  72. else:
  73. self.success_pos = (pos1[0], pos1[1] - 1)
  74. self.failure_pos = (pos0[0], pos0[1] + 1)
  75. self.mission = 'go to the matching object at the end of the hallway'
  76. def step(self, action):
  77. if action == MiniGridEnv.Actions.pickup:
  78. action = MiniGridEnv.Actions.toggle
  79. obs, reward, done, info = MiniGridEnv.step(self, action)
  80. if tuple(self.agent_pos) == self.success_pos:
  81. reward = self._reward()
  82. done = True
  83. if tuple(self.agent_pos) == self.failure_pos:
  84. reward = 0
  85. done = True
  86. return obs, reward, done, info
  87. class MemoryS17Random(MemoryEnv):
  88. def __init__(self, seed=None, **kwargs):
  89. super().__init__(seed=seed, size=17, random_length=True, **kwargs)
  90. register(
  91. id='MiniGrid-MemoryS17Random-v0',
  92. entry_point='gym_minigrid.envs:MemoryS17Random',
  93. )
  94. class MemoryS13Random(MemoryEnv):
  95. def __init__(self, seed=None, **kwargs):
  96. super().__init__(seed=seed, size=13, random_length=True, **kwargs)
  97. register(
  98. id='MiniGrid-MemoryS13Random-v0',
  99. entry_point='gym_minigrid.envs:MemoryS13Random',
  100. )
  101. class MemoryS13(MemoryEnv):
  102. def __init__(self, seed=None, **kwargs):
  103. super().__init__(seed=seed, size=13, **kwargs)
  104. register(
  105. id='MiniGrid-MemoryS13-v0',
  106. entry_point='gym_minigrid.envs:MemoryS13',
  107. )
  108. class MemoryS11(MemoryEnv):
  109. def __init__(self, seed=None, **kwargs):
  110. super().__init__(seed=seed, size=11, **kwargs)
  111. register(
  112. id='MiniGrid-MemoryS11-v0',
  113. entry_point='gym_minigrid.envs:MemoryS11',
  114. )
  115. class MemoryS9(MemoryEnv):
  116. def __init__(self, seed=None, **kwargs):
  117. super().__init__(seed=seed, size=9, **kwargs)
  118. register(
  119. id='MiniGrid-MemoryS9-v0',
  120. entry_point='gym_minigrid.envs:MemoryS9',
  121. )
  122. class MemoryS7(MemoryEnv):
  123. def __init__(self, seed=None, **kwargs):
  124. super().__init__(seed=seed, size=7, **kwargs)
  125. register(
  126. id='MiniGrid-MemoryS7-v0',
  127. entry_point='gym_minigrid.envs:MemoryS7',
  128. )