simple_envs.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. from gym_minigrid.minigrid import *
  2. from gym_minigrid.register import register
  3. class EmptyEnv(MiniGridEnv):
  4. """
  5. Empty grid environment, no obstacles, sparse reward
  6. """
  7. def __init__(self, size=8):
  8. super().__init__(gridSize=size, maxSteps=3 * size)
  9. class EmptyEnv6x6(EmptyEnv):
  10. def __init__(self):
  11. super().__init__(size=6)
  12. class EmptyEnv16x16(EmptyEnv):
  13. def __init__(self):
  14. super().__init__(size=16)
  15. register(
  16. id='MiniGrid-Empty-6x6-v0',
  17. entry_point='gym_minigrid.envs:EmptyEnv6x6'
  18. )
  19. register(
  20. id='MiniGrid-Empty-8x8-v0',
  21. entry_point='gym_minigrid.envs:EmptyEnv'
  22. )
  23. register(
  24. id='MiniGrid-Empty-16x16-v0',
  25. entry_point='gym_minigrid.envs:EmptyEnv16x16'
  26. )
  27. class DoorKeyEnv(MiniGridEnv):
  28. """
  29. Environment with a door and key, sparse reward
  30. """
  31. def __init__(self, size=8):
  32. super().__init__(gridSize=size, maxSteps=4 * size)
  33. def _genGrid(self, width, height):
  34. grid = super()._genGrid(width, height)
  35. assert width == height
  36. gridSz = width
  37. # Create a vertical splitting wall
  38. splitIdx = self._randInt(2, gridSz-2)
  39. for i in range(0, gridSz):
  40. grid.set(splitIdx, i, Wall())
  41. # Place a door in the wall
  42. doorIdx = self._randInt(1, gridSz-2)
  43. grid.set(splitIdx, doorIdx, LockedDoor('yellow'))
  44. # Place a key on the left side
  45. #keyIdx = self._randInt(1 + gridSz // 2, gridSz-2)
  46. keyIdx = gridSz-2
  47. grid.set(1, keyIdx, Key('yellow'))
  48. return grid
  49. class DoorKeyEnv5x5(DoorKeyEnv):
  50. def __init__(self):
  51. super().__init__(size=5)
  52. class DoorKeyEnv6x6(DoorKeyEnv):
  53. def __init__(self):
  54. super().__init__(size=6)
  55. class DoorKeyEnv16x16(DoorKeyEnv):
  56. def __init__(self):
  57. super().__init__(size=16)
  58. register(
  59. id='MiniGrid-DoorKey-5x5-v0',
  60. entry_point='gym_minigrid.envs:DoorKeyEnv5x5'
  61. )
  62. register(
  63. id='MiniGrid-DoorKey-6x6-v0',
  64. entry_point='gym_minigrid.envs:DoorKeyEnv6x6'
  65. )
  66. register(
  67. id='MiniGrid-DoorKey-8x8-v0',
  68. entry_point='gym_minigrid.envs:DoorKeyEnv'
  69. )
  70. register(
  71. id='MiniGrid-DoorKey-16x16-v0',
  72. entry_point='gym_minigrid.envs:DoorKeyEnv16x16'
  73. )
  74. class Room:
  75. def __init__(self,
  76. top,
  77. size,
  78. entryDoorPos,
  79. exitDoorPos
  80. ):
  81. self.top = top
  82. self.size = size
  83. self.entryDoorPos = entryDoorPos
  84. self.exitDoorPos = exitDoorPos
  85. class MultiRoomEnv(MiniGridEnv):
  86. """
  87. Environment with multiple rooms (subgoals)
  88. """
  89. def __init__(self,
  90. minNumRooms,
  91. maxNumRooms,
  92. maxRoomSize=10
  93. ):
  94. assert minNumRooms > 0
  95. assert maxNumRooms >= minNumRooms
  96. assert maxRoomSize >= 4
  97. self.minNumRooms = minNumRooms
  98. self.maxNumRooms = maxNumRooms
  99. self.maxRoomSize = maxRoomSize
  100. self.rooms = []
  101. super(MultiRoomEnv, self).__init__(
  102. gridSize=25,
  103. maxSteps=self.maxNumRooms * 20
  104. )
  105. def _genGrid(self, width, height):
  106. roomList = []
  107. # Choose a random number of rooms to generate
  108. numRooms = self._randInt(self.minNumRooms, self.maxNumRooms+1)
  109. while len(roomList) < numRooms:
  110. curRoomList = []
  111. entryDoorPos = (
  112. self._randInt(0, width - 2),
  113. self._randInt(0, width - 2)
  114. )
  115. # Recursively place the rooms
  116. self._placeRoom(
  117. numRooms,
  118. roomList=curRoomList,
  119. minSz=4,
  120. maxSz=self.maxRoomSize,
  121. entryDoorWall=2,
  122. entryDoorPos=entryDoorPos
  123. )
  124. if len(curRoomList) > len(roomList):
  125. roomList = curRoomList
  126. # Store the list of rooms in this environment
  127. assert len(roomList) > 0
  128. self.rooms = roomList
  129. # Randomize the starting agent position and direction
  130. topX, topY = roomList[0].top
  131. sizeX, sizeY = roomList[0].size
  132. self.startPos = (
  133. self._randInt(topX + 1, topX + sizeX - 2),
  134. self._randInt(topY + 1, topY + sizeY - 2)
  135. )
  136. self.startDir = self._randInt(0, 4)
  137. # Create the grid
  138. grid = Grid(width, height)
  139. wall = Wall()
  140. prevDoorColor = None
  141. # For each room
  142. for idx, room in enumerate(roomList):
  143. topX, topY = room.top
  144. sizeX, sizeY = room.size
  145. # Draw the top and bottom walls
  146. for i in range(0, sizeX):
  147. grid.set(topX + i, topY, wall)
  148. grid.set(topX + i, topY + sizeY - 1, wall)
  149. # Draw the left and right walls
  150. for j in range(0, sizeY):
  151. grid.set(topX, topY + j, wall)
  152. grid.set(topX + sizeX - 1, topY + j, wall)
  153. # If this isn't the first room, place the entry door
  154. if idx > 0:
  155. # Pick a door color different from the previous one
  156. doorColors = set(COLORS.keys())
  157. if prevDoorColor:
  158. doorColors.remove(prevDoorColor)
  159. doorColor = self._randElem(doorColors)
  160. entryDoor = Door(doorColor)
  161. grid.set(*room.entryDoorPos, entryDoor)
  162. prevDoorColor = doorColor
  163. prevRoom = roomList[idx-1]
  164. prevRoom.exitDoorPos = room.entryDoorPos
  165. # Place the final goal
  166. while True:
  167. self.goalPos = (
  168. self._randInt(topX + 1, topX + sizeX - 1),
  169. self._randInt(topY + 1, topY + sizeY - 1)
  170. )
  171. # Make sure the goal doesn't overlap with the agent
  172. if self.goalPos != self.startPos:
  173. grid.set(*self.goalPos, Goal())
  174. break
  175. return grid
  176. def _placeRoom(
  177. self,
  178. numLeft,
  179. roomList,
  180. minSz,
  181. maxSz,
  182. entryDoorWall,
  183. entryDoorPos
  184. ):
  185. # Choose the room size randomly
  186. sizeX = self._randInt(minSz, maxSz+1)
  187. sizeY = self._randInt(minSz, maxSz+1)
  188. # The first room will be at the door position
  189. if len(roomList) == 0:
  190. topX, topY = entryDoorPos
  191. # Entry on the right
  192. elif entryDoorWall == 0:
  193. topX = entryDoorPos[0] - sizeX + 1
  194. y = entryDoorPos[1]
  195. topY = self._randInt(y - sizeY + 2, y)
  196. # Entry wall on the south
  197. elif entryDoorWall == 1:
  198. x = entryDoorPos[0]
  199. topX = self._randInt(x - sizeX + 2, x)
  200. topY = entryDoorPos[1] - sizeY + 1
  201. # Entry wall on the left
  202. elif entryDoorWall == 2:
  203. topX = entryDoorPos[0]
  204. y = entryDoorPos[1]
  205. topY = self._randInt(y - sizeY + 2, y)
  206. # Entry wall on the top
  207. elif entryDoorWall == 3:
  208. x = entryDoorPos[0]
  209. topX = self._randInt(x - sizeX + 2, x)
  210. topY = entryDoorPos[1]
  211. else:
  212. assert False, entryDoorWall
  213. # If the room is out of the grid, can't place a room here
  214. if topX < 0 or topY < 0:
  215. return False
  216. if topX + sizeX > self.gridSize or topY + sizeY >= self.gridSize:
  217. return False
  218. # If the room intersects with previous rooms, can't place it here
  219. for room in roomList[:-1]:
  220. nonOverlap = \
  221. topX + sizeX < room.top[0] or \
  222. room.top[0] + room.size[0] <= topX or \
  223. topY + sizeY < room.top[1] or \
  224. room.top[1] + room.size[1] <= topY
  225. if not nonOverlap:
  226. return False
  227. # Add this room to the list
  228. roomList.append(Room(
  229. (topX, topY),
  230. (sizeX, sizeY),
  231. entryDoorPos,
  232. None
  233. ))
  234. # If this was the last room, stop
  235. if numLeft == 1:
  236. return True
  237. # Try placing the next room
  238. for i in range(0, 8):
  239. # Pick which wall to place the out door on
  240. wallSet = set((0, 1, 2, 3))
  241. wallSet.remove(entryDoorWall)
  242. exitDoorWall = self._randElem(wallSet)
  243. nextEntryWall = (exitDoorWall + 2) % 4
  244. # Pick the exit door position
  245. # Exit on right wall
  246. if exitDoorWall == 0:
  247. exitDoorPos = (
  248. topX + sizeX - 1,
  249. topY + self._randInt(1, sizeY - 1)
  250. )
  251. # Exit on south wall
  252. elif exitDoorWall == 1:
  253. exitDoorPos = (
  254. topX + self._randInt(1, sizeX - 1),
  255. topY + sizeY - 1
  256. )
  257. # Exit on left wall
  258. elif exitDoorWall == 2:
  259. exitDoorPos = (
  260. topX,
  261. topY + self._randInt(1, sizeY - 1)
  262. )
  263. # Exit on north wall
  264. elif exitDoorWall == 3:
  265. exitDoorPos = (
  266. topX + self._randInt(1, sizeX - 1),
  267. topY
  268. )
  269. else:
  270. assert False
  271. # Recursively create the other rooms
  272. success = self._placeRoom(
  273. numLeft - 1,
  274. roomList=roomList,
  275. minSz=minSz,
  276. maxSz=maxSz,
  277. entryDoorWall=nextEntryWall,
  278. entryDoorPos=exitDoorPos
  279. )
  280. if success:
  281. break
  282. return True
  283. class MultiRoomEnvN6(MultiRoomEnv):
  284. def __init__(self):
  285. super(MultiRoomEnvN6, self).__init__(
  286. minNumRooms=6,
  287. maxNumRooms=6
  288. )
  289. register(
  290. id='MiniGrid-MultiRoom-N6-v0',
  291. entry_point='gym_minigrid.envs:MultiRoomEnvN6',
  292. reward_threshold=1000.0
  293. )
  294. class FetchEnv(MiniGridEnv):
  295. """
  296. Environment in which the agent has to fetch a random object
  297. named using English text strings
  298. """
  299. def __init__(
  300. self,
  301. size=8,
  302. numObjs=3):
  303. self.numObjs = numObjs
  304. super(FetchEnv, self).__init__(gridSize=size, maxSteps=5*size)
  305. def _genGrid(self, width, height):
  306. assert width == height
  307. gridSz = width
  308. # Create a grid surrounded by walls
  309. grid = Grid(width, height)
  310. for i in range(0, width):
  311. grid.set(i, 0, Wall())
  312. grid.set(i, height-1, Wall())
  313. for j in range(0, height):
  314. grid.set(0, j, Wall())
  315. grid.set(width-1, j, Wall())
  316. types = ['key', 'ball']
  317. colors = list(COLORS.keys())
  318. objs = []
  319. # For each object to be generated
  320. for i in range(0, self.numObjs):
  321. objType = self._randElem(types)
  322. objColor = self._randElem(colors)
  323. if objType == 'key':
  324. obj = Key(objColor)
  325. elif objType == 'ball':
  326. obj = Ball(objColor)
  327. while True:
  328. pos = (
  329. self._randInt(1, gridSz - 1),
  330. self._randInt(1, gridSz - 1)
  331. )
  332. if pos != self.startPos:
  333. grid.set(*pos, obj)
  334. break
  335. objs.append(obj)
  336. # Choose a random object to be picked up
  337. target = objs[self._randInt(0, len(objs))]
  338. self.targetType = target.type
  339. self.targetColor = target.color
  340. descStr = '%s %s' % (self.targetColor, self.targetType)
  341. # Generate the mission string
  342. idx = self._randInt(0, 5)
  343. if idx == 0:
  344. self.mission = 'get a %s' % descStr
  345. elif idx == 1:
  346. self.mission = 'go get a %s' % descStr
  347. elif idx == 2:
  348. self.mission = 'fetch a %s' % descStr
  349. elif idx == 3:
  350. self.mission = 'go fetch a %s' % descStr
  351. elif idx == 4:
  352. self.mission = 'you must fetch a %s' % descStr
  353. assert hasattr(self, 'mission')
  354. return grid
  355. def _reset(self):
  356. obs = MiniGridEnv._reset(self)
  357. obs = {
  358. 'image': obs,
  359. 'mission': self.mission,
  360. 'advice' : ''
  361. }
  362. return obs
  363. def _step(self, action):
  364. obs, reward, done, info = MiniGridEnv._step(self, action)
  365. if self.carrying:
  366. if self.carrying.color == self.targetColor and \
  367. self.carrying.type == self.targetType:
  368. reward = 1000 - self.stepCount
  369. done = True
  370. else:
  371. reward = -1000
  372. done = True
  373. obs = {
  374. 'image': obs,
  375. 'mission': self.mission,
  376. 'advice': ''
  377. }
  378. return obs, reward, done, info
  379. register(
  380. id='MiniGrid-Fetch-8x8-v0',
  381. entry_point='gym_minigrid.envs:FetchEnv'
  382. )