lavacrossing.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. from gym_minigrid.minigrid import *
  2. from gym_minigrid.register import register
  3. import itertools as itt
  4. class LavaCrossingEnv(MiniGridEnv):
  5. """
  6. Environment with lava obstacles, sparse reward
  7. """
  8. def __init__(self, size=9, num_crossings=1, seed=None):
  9. self.num_crossings = num_crossings
  10. super().__init__(
  11. grid_size=size,
  12. max_steps=4*size*size,
  13. # Set this to True for maximum speed
  14. see_through_walls=False,
  15. seed=None
  16. )
  17. def _gen_grid(self, width, height):
  18. assert width % 2 == 1 and height % 2 == 1 # odd size
  19. # Create an empty grid
  20. self.grid = Grid(width, height)
  21. # Generate the surrounding walls
  22. self.grid.wall_rect(0, 0, width, height)
  23. # Place the agent in the top-left corner
  24. self.start_pos = (1, 1)
  25. self.start_dir = 0
  26. # Place a goal square in the bottom-right corner
  27. self.grid.set(width - 2, height - 2, Goal())
  28. # Place lava tiles
  29. v, h = object(), object() # singleton `vertical` and `horizontal` objects
  30. # Lava river specified by direction and position in grid
  31. rivers = [(v, i) for i in range(2, height - 2, 2)]
  32. rivers += [(h, j) for j in range(2, width - 2, 2)]
  33. self.np_random.shuffle(rivers)
  34. rivers = rivers[:self.num_crossings] # sample random rivers
  35. rivers_v = sorted([pos for direction, pos in rivers if direction is v])
  36. rivers_h = sorted([pos for direction, pos in rivers if direction is h])
  37. lava_pos = itt.chain(
  38. itt.product(range(1, width - 1), rivers_h),
  39. itt.product(rivers_v, range(1, height - 1)),
  40. )
  41. for i, j in lava_pos:
  42. self.grid.set(i, j, Lava())
  43. # Sample path to goal
  44. path = [h] * len(rivers_v) + [v] * len(rivers_h)
  45. self.np_random.shuffle(path)
  46. # Create openings in lava rivers
  47. limits_v = [0] + rivers_v + [height - 1]
  48. limits_h = [0] + rivers_h + [width - 1]
  49. room_i, room_j = 0, 0
  50. for direction in path:
  51. if direction is h:
  52. i = limits_v[room_i + 1]
  53. j = self.np_random.choice(
  54. range(limits_h[room_j] + 1, limits_h[room_j + 1]))
  55. room_i += 1
  56. elif direction is v:
  57. i = self.np_random.choice(
  58. range(limits_v[room_i] + 1, limits_v[room_i + 1]))
  59. j = limits_h[room_j + 1]
  60. room_j += 1
  61. else:
  62. assert False
  63. self.grid.set(i, j, None)
  64. self.mission = "avoid the lava and get to the green goal square"
  65. class LavaCrossingS9N2Env(LavaCrossingEnv):
  66. def __init__(self):
  67. super().__init__(size=9, num_crossings=2)
  68. class LavaCrossingS9N3Env(LavaCrossingEnv):
  69. def __init__(self):
  70. super().__init__(size=9, num_crossings=3)
  71. class LavaCrossingS11N5Env(LavaCrossingEnv):
  72. def __init__(self):
  73. super().__init__(size=11, num_crossings=5)
  74. register(
  75. id='MiniGrid-LavaCrossingS9N1-v0',
  76. entry_point='gym_minigrid.envs:LavaCrossingEnv'
  77. )
  78. register(
  79. id='MiniGrid-LavaCrossingS9N2-v0',
  80. entry_point='gym_minigrid.envs:LavaCrossingS9N2Env'
  81. )
  82. register(
  83. id='MiniGrid-LavaCrossingS9N3-v0',
  84. entry_point='gym_minigrid.envs:LavaCrossingS9N3Env'
  85. )
  86. register(
  87. id='MiniGrid-LavaCrossingS11N5-v0',
  88. entry_point='gym_minigrid.envs:LavaCrossingS11N5Env'
  89. )