keycorridor.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from gym_minigrid.roomgrid import RoomGrid
  2. from gym_minigrid.register import register
  3. class KeyCorridorEnv(RoomGrid):
  4. """
  5. A ball is behind a locked door, the key is placed in a
  6. random room.
  7. """
  8. def __init__(
  9. self,
  10. num_rows=3,
  11. obj_type="ball",
  12. room_size=6,
  13. seed=None
  14. ):
  15. self.obj_type = obj_type
  16. super().__init__(
  17. room_size=room_size,
  18. num_rows=num_rows,
  19. max_steps=30*room_size**2,
  20. seed=seed,
  21. )
  22. def _gen_grid(self, width, height):
  23. super()._gen_grid(width, height)
  24. # Connect the middle column rooms into a hallway
  25. for j in range(1, self.num_rows):
  26. self.remove_wall(1, j, 3)
  27. # Add a locked door on the bottom right
  28. # Add an object behind the locked door
  29. room_idx = self._rand_int(0, self.num_rows)
  30. door, _ = self.add_door(2, room_idx, 2, locked=True)
  31. obj, _ = self.add_object(2, room_idx, kind=self.obj_type)
  32. # Add a key in a random room on the left side
  33. self.add_object(0, self._rand_int(0, self.num_rows), 'key', door.color)
  34. # Place the agent in the middle
  35. self.place_agent(1, self.num_rows // 2)
  36. # Make sure all rooms are accessible
  37. self.connect_all()
  38. self.obj = obj
  39. self.mission = "pick up the %s %s" % (obj.color, obj.type)
  40. def step(self, action):
  41. obs, reward, done, info = super().step(action)
  42. if action == self.actions.pickup:
  43. if self.carrying and self.carrying == self.obj:
  44. reward = self._reward()
  45. done = True
  46. return obs, reward, done, info
  47. register(
  48. id='MiniGrid-KeyCorridorS3R1-v0',
  49. entry_point='gym_minigrid.envs.keycorridor:KeyCorridorEnv',
  50. room_size=3,
  51. num_rows=1,
  52. )
  53. register(
  54. id='MiniGrid-KeyCorridorS3R2-v0',
  55. entry_point='gym_minigrid.envs.keycorridor:KeyCorridorEnv',
  56. room_size=3,
  57. num_rows=2,
  58. )
  59. register(
  60. id='MiniGrid-KeyCorridorS3R3-v0',
  61. entry_point='gym_minigrid.envs.keycorridor:KeyCorridorEnv',
  62. room_size=3,
  63. num_rows=3,
  64. )
  65. register(
  66. id='MiniGrid-KeyCorridorS4R3-v0',
  67. entry_point='gym_minigrid.envs.keycorridor:KeyCorridorEnv',
  68. room_size=4,
  69. num_rows=3,
  70. )
  71. register(
  72. id='MiniGrid-KeyCorridorS5R3-v0',
  73. entry_point='gym_minigrid.envs.keycorridor:KeyCorridorEnv',
  74. room_size=5,
  75. num_rows=3,
  76. )
  77. register(
  78. id='MiniGrid-KeyCorridorS6R3-v0',
  79. entry_point='gym_minigrid.envs.keycorridor:KeyCorridorEnv',
  80. room_size=6,
  81. num_rows=3,
  82. )