simple_envs.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. from gym.envs.registration import register
  2. from gym_minigrid.envs.minigrid_env import *
  3. class EmptyEnv(MiniGridEnv):
  4. """
  5. Empty grid environment, no obstacles, sparse reward
  6. """
  7. def __init__(self, size=8):
  8. super(EmptyEnv, self).__init__(gridSize=size, maxSteps=2 * size)
  9. class EmptyEnv6x6(EmptyEnv):
  10. def __init__(self):
  11. super(EmptyEnv6x6, self).__init__(size=6)
  12. register(
  13. id='MiniGrid-Empty-8x8-v0',
  14. entry_point='gym_minigrid.envs:EmptyEnv',
  15. reward_threshold=1000.0
  16. )
  17. register(
  18. id='-Empty-6x6-v0',
  19. entry_point='gym_minigrid.envs:EmptyEnv6x6',
  20. reward_threshold=1000.0
  21. )
  22. class DoorKeyEnv(MiniGridEnv):
  23. """
  24. Environment with a door and key, sparse reward
  25. """
  26. def __init__(self, size=8):
  27. super(DoorKeyEnv, self).__init__(gridSize=size, maxSteps=4 * size)
  28. def _genGrid(self, width, height):
  29. grid = super(DoorKeyEnv, self)._genGrid(width, height)
  30. assert width == height
  31. gridSz = width
  32. # Create a vertical splitting wall
  33. splitIdx = self._randInt(2, gridSz-3)
  34. for i in range(0, gridSz):
  35. grid.set(splitIdx, i, Wall())
  36. # Place a door in the wall
  37. doorIdx = self._randInt(1, gridSz-2)
  38. grid.set(splitIdx, doorIdx, Door('yellow'))
  39. # Place a key on the left side
  40. #keyIdx = self._randInt(1 + gridSz // 2, gridSz-2)
  41. keyIdx = gridSz-2
  42. grid.set(1, keyIdx, Key('yellow'))
  43. return grid
  44. class DoorKeyEnv16x16(DoorKeyEnv):
  45. def __init__(self):
  46. super(DoorKeyEnv16x16, self).__init__(size=16)
  47. register(
  48. id='-Door-Key-8x8-v0',
  49. entry_point='gym_minigrid.envs:DoorKeyEnv',
  50. reward_threshold=1000.0
  51. )
  52. register(
  53. id='-Door-Key-16x16-v0',
  54. entry_point='gym_minigrid.envs:DoorKeyEnv16x16',
  55. reward_threshold=1000.0
  56. )
  57. class Room:
  58. def __init__(self,
  59. top,
  60. size,
  61. entryDoorPos,
  62. exitDoorPos
  63. ):
  64. self.top = top
  65. self.size = size
  66. self.entryDoorPos = entryDoorPos
  67. self.exitDoorPos = exitDoorPos
  68. class MultiRoomEnv(MiniGridEnv):
  69. """
  70. Environment with multiple rooms (subgoals)
  71. """
  72. def __init__(self,
  73. minNumRooms,
  74. maxNumRooms,
  75. maxRoomSize=10
  76. ):
  77. assert minNumRooms > 0
  78. assert maxNumRooms >= minNumRooms
  79. assert maxRoomSize >= 4
  80. self.minNumRooms = minNumRooms
  81. self.maxNumRooms = maxNumRooms
  82. self.maxRoomSize = maxRoomSize
  83. self.rooms = []
  84. super(MultiRoomEnv, self).__init__(
  85. gridSize=25,
  86. maxSteps=self.maxNumRooms * 20
  87. )
  88. def _genGrid(self, width, height):
  89. roomList = []
  90. # Choose a random number of rooms to generate
  91. numRooms = self._randInt(self.minNumRooms, self.maxNumRooms+1)
  92. while len(roomList) < numRooms:
  93. curRoomList = []
  94. entryDoorPos = (
  95. self._randInt(0, width - 2),
  96. self._randInt(0, width - 2)
  97. )
  98. # Recursively place the rooms
  99. self._placeRoom(
  100. numRooms,
  101. roomList=curRoomList,
  102. minSz=4,
  103. maxSz=self.maxRoomSize,
  104. entryDoorWall=2,
  105. entryDoorPos=entryDoorPos
  106. )
  107. if len(curRoomList) > len(roomList):
  108. roomList = curRoomList
  109. # Store the list of rooms in this environment
  110. assert len(roomList) > 0
  111. self.rooms = roomList
  112. # Randomize the starting agent position and direction
  113. topX, topY = roomList[0].top
  114. sizeX, sizeY = roomList[0].size
  115. self.startPos = (
  116. self._randInt(topX + 1, topX + sizeX - 2),
  117. self._randInt(topY + 1, topY + sizeY - 2)
  118. )
  119. self.startDir = self._randInt(0, 4)
  120. # Create the grid
  121. grid = Grid(width, height)
  122. wall = Wall()
  123. prevDoorColor = None
  124. # For each room
  125. for idx, room in enumerate(roomList):
  126. topX, topY = room.top
  127. sizeX, sizeY = room.size
  128. # Draw the top and bottom walls
  129. for i in range(0, sizeX):
  130. grid.set(topX + i, topY, wall)
  131. grid.set(topX + i, topY + sizeY - 1, wall)
  132. # Draw the left and right walls
  133. for j in range(0, sizeY):
  134. grid.set(topX, topY + j, wall)
  135. grid.set(topX + sizeX - 1, topY + j, wall)
  136. # If this isn't the first room, place the entry door
  137. if idx > 0:
  138. # Pick a door color different from the previous one
  139. doorColors = set(COLORS.keys())
  140. if prevDoorColor:
  141. doorColors.remove(prevDoorColor)
  142. doorColor = self._randElem(doorColors)
  143. entryDoor = Door(doorColor)
  144. grid.set(*room.entryDoorPos, entryDoor)
  145. prevDoorColor = doorColor
  146. prevRoom = roomList[idx-1]
  147. prevRoom.exitDoorPos = entryDoorPos
  148. # Place the final goal
  149. while True:
  150. goalX = self._randInt(topX + 1, topX + sizeX - 1)
  151. goalY = self._randInt(topY + 1, topY + sizeY - 1)
  152. # Make sure the goal doesn't overlap with the agent
  153. if (goalX, goalY) != self.startPos:
  154. grid.set(goalX, goalY, Goal())
  155. break
  156. return grid
  157. def _placeRoom(
  158. self,
  159. numLeft,
  160. roomList,
  161. minSz,
  162. maxSz,
  163. entryDoorWall,
  164. entryDoorPos
  165. ):
  166. # Choose the room size randomly
  167. sizeX = self._randInt(minSz, maxSz+1)
  168. sizeY = self._randInt(minSz, maxSz+1)
  169. # The first room will be at the door position
  170. if len(roomList) == 0:
  171. topX, topY = entryDoorPos
  172. # Entry on the right
  173. elif entryDoorWall == 0:
  174. topX = entryDoorPos[0] - sizeX + 1
  175. y = entryDoorPos[1]
  176. topY = self._randInt(y - sizeY + 2, y)
  177. # Entry wall on the south
  178. elif entryDoorWall == 1:
  179. x = entryDoorPos[0]
  180. topX = self._randInt(x - sizeX + 2, x)
  181. topY = entryDoorPos[1] - sizeY + 1
  182. # Entry wall on the left
  183. elif entryDoorWall == 2:
  184. topX = entryDoorPos[0]
  185. y = entryDoorPos[1]
  186. topY = self._randInt(y - sizeY + 2, y)
  187. # Entry wall on the top
  188. elif entryDoorWall == 3:
  189. x = entryDoorPos[0]
  190. topX = self._randInt(x - sizeX + 2, x)
  191. topY = entryDoorPos[1]
  192. else:
  193. assert False, entryDoorWall
  194. # If the room is out of the grid, can't place a room here
  195. if topX < 0 or topY < 0:
  196. return False
  197. if topX + sizeX > self.gridSize or topY + sizeY >= self.gridSize:
  198. return False
  199. # If the room intersects with previous rooms, can't place it here
  200. for room in roomList[:-1]:
  201. nonOverlap = \
  202. topX + sizeX < room.top[0] or \
  203. room.top[0] + room.size[0] <= topX or \
  204. topY + sizeY < room.top[1] or \
  205. room.top[1] + room.size[1] <= topY
  206. if not nonOverlap:
  207. return False
  208. # Add this room to the list
  209. roomList.append(Room(
  210. (topX, topY),
  211. (sizeX, sizeY),
  212. entryDoorPos,
  213. None
  214. ))
  215. # If this was the last room, stop
  216. if numLeft == 1:
  217. return True
  218. # Try placing the next room
  219. for i in range(0, 8):
  220. # Pick which wall to place the out door on
  221. wallSet = set((0, 1, 2, 3))
  222. wallSet.remove(entryDoorWall)
  223. exitDoorWall = self._randElem(wallSet)
  224. nextEntryWall = (exitDoorWall + 2) % 4
  225. # Pick the exit door position
  226. # Exit on right wall
  227. if exitDoorWall == 0:
  228. exitDoorPos = (
  229. topX + sizeX - 1,
  230. topY + self._randInt(1, sizeY - 1)
  231. )
  232. # Exit on south wall
  233. elif exitDoorWall == 1:
  234. exitDoorPos = (
  235. topX + self._randInt(1, sizeX - 1),
  236. topY + sizeY - 1
  237. )
  238. # Exit on left wall
  239. elif exitDoorWall == 2:
  240. exitDoorPos = (
  241. topX,
  242. topY + self._randInt(1, sizeY - 1)
  243. )
  244. # Exit on north wall
  245. elif exitDoorWall == 3:
  246. exitDoorPos = (
  247. topX + self._randInt(1, sizeX - 1),
  248. topY
  249. )
  250. else:
  251. assert False
  252. # Recursively create the other rooms
  253. success = self._placeRoom(
  254. numLeft - 1,
  255. roomList=roomList,
  256. minSz=minSz,
  257. maxSz=maxSz,
  258. entryDoorWall=nextEntryWall,
  259. entryDoorPos=exitDoorPos
  260. )
  261. if success:
  262. break
  263. return True
  264. class MultiRoomEnvN6(MultiRoomEnv):
  265. def __init__(self):
  266. super(MultiRoomEnvN6, self).__init__(
  267. minNumRooms=6,
  268. maxNumRooms=6
  269. )
  270. register(
  271. id='MiniGrid-Multi-Room-N6-v0',
  272. entry_point='gym_minigrid.envs:MultiRoomEnvN6',
  273. reward_threshold=1000.0
  274. )
  275. class FetchEnv(MiniGridEnv):
  276. """
  277. Environment in which the agent has to fetch a random object
  278. named using English text strings
  279. """
  280. def __init__(
  281. self,
  282. size=8,
  283. numObjs=3):
  284. self.numObjs = numObjs
  285. super(FetchEnv, self).__init__(gridSize=size, maxSteps=5*size)
  286. def _genGrid(self, width, height):
  287. assert width == height
  288. gridSz = width
  289. # Create a grid surrounded by walls
  290. grid = Grid(width, height)
  291. for i in range(0, width):
  292. grid.set(i, 0, Wall())
  293. grid.set(i, height-1, Wall())
  294. for j in range(0, height):
  295. grid.set(0, j, Wall())
  296. grid.set(width-1, j, Wall())
  297. types = ['key', 'ball']
  298. colors = list(COLORS.keys())
  299. objs = []
  300. # For each object to be generated
  301. for i in range(0, self.numObjs):
  302. objType = self._randElem(types)
  303. objColor = self._randElem(colors)
  304. if objType == 'key':
  305. obj = Key(objColor)
  306. elif objType == 'ball':
  307. obj = Ball(objColor)
  308. while True:
  309. pos = (
  310. self._randInt(1, gridSz - 1),
  311. self._randInt(1, gridSz - 1)
  312. )
  313. if pos != self.startPos:
  314. grid.set(*pos, obj)
  315. break
  316. objs.append(obj)
  317. # Choose a random object to be picked up
  318. target = objs[self._randInt(0, len(objs))]
  319. self.targetType = target.type
  320. self.targetColor = target.color
  321. descStr = '%s %s' % (self.targetColor, self.targetType)
  322. # Generate the mission string
  323. idx = self._randInt(0, 5)
  324. if idx == 0:
  325. self.mission = 'get a %s' % descStr
  326. elif idx == 1:
  327. self.mission = 'go get a %s' % descStr
  328. elif idx == 2:
  329. self.mission = 'fetch a %s' % descStr
  330. elif idx == 3:
  331. self.mission = 'go fetch a %s' % descStr
  332. elif idx == 4:
  333. self.mission = 'you must fetch a %s' % descStr
  334. assert hasattr(self, 'mission')
  335. return grid
  336. def _reset(self):
  337. obs = MiniGridEnv._reset(self)
  338. obs = {
  339. 'image': obs,
  340. 'mission': self.mission,
  341. 'advice' : ''
  342. }
  343. return obs
  344. def _step(self, action):
  345. obs, reward, done, info = MiniGridEnv._step(self, action)
  346. if self.carrying:
  347. if self.carrying.color == self.targetColor and \
  348. self.carrying.type == self.targetType:
  349. reward = 1000 - self.stepCount
  350. done = True
  351. else:
  352. reward = -1000
  353. done = True
  354. obs = {
  355. 'image': obs,
  356. 'mission': self.mission,
  357. 'advice': ''
  358. }
  359. return obs, reward, done, info
  360. register(
  361. id='MiniGrid-Fetch-8x8-v0',
  362. entry_point='gym_minigrid.envs:FetchEnv',
  363. reward_threshold=900.0
  364. )