crossing.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import itertools as itt
  2. import numpy as np
  3. from gym_minigrid.minigrid import Goal, Grid, Lava, MiniGridEnv, MissionSpace
  4. class CrossingEnv(MiniGridEnv):
  5. """
  6. ### Description
  7. Depending on the `obstacle_type` parameter:
  8. - `Lava` - The agent has to reach the green goal square on the other corner
  9. of the room while avoiding rivers of deadly lava which terminate the
  10. episode in failure. Each lava stream runs across the room either
  11. horizontally or vertically, and has a single crossing point which can be
  12. safely used; Luckily, a path to the goal is guaranteed to exist. This
  13. environment is useful for studying safety and safe exploration.
  14. - otherwise - Similar to the `LavaCrossing` environment, the agent has to
  15. reach the green goal square on the other corner of the room, however
  16. lava is replaced by walls. This MDP is therefore much easier and maybe
  17. useful for quickly testing your algorithms.
  18. ### Mission Space
  19. Depending on the `obstacle_type` parameter:
  20. - `Lava` - "avoid the lava and get to the green goal square"
  21. - otherwise - "find the opening and get to the green goal square"
  22. ### Action Space
  23. | Num | Name | Action |
  24. |-----|--------------|--------------|
  25. | 0 | left | Turn left |
  26. | 1 | right | Turn right |
  27. | 2 | forward | Move forward |
  28. | 3 | pickup | Unused |
  29. | 4 | drop | Unused |
  30. | 5 | toggle | Unused |
  31. | 6 | done | Unused |
  32. ### Observation Encoding
  33. - Each tile is encoded as a 3 dimensional tuple:
  34. `(OBJECT_IDX, COLOR_IDX, STATE)`
  35. - `OBJECT_TO_IDX` and `COLOR_TO_IDX` mapping can be found in
  36. [gym_minigrid/minigrid.py](gym_minigrid/minigrid.py)
  37. - `STATE` refers to the door state with 0=open, 1=closed and 2=locked
  38. ### Rewards
  39. A reward of '1' is given for success, and '0' for failure.
  40. ### Termination
  41. The episode ends if any one of the following conditions is met:
  42. 1. The agent reaches the goal.
  43. 2. The agent falls into lava.
  44. 3. Timeout (see `max_steps`).
  45. ### Registered Configurations
  46. S: size of the map SxS.
  47. N: number of valid crossings across lava or walls from the starting position
  48. to the goal
  49. - `Lava` :
  50. - `MiniGrid-LavaCrossingS9N1-v0`
  51. - `MiniGrid-LavaCrossingS9N2-v0`
  52. - `MiniGrid-LavaCrossingS9N3-v0`
  53. - `MiniGrid-LavaCrossingS11N5-v0`
  54. - otherwise :
  55. - `MiniGrid-SimpleCrossingS9N1-v0`
  56. - `MiniGrid-SimpleCrossingS9N2-v0`
  57. - `MiniGrid-SimpleCrossingS9N3-v0`
  58. - `MiniGrid-SimpleCrossingS11N5-v0`
  59. """
  60. def __init__(self, size=9, num_crossings=1, obstacle_type=Lava, **kwargs):
  61. self.num_crossings = num_crossings
  62. self.obstacle_type = obstacle_type
  63. if obstacle_type == Lava:
  64. mission_space = MissionSpace(
  65. mission_func=lambda: "avoid the lava and get to the green goal square"
  66. )
  67. else:
  68. mission_space = MissionSpace(
  69. mission_func=lambda: "find the opening and get to the green goal square"
  70. )
  71. super().__init__(
  72. mission_space=mission_space,
  73. grid_size=size,
  74. max_steps=4 * size * size,
  75. # Set this to True for maximum speed
  76. see_through_walls=False,
  77. **kwargs
  78. )
  79. def _gen_grid(self, width, height):
  80. assert width % 2 == 1 and height % 2 == 1 # odd size
  81. # Create an empty grid
  82. self.grid = Grid(width, height)
  83. # Generate the surrounding walls
  84. self.grid.wall_rect(0, 0, width, height)
  85. # Place the agent in the top-left corner
  86. self.agent_pos = np.array((1, 1))
  87. self.agent_dir = 0
  88. # Place a goal square in the bottom-right corner
  89. self.put_obj(Goal(), width - 2, height - 2)
  90. # Place obstacles (lava or walls)
  91. v, h = object(), object() # singleton `vertical` and `horizontal` objects
  92. # Lava rivers or walls specified by direction and position in grid
  93. rivers = [(v, i) for i in range(2, height - 2, 2)]
  94. rivers += [(h, j) for j in range(2, width - 2, 2)]
  95. self.np_random.shuffle(rivers)
  96. rivers = rivers[: self.num_crossings] # sample random rivers
  97. rivers_v = sorted(pos for direction, pos in rivers if direction is v)
  98. rivers_h = sorted(pos for direction, pos in rivers if direction is h)
  99. obstacle_pos = itt.chain(
  100. itt.product(range(1, width - 1), rivers_h),
  101. itt.product(rivers_v, range(1, height - 1)),
  102. )
  103. for i, j in obstacle_pos:
  104. self.put_obj(self.obstacle_type(), i, j)
  105. # Sample path to goal
  106. path = [h] * len(rivers_v) + [v] * len(rivers_h)
  107. self.np_random.shuffle(path)
  108. # Create openings
  109. limits_v = [0] + rivers_v + [height - 1]
  110. limits_h = [0] + rivers_h + [width - 1]
  111. room_i, room_j = 0, 0
  112. for direction in path:
  113. if direction is h:
  114. i = limits_v[room_i + 1]
  115. j = self.np_random.choice(
  116. range(limits_h[room_j] + 1, limits_h[room_j + 1])
  117. )
  118. room_i += 1
  119. elif direction is v:
  120. i = self.np_random.choice(
  121. range(limits_v[room_i] + 1, limits_v[room_i + 1])
  122. )
  123. j = limits_h[room_j + 1]
  124. room_j += 1
  125. else:
  126. assert False
  127. self.grid.set(i, j, None)
  128. self.mission = (
  129. "avoid the lava and get to the green goal square"
  130. if self.obstacle_type == Lava
  131. else "find the opening and get to the green goal square"
  132. )