crossing.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import itertools as itt
  2. from typing import Optional
  3. import numpy as np
  4. from minigrid.core.grid import Grid
  5. from minigrid.core.mission import MissionSpace
  6. from minigrid.core.world_object import Goal, Lava
  7. from minigrid.minigrid_env import MiniGridEnv
  8. class CrossingEnv(MiniGridEnv):
  9. """
  10. <p>
  11. <img style="float:left" src="https://raw.githubusercontent.com/Farama-Foundation/Minigrid/master/figures/LavaCrossingS9N1.png" alt="LavaCrossingS9N1" width="200px"/>
  12. <img style="float:left" src="https://raw.githubusercontent.com/Farama-Foundation/Minigrid/master/figures/LavaCrossingS9N2.png" alt="LavaCrossingS9N2" width="200px"/>
  13. <img style="float:left" src="https://raw.githubusercontent.com/Farama-Foundation/Minigrid/master/figures/LavaCrossingS9N3.png" alt="LavaCrossingS9N3" width="200px"/>
  14. <img style="float:left" src="https://raw.githubusercontent.com/Farama-Foundation/Minigrid/master/figures/LavaCrossingS11N5.png" alt="LavaCrossingS11N5" width="200px"/>
  15. </p>
  16. <p>
  17. <img style="float:left" src="https://raw.githubusercontent.com/Farama-Foundation/Minigrid/master/figures/SimpleCrossingS9N1.png" alt="SimpleCrossingS9N1" width="200px"/>
  18. <img style="float:left" src="https://raw.githubusercontent.com/Farama-Foundation/Minigrid/master/figures/SimpleCrossingS9N2.png" alt="SimpleCrossingS9N2" width="200px"/>
  19. <img style="float:left" src="https://raw.githubusercontent.com/Farama-Foundation/Minigrid/master/figures/SimpleCrossingS9N3.png" alt="SimpleCrossingS9N3" width="200px"/>
  20. <img src="https://raw.githubusercontent.com/Farama-Foundation/Minigrid/master/figures/SimpleCrossingS11N5.png" alt="SimpleCrossingS11N5" width="200px"/>
  21. </p>
  22. ### Description
  23. Depending on the `obstacle_type` parameter:
  24. - `Lava` - The agent has to reach the green goal square on the other corner
  25. of the room while avoiding rivers of deadly lava which terminate the
  26. episode in failure. Each lava stream runs across the room either
  27. horizontally or vertically, and has a single crossing point which can be
  28. safely used; Luckily, a path to the goal is guaranteed to exist. This
  29. environment is useful for studying safety and safe exploration.
  30. - otherwise - Similar to the `LavaCrossing` environment, the agent has to
  31. reach the green goal square on the other corner of the room, however
  32. lava is replaced by walls. This MDP is therefore much easier and maybe
  33. useful for quickly testing your algorithms.
  34. ### Mission Space
  35. Depending on the `obstacle_type` parameter:
  36. - `Lava` - "avoid the lava and get to the green goal square"
  37. - otherwise - "find the opening and get to the green goal square"
  38. ### Action Space
  39. | Num | Name | Action |
  40. |-----|--------------|--------------|
  41. | 0 | left | Turn left |
  42. | 1 | right | Turn right |
  43. | 2 | forward | Move forward |
  44. | 3 | pickup | Unused |
  45. | 4 | drop | Unused |
  46. | 5 | toggle | Unused |
  47. | 6 | done | Unused |
  48. ### Observation Encoding
  49. - Each tile is encoded as a 3 dimensional tuple:
  50. `(OBJECT_IDX, COLOR_IDX, STATE)`
  51. - `OBJECT_TO_IDX` and `COLOR_TO_IDX` mapping can be found in
  52. [minigrid/minigrid.py](minigrid/minigrid.py)
  53. - `STATE` refers to the door state with 0=open, 1=closed and 2=locked
  54. ### Rewards
  55. A reward of '1' is given for success, and '0' for failure.
  56. ### Termination
  57. The episode ends if any one of the following conditions is met:
  58. 1. The agent reaches the goal.
  59. 2. The agent falls into lava.
  60. 3. Timeout (see `max_steps`).
  61. ### Registered Configurations
  62. S: size of the map SxS.
  63. N: number of valid crossings across lava or walls from the starting position
  64. to the goal
  65. - `Lava` :
  66. - `MiniGrid-LavaCrossingS9N1-v0`
  67. - `MiniGrid-LavaCrossingS9N2-v0`
  68. - `MiniGrid-LavaCrossingS9N3-v0`
  69. - `MiniGrid-LavaCrossingS11N5-v0`
  70. - otherwise :
  71. - `MiniGrid-SimpleCrossingS9N1-v0`
  72. - `MiniGrid-SimpleCrossingS9N2-v0`
  73. - `MiniGrid-SimpleCrossingS9N3-v0`
  74. - `MiniGrid-SimpleCrossingS11N5-v0`
  75. """
  76. def __init__(
  77. self,
  78. size=9,
  79. num_crossings=1,
  80. obstacle_type=Lava,
  81. max_steps: Optional[int] = None,
  82. **kwargs
  83. ):
  84. self.num_crossings = num_crossings
  85. self.obstacle_type = obstacle_type
  86. if obstacle_type == Lava:
  87. mission_space = MissionSpace(mission_func=self._gen_mission_lava)
  88. else:
  89. mission_space = MissionSpace(mission_func=self._gen_mission)
  90. if max_steps is None:
  91. max_steps = 4 * size**2
  92. super().__init__(
  93. mission_space=mission_space,
  94. grid_size=size,
  95. see_through_walls=False, # Set this to True for maximum speed
  96. max_steps=max_steps,
  97. **kwargs
  98. )
  99. @staticmethod
  100. def _gen_mission_lava():
  101. return "avoid the lava and get to the green goal square"
  102. @staticmethod
  103. def _gen_mission():
  104. return "find the opening and get to the green goal square"
  105. def _gen_grid(self, width, height):
  106. assert width % 2 == 1 and height % 2 == 1 # odd size
  107. # Create an empty grid
  108. self.grid = Grid(width, height)
  109. # Generate the surrounding walls
  110. self.grid.wall_rect(0, 0, width, height)
  111. # Place the agent in the top-left corner
  112. self.agent_pos = np.array((1, 1))
  113. self.agent_dir = 0
  114. # Place a goal square in the bottom-right corner
  115. self.put_obj(Goal(), width - 2, height - 2)
  116. # Place obstacles (lava or walls)
  117. v, h = object(), object() # singleton `vertical` and `horizontal` objects
  118. # Lava rivers or walls specified by direction and position in grid
  119. rivers = [(v, i) for i in range(2, height - 2, 2)]
  120. rivers += [(h, j) for j in range(2, width - 2, 2)]
  121. self.np_random.shuffle(rivers)
  122. rivers = rivers[: self.num_crossings] # sample random rivers
  123. rivers_v = sorted(pos for direction, pos in rivers if direction is v)
  124. rivers_h = sorted(pos for direction, pos in rivers if direction is h)
  125. obstacle_pos = itt.chain(
  126. itt.product(range(1, width - 1), rivers_h),
  127. itt.product(rivers_v, range(1, height - 1)),
  128. )
  129. for i, j in obstacle_pos:
  130. self.put_obj(self.obstacle_type(), i, j)
  131. # Sample path to goal
  132. path = [h] * len(rivers_v) + [v] * len(rivers_h)
  133. self.np_random.shuffle(path)
  134. # Create openings
  135. limits_v = [0] + rivers_v + [height - 1]
  136. limits_h = [0] + rivers_h + [width - 1]
  137. room_i, room_j = 0, 0
  138. for direction in path:
  139. if direction is h:
  140. i = limits_v[room_i + 1]
  141. j = self.np_random.choice(
  142. range(limits_h[room_j] + 1, limits_h[room_j + 1])
  143. )
  144. room_i += 1
  145. elif direction is v:
  146. i = self.np_random.choice(
  147. range(limits_v[room_i] + 1, limits_v[room_i + 1])
  148. )
  149. j = limits_h[room_j + 1]
  150. room_j += 1
  151. else:
  152. assert False
  153. self.grid.set(i, j, None)
  154. self.mission = (
  155. "avoid the lava and get to the green goal square"
  156. if self.obstacle_type == Lava
  157. else "find the opening and get to the green goal square"
  158. )