simple_envs.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. from gym_minigrid.envs.minigrid import *
  2. from gym_minigrid.register import register
  3. class EmptyEnv(MiniGridEnv):
  4. """
  5. Empty grid environment, no obstacles, sparse reward
  6. """
  7. def __init__(self, size=8):
  8. super(EmptyEnv, self).__init__(gridSize=size, maxSteps=2 * size)
  9. class EmptyEnv6x6(EmptyEnv):
  10. def __init__(self):
  11. super(EmptyEnv6x6, self).__init__(size=6)
  12. register(
  13. id='MiniGrid-Empty-8x8-v0',
  14. entry_point='gym_minigrid.envs:EmptyEnv'
  15. )
  16. register(
  17. id='MiniGrid-Empty-6x6-v0',
  18. entry_point='gym_minigrid.envs:EmptyEnv6x6'
  19. )
  20. class DoorKeyEnv(MiniGridEnv):
  21. """
  22. Environment with a door and key, sparse reward
  23. """
  24. def __init__(self, size=8):
  25. super(DoorKeyEnv, self).__init__(gridSize=size, maxSteps=4 * size)
  26. def _genGrid(self, width, height):
  27. grid = super(DoorKeyEnv, self)._genGrid(width, height)
  28. assert width == height
  29. gridSz = width
  30. # Create a vertical splitting wall
  31. splitIdx = self._randInt(2, gridSz-3)
  32. for i in range(0, gridSz):
  33. grid.set(splitIdx, i, Wall())
  34. # Place a door in the wall
  35. doorIdx = self._randInt(1, gridSz-2)
  36. grid.set(splitIdx, doorIdx, Door('yellow'))
  37. # Place a key on the left side
  38. #keyIdx = self._randInt(1 + gridSz // 2, gridSz-2)
  39. keyIdx = gridSz-2
  40. grid.set(1, keyIdx, Key('yellow'))
  41. return grid
  42. class DoorKeyEnv16x16(DoorKeyEnv):
  43. def __init__(self):
  44. super(DoorKeyEnv16x16, self).__init__(size=16)
  45. register(
  46. id='MiniGrid-Door-Key-8x8-v0',
  47. entry_point='gym_minigrid.envs:DoorKeyEnv'
  48. )
  49. register(
  50. id='MiniGrid-Door-Key-16x16-v0',
  51. entry_point='gym_minigrid.envs:DoorKeyEnv16x16'
  52. )
  53. class Room:
  54. def __init__(self,
  55. top,
  56. size,
  57. entryDoorPos,
  58. exitDoorPos
  59. ):
  60. self.top = top
  61. self.size = size
  62. self.entryDoorPos = entryDoorPos
  63. self.exitDoorPos = exitDoorPos
  64. class MultiRoomEnv(MiniGridEnv):
  65. """
  66. Environment with multiple rooms (subgoals)
  67. """
  68. def __init__(self,
  69. minNumRooms,
  70. maxNumRooms,
  71. maxRoomSize=10
  72. ):
  73. assert minNumRooms > 0
  74. assert maxNumRooms >= minNumRooms
  75. assert maxRoomSize >= 4
  76. self.minNumRooms = minNumRooms
  77. self.maxNumRooms = maxNumRooms
  78. self.maxRoomSize = maxRoomSize
  79. self.rooms = []
  80. super(MultiRoomEnv, self).__init__(
  81. gridSize=25,
  82. maxSteps=self.maxNumRooms * 20
  83. )
  84. def _genGrid(self, width, height):
  85. roomList = []
  86. # Choose a random number of rooms to generate
  87. numRooms = self._randInt(self.minNumRooms, self.maxNumRooms+1)
  88. while len(roomList) < numRooms:
  89. curRoomList = []
  90. entryDoorPos = (
  91. self._randInt(0, width - 2),
  92. self._randInt(0, width - 2)
  93. )
  94. # Recursively place the rooms
  95. self._placeRoom(
  96. numRooms,
  97. roomList=curRoomList,
  98. minSz=4,
  99. maxSz=self.maxRoomSize,
  100. entryDoorWall=2,
  101. entryDoorPos=entryDoorPos
  102. )
  103. if len(curRoomList) > len(roomList):
  104. roomList = curRoomList
  105. # Store the list of rooms in this environment
  106. assert len(roomList) > 0
  107. self.rooms = roomList
  108. # Randomize the starting agent position and direction
  109. topX, topY = roomList[0].top
  110. sizeX, sizeY = roomList[0].size
  111. self.startPos = (
  112. self._randInt(topX + 1, topX + sizeX - 2),
  113. self._randInt(topY + 1, topY + sizeY - 2)
  114. )
  115. self.startDir = self._randInt(0, 4)
  116. # Create the grid
  117. grid = Grid(width, height)
  118. wall = Wall()
  119. prevDoorColor = None
  120. # For each room
  121. for idx, room in enumerate(roomList):
  122. topX, topY = room.top
  123. sizeX, sizeY = room.size
  124. # Draw the top and bottom walls
  125. for i in range(0, sizeX):
  126. grid.set(topX + i, topY, wall)
  127. grid.set(topX + i, topY + sizeY - 1, wall)
  128. # Draw the left and right walls
  129. for j in range(0, sizeY):
  130. grid.set(topX, topY + j, wall)
  131. grid.set(topX + sizeX - 1, topY + j, wall)
  132. # If this isn't the first room, place the entry door
  133. if idx > 0:
  134. # Pick a door color different from the previous one
  135. doorColors = set(COLORS.keys())
  136. if prevDoorColor:
  137. doorColors.remove(prevDoorColor)
  138. doorColor = self._randElem(doorColors)
  139. entryDoor = Door(doorColor)
  140. grid.set(*room.entryDoorPos, entryDoor)
  141. prevDoorColor = doorColor
  142. prevRoom = roomList[idx-1]
  143. prevRoom.exitDoorPos = room.entryDoorPos
  144. # Place the final goal
  145. while True:
  146. self.goalPos = (
  147. self._randInt(topX + 1, topX + sizeX - 1),
  148. self._randInt(topY + 1, topY + sizeY - 1)
  149. )
  150. # Make sure the goal doesn't overlap with the agent
  151. if self.goalPos != self.startPos:
  152. grid.set(*self.goalPos, Goal())
  153. break
  154. return grid
  155. def _placeRoom(
  156. self,
  157. numLeft,
  158. roomList,
  159. minSz,
  160. maxSz,
  161. entryDoorWall,
  162. entryDoorPos
  163. ):
  164. # Choose the room size randomly
  165. sizeX = self._randInt(minSz, maxSz+1)
  166. sizeY = self._randInt(minSz, maxSz+1)
  167. # The first room will be at the door position
  168. if len(roomList) == 0:
  169. topX, topY = entryDoorPos
  170. # Entry on the right
  171. elif entryDoorWall == 0:
  172. topX = entryDoorPos[0] - sizeX + 1
  173. y = entryDoorPos[1]
  174. topY = self._randInt(y - sizeY + 2, y)
  175. # Entry wall on the south
  176. elif entryDoorWall == 1:
  177. x = entryDoorPos[0]
  178. topX = self._randInt(x - sizeX + 2, x)
  179. topY = entryDoorPos[1] - sizeY + 1
  180. # Entry wall on the left
  181. elif entryDoorWall == 2:
  182. topX = entryDoorPos[0]
  183. y = entryDoorPos[1]
  184. topY = self._randInt(y - sizeY + 2, y)
  185. # Entry wall on the top
  186. elif entryDoorWall == 3:
  187. x = entryDoorPos[0]
  188. topX = self._randInt(x - sizeX + 2, x)
  189. topY = entryDoorPos[1]
  190. else:
  191. assert False, entryDoorWall
  192. # If the room is out of the grid, can't place a room here
  193. if topX < 0 or topY < 0:
  194. return False
  195. if topX + sizeX > self.gridSize or topY + sizeY >= self.gridSize:
  196. return False
  197. # If the room intersects with previous rooms, can't place it here
  198. for room in roomList[:-1]:
  199. nonOverlap = \
  200. topX + sizeX < room.top[0] or \
  201. room.top[0] + room.size[0] <= topX or \
  202. topY + sizeY < room.top[1] or \
  203. room.top[1] + room.size[1] <= topY
  204. if not nonOverlap:
  205. return False
  206. # Add this room to the list
  207. roomList.append(Room(
  208. (topX, topY),
  209. (sizeX, sizeY),
  210. entryDoorPos,
  211. None
  212. ))
  213. # If this was the last room, stop
  214. if numLeft == 1:
  215. return True
  216. # Try placing the next room
  217. for i in range(0, 8):
  218. # Pick which wall to place the out door on
  219. wallSet = set((0, 1, 2, 3))
  220. wallSet.remove(entryDoorWall)
  221. exitDoorWall = self._randElem(wallSet)
  222. nextEntryWall = (exitDoorWall + 2) % 4
  223. # Pick the exit door position
  224. # Exit on right wall
  225. if exitDoorWall == 0:
  226. exitDoorPos = (
  227. topX + sizeX - 1,
  228. topY + self._randInt(1, sizeY - 1)
  229. )
  230. # Exit on south wall
  231. elif exitDoorWall == 1:
  232. exitDoorPos = (
  233. topX + self._randInt(1, sizeX - 1),
  234. topY + sizeY - 1
  235. )
  236. # Exit on left wall
  237. elif exitDoorWall == 2:
  238. exitDoorPos = (
  239. topX,
  240. topY + self._randInt(1, sizeY - 1)
  241. )
  242. # Exit on north wall
  243. elif exitDoorWall == 3:
  244. exitDoorPos = (
  245. topX + self._randInt(1, sizeX - 1),
  246. topY
  247. )
  248. else:
  249. assert False
  250. # Recursively create the other rooms
  251. success = self._placeRoom(
  252. numLeft - 1,
  253. roomList=roomList,
  254. minSz=minSz,
  255. maxSz=maxSz,
  256. entryDoorWall=nextEntryWall,
  257. entryDoorPos=exitDoorPos
  258. )
  259. if success:
  260. break
  261. return True
  262. class MultiRoomEnvN6(MultiRoomEnv):
  263. def __init__(self):
  264. super(MultiRoomEnvN6, self).__init__(
  265. minNumRooms=6,
  266. maxNumRooms=6
  267. )
  268. register(
  269. id='MiniGrid-Multi-Room-N6-v0',
  270. entry_point='gym_minigrid.envs:MultiRoomEnvN6',
  271. reward_threshold=1000.0
  272. )
  273. class FetchEnv(MiniGridEnv):
  274. """
  275. Environment in which the agent has to fetch a random object
  276. named using English text strings
  277. """
  278. def __init__(
  279. self,
  280. size=8,
  281. numObjs=3):
  282. self.numObjs = numObjs
  283. super(FetchEnv, self).__init__(gridSize=size, maxSteps=5*size)
  284. def _genGrid(self, width, height):
  285. assert width == height
  286. gridSz = width
  287. # Create a grid surrounded by walls
  288. grid = Grid(width, height)
  289. for i in range(0, width):
  290. grid.set(i, 0, Wall())
  291. grid.set(i, height-1, Wall())
  292. for j in range(0, height):
  293. grid.set(0, j, Wall())
  294. grid.set(width-1, j, Wall())
  295. types = ['key', 'ball']
  296. colors = list(COLORS.keys())
  297. objs = []
  298. # For each object to be generated
  299. for i in range(0, self.numObjs):
  300. objType = self._randElem(types)
  301. objColor = self._randElem(colors)
  302. if objType == 'key':
  303. obj = Key(objColor)
  304. elif objType == 'ball':
  305. obj = Ball(objColor)
  306. while True:
  307. pos = (
  308. self._randInt(1, gridSz - 1),
  309. self._randInt(1, gridSz - 1)
  310. )
  311. if pos != self.startPos:
  312. grid.set(*pos, obj)
  313. break
  314. objs.append(obj)
  315. # Choose a random object to be picked up
  316. target = objs[self._randInt(0, len(objs))]
  317. self.targetType = target.type
  318. self.targetColor = target.color
  319. descStr = '%s %s' % (self.targetColor, self.targetType)
  320. # Generate the mission string
  321. idx = self._randInt(0, 5)
  322. if idx == 0:
  323. self.mission = 'get a %s' % descStr
  324. elif idx == 1:
  325. self.mission = 'go get a %s' % descStr
  326. elif idx == 2:
  327. self.mission = 'fetch a %s' % descStr
  328. elif idx == 3:
  329. self.mission = 'go fetch a %s' % descStr
  330. elif idx == 4:
  331. self.mission = 'you must fetch a %s' % descStr
  332. assert hasattr(self, 'mission')
  333. return grid
  334. def _reset(self):
  335. obs = MiniGridEnv._reset(self)
  336. obs = {
  337. 'image': obs,
  338. 'mission': self.mission,
  339. 'advice' : ''
  340. }
  341. return obs
  342. def _step(self, action):
  343. obs, reward, done, info = MiniGridEnv._step(self, action)
  344. if self.carrying:
  345. if self.carrying.color == self.targetColor and \
  346. self.carrying.type == self.targetType:
  347. reward = 1000 - self.stepCount
  348. done = True
  349. else:
  350. reward = -1000
  351. done = True
  352. obs = {
  353. 'image': obs,
  354. 'mission': self.mission,
  355. 'advice': ''
  356. }
  357. return obs, reward, done, info
  358. register(
  359. id='MiniGrid-Fetch-8x8-v0',
  360. entry_point='gym_minigrid.envs:FetchEnv'
  361. )