keycorridor.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. from gym_minigrid.roomgrid import RoomGrid
  2. from gym_minigrid.register import register
  3. class KeyCorridor(RoomGrid):
  4. """
  5. A ball is behind a locked door, the key is placed in a
  6. random room.
  7. """
  8. def __init__(
  9. self,
  10. num_rows=3,
  11. obj_type="ball",
  12. room_size=6,
  13. seed=None,
  14. **kwargs
  15. ):
  16. self.obj_type = obj_type
  17. super().__init__(
  18. room_size=room_size,
  19. num_rows=num_rows,
  20. max_steps=30*room_size**2,
  21. seed=seed,
  22. **kwargs
  23. )
  24. def _gen_grid(self, width, height):
  25. super()._gen_grid(width, height)
  26. # Connect the middle column rooms into a hallway
  27. for j in range(1, self.num_rows):
  28. self.remove_wall(1, j, 3)
  29. # Add a locked door on the bottom right
  30. # Add an object behind the locked door
  31. room_idx = self._rand_int(0, self.num_rows)
  32. door, _ = self.add_door(2, room_idx, 2, locked=True)
  33. obj, _ = self.add_object(2, room_idx, kind=self.obj_type)
  34. # Add a key in a random room on the left side
  35. self.add_object(0, self._rand_int(0, self.num_rows), 'key', door.color)
  36. # Place the agent in the middle
  37. self.place_agent(1, self.num_rows // 2)
  38. # Make sure all rooms are accessible
  39. self.connect_all()
  40. self.obj = obj
  41. self.mission = "pick up the %s %s" % (obj.color, obj.type)
  42. def step(self, action):
  43. obs, reward, done, info = super().step(action)
  44. if action == self.actions.pickup:
  45. if self.carrying and self.carrying == self.obj:
  46. reward = self._reward()
  47. done = True
  48. return obs, reward, done, info
  49. class KeyCorridorS3R1(KeyCorridor):
  50. def __init__(self, seed=None, **kwargs):
  51. super().__init__(
  52. room_size=3,
  53. num_rows=1,
  54. seed=seed,
  55. **kwargs
  56. )
  57. class KeyCorridorS3R2(KeyCorridor):
  58. def __init__(self, seed=None, **kwargs):
  59. super().__init__(
  60. room_size=3,
  61. num_rows=2,
  62. seed=seed,
  63. **kwargs
  64. )
  65. class KeyCorridorS3R3(KeyCorridor):
  66. def __init__(self, seed=None, **kwargs):
  67. super().__init__(
  68. room_size=3,
  69. num_rows=3,
  70. seed=seed,
  71. **kwargs
  72. )
  73. class KeyCorridorS4R3(KeyCorridor):
  74. def __init__(self, seed=None, **kwargs):
  75. super().__init__(
  76. room_size=4,
  77. num_rows=3,
  78. seed=seed,
  79. **kwargs
  80. )
  81. class KeyCorridorS5R3(KeyCorridor):
  82. def __init__(self, seed=None, **kwargs):
  83. super().__init__(
  84. room_size=5,
  85. num_rows=3,
  86. seed=seed,
  87. **kwargs
  88. )
  89. class KeyCorridorS6R3(KeyCorridor):
  90. def __init__(self, seed=None, **kwargs):
  91. super().__init__(
  92. room_size=6,
  93. num_rows=3,
  94. seed=seed,
  95. **kwargs
  96. )
  97. register(
  98. id='MiniGrid-KeyCorridorS3R1-v0',
  99. entry_point='gym_minigrid.envs:KeyCorridorS3R1'
  100. )
  101. register(
  102. id='MiniGrid-KeyCorridorS3R2-v0',
  103. entry_point='gym_minigrid.envs:KeyCorridorS3R2'
  104. )
  105. register(
  106. id='MiniGrid-KeyCorridorS3R3-v0',
  107. entry_point='gym_minigrid.envs:KeyCorridorS3R3'
  108. )
  109. register(
  110. id='MiniGrid-KeyCorridorS4R3-v0',
  111. entry_point='gym_minigrid.envs:KeyCorridorS4R3'
  112. )
  113. register(
  114. id='MiniGrid-KeyCorridorS5R3-v0',
  115. entry_point='gym_minigrid.envs:KeyCorridorS5R3'
  116. )
  117. register(
  118. id='MiniGrid-KeyCorridorS6R3-v0',
  119. entry_point='gym_minigrid.envs:KeyCorridorS6R3'
  120. )