simple_envs.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. from gym_minigrid.minigrid import *
  2. from gym_minigrid.register import register
  3. class EmptyEnv(MiniGridEnv):
  4. """
  5. Empty grid environment, no obstacles, sparse reward
  6. """
  7. def __init__(self, size=8):
  8. super().__init__(gridSize=size, maxSteps=3 * size)
  9. class EmptyEnv6x6(EmptyEnv):
  10. def __init__(self):
  11. super().__init__(size=6)
  12. class EmptyEnv16x16(EmptyEnv):
  13. def __init__(self):
  14. super().__init__(size=16)
  15. register(
  16. id='MiniGrid-Empty-6x6-v0',
  17. entry_point='gym_minigrid.envs:EmptyEnv6x6'
  18. )
  19. register(
  20. id='MiniGrid-Empty-8x8-v0',
  21. entry_point='gym_minigrid.envs:EmptyEnv'
  22. )
  23. register(
  24. id='MiniGrid-Empty-16x16-v0',
  25. entry_point='gym_minigrid.envs:EmptyEnv16x16'
  26. )
  27. class DoorKeyEnv(MiniGridEnv):
  28. """
  29. Environment with a door and key, sparse reward
  30. """
  31. def __init__(self, size=8):
  32. super().__init__(gridSize=size, maxSteps=4 * size)
  33. def _genGrid(self, width, height):
  34. grid = super()._genGrid(width, height)
  35. assert width == height
  36. gridSz = width
  37. # Create a vertical splitting wall
  38. splitIdx = self._randInt(2, gridSz-3)
  39. for i in range(0, gridSz):
  40. grid.set(splitIdx, i, Wall())
  41. # Place a door in the wall
  42. doorIdx = self._randInt(1, gridSz-2)
  43. grid.set(splitIdx, doorIdx, LockedDoor('yellow'))
  44. # Place a key on the left side
  45. #keyIdx = self._randInt(1 + gridSz // 2, gridSz-2)
  46. keyIdx = gridSz-2
  47. grid.set(1, keyIdx, Key('yellow'))
  48. return grid
  49. class DoorKeyEnv16x16(DoorKeyEnv):
  50. def __init__(self):
  51. super().__init__(size=16)
  52. register(
  53. id='MiniGrid-DoorKey-8x8-v0',
  54. entry_point='gym_minigrid.envs:DoorKeyEnv'
  55. )
  56. register(
  57. id='MiniGrid-DoorKey-16x16-v0',
  58. entry_point='gym_minigrid.envs:DoorKeyEnv16x16'
  59. )
  60. class Room:
  61. def __init__(self,
  62. top,
  63. size,
  64. entryDoorPos,
  65. exitDoorPos
  66. ):
  67. self.top = top
  68. self.size = size
  69. self.entryDoorPos = entryDoorPos
  70. self.exitDoorPos = exitDoorPos
  71. class MultiRoomEnv(MiniGridEnv):
  72. """
  73. Environment with multiple rooms (subgoals)
  74. """
  75. def __init__(self,
  76. minNumRooms,
  77. maxNumRooms,
  78. maxRoomSize=10
  79. ):
  80. assert minNumRooms > 0
  81. assert maxNumRooms >= minNumRooms
  82. assert maxRoomSize >= 4
  83. self.minNumRooms = minNumRooms
  84. self.maxNumRooms = maxNumRooms
  85. self.maxRoomSize = maxRoomSize
  86. self.rooms = []
  87. super(MultiRoomEnv, self).__init__(
  88. gridSize=25,
  89. maxSteps=self.maxNumRooms * 20
  90. )
  91. def _genGrid(self, width, height):
  92. roomList = []
  93. # Choose a random number of rooms to generate
  94. numRooms = self._randInt(self.minNumRooms, self.maxNumRooms+1)
  95. while len(roomList) < numRooms:
  96. curRoomList = []
  97. entryDoorPos = (
  98. self._randInt(0, width - 2),
  99. self._randInt(0, width - 2)
  100. )
  101. # Recursively place the rooms
  102. self._placeRoom(
  103. numRooms,
  104. roomList=curRoomList,
  105. minSz=4,
  106. maxSz=self.maxRoomSize,
  107. entryDoorWall=2,
  108. entryDoorPos=entryDoorPos
  109. )
  110. if len(curRoomList) > len(roomList):
  111. roomList = curRoomList
  112. # Store the list of rooms in this environment
  113. assert len(roomList) > 0
  114. self.rooms = roomList
  115. # Randomize the starting agent position and direction
  116. topX, topY = roomList[0].top
  117. sizeX, sizeY = roomList[0].size
  118. self.startPos = (
  119. self._randInt(topX + 1, topX + sizeX - 2),
  120. self._randInt(topY + 1, topY + sizeY - 2)
  121. )
  122. self.startDir = self._randInt(0, 4)
  123. # Create the grid
  124. grid = Grid(width, height)
  125. wall = Wall()
  126. prevDoorColor = None
  127. # For each room
  128. for idx, room in enumerate(roomList):
  129. topX, topY = room.top
  130. sizeX, sizeY = room.size
  131. # Draw the top and bottom walls
  132. for i in range(0, sizeX):
  133. grid.set(topX + i, topY, wall)
  134. grid.set(topX + i, topY + sizeY - 1, wall)
  135. # Draw the left and right walls
  136. for j in range(0, sizeY):
  137. grid.set(topX, topY + j, wall)
  138. grid.set(topX + sizeX - 1, topY + j, wall)
  139. # If this isn't the first room, place the entry door
  140. if idx > 0:
  141. # Pick a door color different from the previous one
  142. doorColors = set(COLORS.keys())
  143. if prevDoorColor:
  144. doorColors.remove(prevDoorColor)
  145. doorColor = self._randElem(doorColors)
  146. entryDoor = Door(doorColor)
  147. grid.set(*room.entryDoorPos, entryDoor)
  148. prevDoorColor = doorColor
  149. prevRoom = roomList[idx-1]
  150. prevRoom.exitDoorPos = room.entryDoorPos
  151. # Place the final goal
  152. while True:
  153. self.goalPos = (
  154. self._randInt(topX + 1, topX + sizeX - 1),
  155. self._randInt(topY + 1, topY + sizeY - 1)
  156. )
  157. # Make sure the goal doesn't overlap with the agent
  158. if self.goalPos != self.startPos:
  159. grid.set(*self.goalPos, Goal())
  160. break
  161. return grid
  162. def _placeRoom(
  163. self,
  164. numLeft,
  165. roomList,
  166. minSz,
  167. maxSz,
  168. entryDoorWall,
  169. entryDoorPos
  170. ):
  171. # Choose the room size randomly
  172. sizeX = self._randInt(minSz, maxSz+1)
  173. sizeY = self._randInt(minSz, maxSz+1)
  174. # The first room will be at the door position
  175. if len(roomList) == 0:
  176. topX, topY = entryDoorPos
  177. # Entry on the right
  178. elif entryDoorWall == 0:
  179. topX = entryDoorPos[0] - sizeX + 1
  180. y = entryDoorPos[1]
  181. topY = self._randInt(y - sizeY + 2, y)
  182. # Entry wall on the south
  183. elif entryDoorWall == 1:
  184. x = entryDoorPos[0]
  185. topX = self._randInt(x - sizeX + 2, x)
  186. topY = entryDoorPos[1] - sizeY + 1
  187. # Entry wall on the left
  188. elif entryDoorWall == 2:
  189. topX = entryDoorPos[0]
  190. y = entryDoorPos[1]
  191. topY = self._randInt(y - sizeY + 2, y)
  192. # Entry wall on the top
  193. elif entryDoorWall == 3:
  194. x = entryDoorPos[0]
  195. topX = self._randInt(x - sizeX + 2, x)
  196. topY = entryDoorPos[1]
  197. else:
  198. assert False, entryDoorWall
  199. # If the room is out of the grid, can't place a room here
  200. if topX < 0 or topY < 0:
  201. return False
  202. if topX + sizeX > self.gridSize or topY + sizeY >= self.gridSize:
  203. return False
  204. # If the room intersects with previous rooms, can't place it here
  205. for room in roomList[:-1]:
  206. nonOverlap = \
  207. topX + sizeX < room.top[0] or \
  208. room.top[0] + room.size[0] <= topX or \
  209. topY + sizeY < room.top[1] or \
  210. room.top[1] + room.size[1] <= topY
  211. if not nonOverlap:
  212. return False
  213. # Add this room to the list
  214. roomList.append(Room(
  215. (topX, topY),
  216. (sizeX, sizeY),
  217. entryDoorPos,
  218. None
  219. ))
  220. # If this was the last room, stop
  221. if numLeft == 1:
  222. return True
  223. # Try placing the next room
  224. for i in range(0, 8):
  225. # Pick which wall to place the out door on
  226. wallSet = set((0, 1, 2, 3))
  227. wallSet.remove(entryDoorWall)
  228. exitDoorWall = self._randElem(wallSet)
  229. nextEntryWall = (exitDoorWall + 2) % 4
  230. # Pick the exit door position
  231. # Exit on right wall
  232. if exitDoorWall == 0:
  233. exitDoorPos = (
  234. topX + sizeX - 1,
  235. topY + self._randInt(1, sizeY - 1)
  236. )
  237. # Exit on south wall
  238. elif exitDoorWall == 1:
  239. exitDoorPos = (
  240. topX + self._randInt(1, sizeX - 1),
  241. topY + sizeY - 1
  242. )
  243. # Exit on left wall
  244. elif exitDoorWall == 2:
  245. exitDoorPos = (
  246. topX,
  247. topY + self._randInt(1, sizeY - 1)
  248. )
  249. # Exit on north wall
  250. elif exitDoorWall == 3:
  251. exitDoorPos = (
  252. topX + self._randInt(1, sizeX - 1),
  253. topY
  254. )
  255. else:
  256. assert False
  257. # Recursively create the other rooms
  258. success = self._placeRoom(
  259. numLeft - 1,
  260. roomList=roomList,
  261. minSz=minSz,
  262. maxSz=maxSz,
  263. entryDoorWall=nextEntryWall,
  264. entryDoorPos=exitDoorPos
  265. )
  266. if success:
  267. break
  268. return True
  269. class MultiRoomEnvN6(MultiRoomEnv):
  270. def __init__(self):
  271. super(MultiRoomEnvN6, self).__init__(
  272. minNumRooms=6,
  273. maxNumRooms=6
  274. )
  275. register(
  276. id='MiniGrid-MultiRoom-N6-v0',
  277. entry_point='gym_minigrid.envs:MultiRoomEnvN6',
  278. reward_threshold=1000.0
  279. )
  280. class FetchEnv(MiniGridEnv):
  281. """
  282. Environment in which the agent has to fetch a random object
  283. named using English text strings
  284. """
  285. def __init__(
  286. self,
  287. size=8,
  288. numObjs=3):
  289. self.numObjs = numObjs
  290. super(FetchEnv, self).__init__(gridSize=size, maxSteps=5*size)
  291. def _genGrid(self, width, height):
  292. assert width == height
  293. gridSz = width
  294. # Create a grid surrounded by walls
  295. grid = Grid(width, height)
  296. for i in range(0, width):
  297. grid.set(i, 0, Wall())
  298. grid.set(i, height-1, Wall())
  299. for j in range(0, height):
  300. grid.set(0, j, Wall())
  301. grid.set(width-1, j, Wall())
  302. types = ['key', 'ball']
  303. colors = list(COLORS.keys())
  304. objs = []
  305. # For each object to be generated
  306. for i in range(0, self.numObjs):
  307. objType = self._randElem(types)
  308. objColor = self._randElem(colors)
  309. if objType == 'key':
  310. obj = Key(objColor)
  311. elif objType == 'ball':
  312. obj = Ball(objColor)
  313. while True:
  314. pos = (
  315. self._randInt(1, gridSz - 1),
  316. self._randInt(1, gridSz - 1)
  317. )
  318. if pos != self.startPos:
  319. grid.set(*pos, obj)
  320. break
  321. objs.append(obj)
  322. # Choose a random object to be picked up
  323. target = objs[self._randInt(0, len(objs))]
  324. self.targetType = target.type
  325. self.targetColor = target.color
  326. descStr = '%s %s' % (self.targetColor, self.targetType)
  327. # Generate the mission string
  328. idx = self._randInt(0, 5)
  329. if idx == 0:
  330. self.mission = 'get a %s' % descStr
  331. elif idx == 1:
  332. self.mission = 'go get a %s' % descStr
  333. elif idx == 2:
  334. self.mission = 'fetch a %s' % descStr
  335. elif idx == 3:
  336. self.mission = 'go fetch a %s' % descStr
  337. elif idx == 4:
  338. self.mission = 'you must fetch a %s' % descStr
  339. assert hasattr(self, 'mission')
  340. return grid
  341. def _reset(self):
  342. obs = MiniGridEnv._reset(self)
  343. obs = {
  344. 'image': obs,
  345. 'mission': self.mission,
  346. 'advice' : ''
  347. }
  348. return obs
  349. def _step(self, action):
  350. obs, reward, done, info = MiniGridEnv._step(self, action)
  351. if self.carrying:
  352. if self.carrying.color == self.targetColor and \
  353. self.carrying.type == self.targetType:
  354. reward = 1000 - self.stepCount
  355. done = True
  356. else:
  357. reward = -1000
  358. done = True
  359. obs = {
  360. 'image': obs,
  361. 'mission': self.mission,
  362. 'advice': ''
  363. }
  364. return obs, reward, done, info
  365. register(
  366. id='MiniGrid-Fetch-8x8-v0',
  367. entry_point='gym_minigrid.envs:FetchEnv'
  368. )