keycorridor.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from gym_minigrid.register import register
  2. from gym_minigrid.roomgrid import RoomGrid
  3. class KeyCorridorEnv(RoomGrid):
  4. """
  5. A ball is behind a locked door, the key is placed in a
  6. random room.
  7. """
  8. def __init__(self, num_rows=3, obj_type="ball", room_size=6, **kwargs):
  9. self.obj_type = obj_type
  10. super().__init__(
  11. room_size=room_size,
  12. num_rows=num_rows,
  13. max_steps=30 * room_size**2,
  14. **kwargs,
  15. )
  16. def _gen_grid(self, width, height):
  17. super()._gen_grid(width, height)
  18. # Connect the middle column rooms into a hallway
  19. for j in range(1, self.num_rows):
  20. self.remove_wall(1, j, 3)
  21. # Add a locked door on the bottom right
  22. # Add an object behind the locked door
  23. room_idx = self._rand_int(0, self.num_rows)
  24. door, _ = self.add_door(2, room_idx, 2, locked=True)
  25. obj, _ = self.add_object(2, room_idx, kind=self.obj_type)
  26. # Add a key in a random room on the left side
  27. self.add_object(0, self._rand_int(0, self.num_rows), "key", door.color)
  28. # Place the agent in the middle
  29. self.place_agent(1, self.num_rows // 2)
  30. # Make sure all rooms are accessible
  31. self.connect_all()
  32. self.obj = obj
  33. self.mission = f"pick up the {obj.color} {obj.type}"
  34. def step(self, action):
  35. obs, reward, done, info = super().step(action)
  36. if action == self.actions.pickup:
  37. if self.carrying and self.carrying == self.obj:
  38. reward = self._reward()
  39. done = True
  40. return obs, reward, done, info
  41. register(
  42. id="MiniGrid-KeyCorridorS3R1-v0",
  43. entry_point="gym_minigrid.envs.keycorridor:KeyCorridorEnv",
  44. room_size=3,
  45. num_rows=1,
  46. )
  47. register(
  48. id="MiniGrid-KeyCorridorS3R2-v0",
  49. entry_point="gym_minigrid.envs.keycorridor:KeyCorridorEnv",
  50. room_size=3,
  51. num_rows=2,
  52. )
  53. register(
  54. id="MiniGrid-KeyCorridorS3R3-v0",
  55. entry_point="gym_minigrid.envs.keycorridor:KeyCorridorEnv",
  56. room_size=3,
  57. num_rows=3,
  58. )
  59. register(
  60. id="MiniGrid-KeyCorridorS4R3-v0",
  61. entry_point="gym_minigrid.envs.keycorridor:KeyCorridorEnv",
  62. room_size=4,
  63. num_rows=3,
  64. )
  65. register(
  66. id="MiniGrid-KeyCorridorS5R3-v0",
  67. entry_point="gym_minigrid.envs.keycorridor:KeyCorridorEnv",
  68. room_size=5,
  69. num_rows=3,
  70. )
  71. register(
  72. id="MiniGrid-KeyCorridorS6R3-v0",
  73. entry_point="gym_minigrid.envs.keycorridor:KeyCorridorEnv",
  74. room_size=6,
  75. num_rows=3,
  76. )