keycorridor.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from gym_minigrid.register import register
  2. from gym_minigrid.roomgrid import RoomGrid
  3. class KeyCorridor(RoomGrid):
  4. """
  5. A ball is behind a locked door, the key is placed in a
  6. random room.
  7. """
  8. def __init__(self, num_rows=3, obj_type="ball", room_size=6, seed=None):
  9. self.obj_type = obj_type
  10. super().__init__(
  11. room_size=room_size,
  12. num_rows=num_rows,
  13. max_steps=30 * room_size**2,
  14. seed=seed,
  15. )
  16. def _gen_grid(self, width, height):
  17. super()._gen_grid(width, height)
  18. # Connect the middle column rooms into a hallway
  19. for j in range(1, self.num_rows):
  20. self.remove_wall(1, j, 3)
  21. # Add a locked door on the bottom right
  22. # Add an object behind the locked door
  23. room_idx = self._rand_int(0, self.num_rows)
  24. door, _ = self.add_door(2, room_idx, 2, locked=True)
  25. obj, _ = self.add_object(2, room_idx, kind=self.obj_type)
  26. # Add a key in a random room on the left side
  27. self.add_object(0, self._rand_int(0, self.num_rows), "key", door.color)
  28. # Place the agent in the middle
  29. self.place_agent(1, self.num_rows // 2)
  30. # Make sure all rooms are accessible
  31. self.connect_all()
  32. self.obj = obj
  33. self.mission = f"pick up the {obj.color} {obj.type}"
  34. def step(self, action):
  35. obs, reward, done, info = super().step(action)
  36. if action == self.actions.pickup:
  37. if self.carrying and self.carrying == self.obj:
  38. reward = self._reward()
  39. done = True
  40. return obs, reward, done, info
  41. class KeyCorridorS3R1(KeyCorridor):
  42. def __init__(self, seed=None):
  43. super().__init__(room_size=3, num_rows=1, seed=seed)
  44. class KeyCorridorS3R2(KeyCorridor):
  45. def __init__(self, seed=None):
  46. super().__init__(room_size=3, num_rows=2, seed=seed)
  47. class KeyCorridorS3R3(KeyCorridor):
  48. def __init__(self, seed=None):
  49. super().__init__(room_size=3, num_rows=3, seed=seed)
  50. class KeyCorridorS4R3(KeyCorridor):
  51. def __init__(self, seed=None):
  52. super().__init__(room_size=4, num_rows=3, seed=seed)
  53. class KeyCorridorS5R3(KeyCorridor):
  54. def __init__(self, seed=None):
  55. super().__init__(room_size=5, num_rows=3, seed=seed)
  56. class KeyCorridorS6R3(KeyCorridor):
  57. def __init__(self, seed=None):
  58. super().__init__(room_size=6, num_rows=3, seed=seed)
  59. register(
  60. id="MiniGrid-KeyCorridorS3R1-v0", entry_point="gym_minigrid.envs:KeyCorridorS3R1"
  61. )
  62. register(
  63. id="MiniGrid-KeyCorridorS3R2-v0", entry_point="gym_minigrid.envs:KeyCorridorS3R2"
  64. )
  65. register(
  66. id="MiniGrid-KeyCorridorS3R3-v0", entry_point="gym_minigrid.envs:KeyCorridorS3R3"
  67. )
  68. register(
  69. id="MiniGrid-KeyCorridorS4R3-v0", entry_point="gym_minigrid.envs:KeyCorridorS4R3"
  70. )
  71. register(
  72. id="MiniGrid-KeyCorridorS5R3-v0", entry_point="gym_minigrid.envs:KeyCorridorS5R3"
  73. )
  74. register(
  75. id="MiniGrid-KeyCorridorS6R3-v0", entry_point="gym_minigrid.envs:KeyCorridorS6R3"
  76. )