roomgrid.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. from gym_minigrid.minigrid import *
  2. from gym_minigrid.register import register
  3. class Room:
  4. def __init__(self,
  5. top,
  6. size
  7. ):
  8. # Top-left corner and size (tuples)
  9. self.top = top
  10. self.size = size
  11. # List of door objects and door positions
  12. self.doors = []
  13. self.doorPos = []
  14. # Indicates if this room is locked
  15. self.locked = False
  16. # TODO: connectivity?
  17. # List of objects contained
  18. self.objs = []
  19. def randPos(self, env):
  20. topX, topY = self.top
  21. sizeX, sizeY = self.size
  22. return env._randPos(
  23. topX + 1, topX + sizeX - 1,
  24. topY + 1, topY + sizeY - 1
  25. )
  26. class RoomGrid(MiniGridEnv):
  27. """
  28. Environment with multiple rooms and random objects.
  29. This is meant to serve as a base class for other environments.
  30. """
  31. def __init__(
  32. self,
  33. roomSize=6,
  34. numCols=4,
  35. maxObsPerRoom=3,
  36. lockedRooms=False
  37. ):
  38. assert roomSize > 0
  39. assert roomSize >= 4
  40. assert numCols > 0
  41. self.roomSize = roomSize
  42. self.numCols = numCols
  43. self.numRows = numCols
  44. self.maxObsPerRoom = maxObsPerRoom
  45. self.lockedRooms = False
  46. gridSize = (roomSize - 1) * numCols + 1
  47. super().__init__(gridSize=gridSize, maxSteps=6*gridSize)
  48. self.reward_range = (0, 1)
  49. def _genGrid(self, width, height):
  50. # Create the grid
  51. self.grid = Grid(width, height)
  52. # Generate the surrounding walls
  53. self.grid.horzWall(0, 0)
  54. self.grid.horzWall(0, height-1)
  55. self.grid.vertWall(0, 0)
  56. self.grid.vertWall(width-1, 0)
  57. roomW = self.roomSize
  58. roomH = self.roomSize
  59. self.rooms = []
  60. # Generate the list of rooms
  61. for j in range(0, self.numRows):
  62. for i in range(0, self.numCols):
  63. room = Room(
  64. (i * (self.roomSize-1), j * (self.roomSize-1)),
  65. (self.roomSize, self.roomSize)
  66. )
  67. self.rooms.append(room)
  68. # TODO: generate walls
  69. # May want to add function to Grid class, wallRect(i, j, w, h, color)
  70. # Randomize the player start position and orientation
  71. self.placeAgent()
  72. # TODO: respect maxObsPerRoom
  73. # Place random objects in the world
  74. types = ['key', 'ball', 'box']
  75. for i in range(0, 12):
  76. objType = self._randElem(types)
  77. objColor = self._randElem(COLOR_NAMES)
  78. if objType == 'key':
  79. obj = Key(objColor)
  80. elif objType == 'ball':
  81. obj = Ball(objColor)
  82. elif objType == 'box':
  83. obj = Box(objColor)
  84. self.placeObj(obj)
  85. # TODO: curriculum generation
  86. self.mission = ''
  87. def step(self, action):
  88. obs, reward, done, info = super().step(action)
  89. return obs, reward, done, info
  90. register(
  91. id='MiniGrid-RoomGrid-v0',
  92. entry_point='gym_minigrid.envs:RoomGrid'
  93. )