keycorridor.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. from gym_minigrid.roomgrid import RoomGrid
  2. from gym_minigrid.register import register
  3. class KeyCorridor(RoomGrid):
  4. """
  5. A ball is behind a locked door, the key is placed in a
  6. random room.
  7. """
  8. def __init__(
  9. self,
  10. num_rows=3,
  11. obj_type="ball",
  12. room_size=6,
  13. seed=None
  14. ):
  15. self.obj_type = obj_type
  16. super().__init__(
  17. room_size=room_size,
  18. num_rows=num_rows,
  19. max_steps=30*room_size**2,
  20. seed=seed,
  21. )
  22. def _gen_grid(self, width, height):
  23. super()._gen_grid(width, height)
  24. # Connect the middle column rooms into a hallway
  25. for j in range(1, self.num_rows):
  26. self.remove_wall(1, j, 3)
  27. # Add a locked door on the bottom right
  28. # Add an object behind the locked door
  29. room_idx = self._rand_int(0, self.num_rows)
  30. door, _ = self.add_door(2, room_idx, 2, locked=True)
  31. obj, _ = self.add_object(2, room_idx, kind=self.obj_type)
  32. # Add a key in a random room on the left side
  33. self.add_object(0, self._rand_int(0, self.num_rows), 'key', door.color)
  34. # Place the agent in the middle
  35. self.place_agent(1, self.num_rows // 2)
  36. # Make sure all rooms are accessible
  37. self.connect_all()
  38. self.obj = obj
  39. self.mission = "pick up the %s %s" % (obj.color, obj.type)
  40. def step(self, action):
  41. obs, reward, done, info = super().step(action)
  42. if action == self.actions.pickup:
  43. if self.carrying and self.carrying == self.obj:
  44. reward = self._reward()
  45. done = True
  46. return obs, reward, done, info
  47. class KeyCorridorS3R1(KeyCorridor):
  48. def __init__(self, seed=None):
  49. super().__init__(
  50. room_size=3,
  51. num_rows=1,
  52. seed=seed
  53. )
  54. class KeyCorridorS3R2(KeyCorridor):
  55. def __init__(self, seed=None):
  56. super().__init__(
  57. room_size=3,
  58. num_rows=2,
  59. seed=seed
  60. )
  61. class KeyCorridorS3R3(KeyCorridor):
  62. def __init__(self, seed=None):
  63. super().__init__(
  64. room_size=3,
  65. num_rows=3,
  66. seed=seed
  67. )
  68. class KeyCorridorS4R3(KeyCorridor):
  69. def __init__(self, seed=None):
  70. super().__init__(
  71. room_size=4,
  72. num_rows=3,
  73. seed=seed
  74. )
  75. class KeyCorridorS5R3(KeyCorridor):
  76. def __init__(self, seed=None):
  77. super().__init__(
  78. room_size=5,
  79. num_rows=3,
  80. seed=seed
  81. )
  82. class KeyCorridorS6R3(KeyCorridor):
  83. def __init__(self, seed=None):
  84. super().__init__(
  85. room_size=6,
  86. num_rows=3,
  87. seed=seed
  88. )
  89. register(
  90. id='MiniGrid-KeyCorridorS3R1-v0',
  91. entry_point='gym_minigrid.envs:KeyCorridorS3R1'
  92. )
  93. register(
  94. id='MiniGrid-KeyCorridorS3R2-v0',
  95. entry_point='gym_minigrid.envs:KeyCorridorS3R2'
  96. )
  97. register(
  98. id='MiniGrid-KeyCorridorS3R3-v0',
  99. entry_point='gym_minigrid.envs:KeyCorridorS3R3'
  100. )
  101. register(
  102. id='MiniGrid-KeyCorridorS4R3-v0',
  103. entry_point='gym_minigrid.envs:KeyCorridorS4R3'
  104. )
  105. register(
  106. id='MiniGrid-KeyCorridorS5R3-v0',
  107. entry_point='gym_minigrid.envs:KeyCorridorS5R3'
  108. )
  109. register(
  110. id='MiniGrid-KeyCorridorS6R3-v0',
  111. entry_point='gym_minigrid.envs:KeyCorridorS6R3'
  112. )