simple_envs.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. from gym.envs.registration import register
  2. from gym_minigrid.envs.minigrid_env import *
  3. class EmptyEnv(MiniGridEnv):
  4. """
  5. Empty grid environment, no obstacles, sparse reward
  6. """
  7. def __init__(self, size=8):
  8. super(EmptyEnv, self).__init__(gridSize=size, maxSteps=2 * size)
  9. class EmptyEnv6x6(EmptyEnv):
  10. def __init__(self):
  11. super(EmptyEnv6x6, self).__init__(size=6)
  12. register(
  13. id='MiniGrid-Empty-8x8-v0',
  14. entry_point='gym_minigrid.envs:EmptyEnv',
  15. reward_threshold=1000.0
  16. )
  17. register(
  18. id='MiniGrid-Empty-6x6-v0',
  19. entry_point='gym_minigrid.envs:EmptyEnv6x6',
  20. reward_threshold=1000.0
  21. )
  22. class DoorKeyEnv(MiniGridEnv):
  23. """
  24. Environment with a door and key, sparse reward
  25. """
  26. def __init__(self, size=8):
  27. super(DoorKeyEnv, self).__init__(gridSize=size, maxSteps=4 * size)
  28. def _genGrid(self, width, height):
  29. grid = super(DoorKeyEnv, self)._genGrid(width, height)
  30. assert width == height
  31. gridSz = width
  32. # Create a vertical splitting wall
  33. splitIdx = self._randInt(2, gridSz-3)
  34. for i in range(0, gridSz):
  35. grid.set(splitIdx, i, Wall())
  36. # Place a door in the wall
  37. doorIdx = self._randInt(1, gridSz-2)
  38. grid.set(splitIdx, doorIdx, Door('yellow'))
  39. # Place a key on the left side
  40. #keyIdx = self._randInt(1 + gridSz // 2, gridSz-2)
  41. keyIdx = gridSz-2
  42. grid.set(1, keyIdx, Key('yellow'))
  43. return grid
  44. class DoorKeyEnv16x16(DoorKeyEnv):
  45. def __init__(self):
  46. super(DoorKeyEnv16x16, self).__init__(size=16)
  47. register(
  48. id='-Door-Key-8x8-v0',
  49. entry_point='gym_minigrid.envs:DoorKeyEnv',
  50. reward_threshold=1000.0
  51. )
  52. register(
  53. id='-Door-Key-16x16-v0',
  54. entry_point='gym_minigrid.envs:DoorKeyEnv16x16',
  55. reward_threshold=1000.0
  56. )
  57. class Room:
  58. def __init__(self,
  59. top,
  60. size,
  61. entryDoorPos,
  62. exitDoorPos
  63. ):
  64. self.top = top
  65. self.size = size
  66. self.entryDoorPos = entryDoorPos
  67. self.exitDoorPos = exitDoorPos
  68. class MultiRoomEnv(MiniGridEnv):
  69. """
  70. Environment with multiple rooms (subgoals)
  71. """
  72. def __init__(self,
  73. minNumRooms,
  74. maxNumRooms,
  75. maxRoomSize=10
  76. ):
  77. assert minNumRooms > 0
  78. assert maxNumRooms >= minNumRooms
  79. assert maxRoomSize >= 4
  80. self.minNumRooms = minNumRooms
  81. self.maxNumRooms = maxNumRooms
  82. self.maxRoomSize = maxRoomSize
  83. self.rooms = []
  84. super(MultiRoomEnv, self).__init__(
  85. gridSize=25,
  86. maxSteps=self.maxNumRooms * 20
  87. )
  88. def _genGrid(self, width, height):
  89. roomList = []
  90. # Choose a random number of rooms to generate
  91. numRooms = self._randInt(self.minNumRooms, self.maxNumRooms+1)
  92. while len(roomList) < numRooms:
  93. curRoomList = []
  94. entryDoorPos = (
  95. self._randInt(0, width - 2),
  96. self._randInt(0, width - 2)
  97. )
  98. # Recursively place the rooms
  99. self._placeRoom(
  100. numRooms,
  101. roomList=curRoomList,
  102. minSz=4,
  103. maxSz=self.maxRoomSize,
  104. entryDoorWall=2,
  105. entryDoorPos=entryDoorPos
  106. )
  107. if len(curRoomList) > len(roomList):
  108. roomList = curRoomList
  109. # Store the list of rooms in this environment
  110. assert len(roomList) > 0
  111. self.rooms = roomList
  112. # Randomize the starting agent position and direction
  113. topX, topY = roomList[0].top
  114. sizeX, sizeY = roomList[0].size
  115. self.startPos = (
  116. self._randInt(topX + 1, topX + sizeX - 2),
  117. self._randInt(topY + 1, topY + sizeY - 2)
  118. )
  119. self.startDir = self._randInt(0, 4)
  120. # Create the grid
  121. grid = Grid(width, height)
  122. wall = Wall()
  123. prevDoorColor = None
  124. # For each room
  125. for idx, room in enumerate(roomList):
  126. topX, topY = room.top
  127. sizeX, sizeY = room.size
  128. # Draw the top and bottom walls
  129. for i in range(0, sizeX):
  130. grid.set(topX + i, topY, wall)
  131. grid.set(topX + i, topY + sizeY - 1, wall)
  132. # Draw the left and right walls
  133. for j in range(0, sizeY):
  134. grid.set(topX, topY + j, wall)
  135. grid.set(topX + sizeX - 1, topY + j, wall)
  136. # If this isn't the first room, place the entry door
  137. if idx > 0:
  138. # Pick a door color different from the previous one
  139. doorColors = set(COLORS.keys())
  140. if prevDoorColor:
  141. doorColors.remove(prevDoorColor)
  142. doorColor = self._randElem(doorColors)
  143. entryDoor = Door(doorColor)
  144. grid.set(*room.entryDoorPos, entryDoor)
  145. prevDoorColor = doorColor
  146. prevRoom = roomList[idx-1]
  147. prevRoom.exitDoorPos = room.entryDoorPos
  148. # Place the final goal
  149. while True:
  150. self.goalPos = (
  151. self._randInt(topX + 1, topX + sizeX - 1),
  152. self._randInt(topY + 1, topY + sizeY - 1)
  153. )
  154. # Make sure the goal doesn't overlap with the agent
  155. if self.goalPos != self.startPos:
  156. grid.set(*self.goalPos, Goal())
  157. break
  158. return grid
  159. def _placeRoom(
  160. self,
  161. numLeft,
  162. roomList,
  163. minSz,
  164. maxSz,
  165. entryDoorWall,
  166. entryDoorPos
  167. ):
  168. # Choose the room size randomly
  169. sizeX = self._randInt(minSz, maxSz+1)
  170. sizeY = self._randInt(minSz, maxSz+1)
  171. # The first room will be at the door position
  172. if len(roomList) == 0:
  173. topX, topY = entryDoorPos
  174. # Entry on the right
  175. elif entryDoorWall == 0:
  176. topX = entryDoorPos[0] - sizeX + 1
  177. y = entryDoorPos[1]
  178. topY = self._randInt(y - sizeY + 2, y)
  179. # Entry wall on the south
  180. elif entryDoorWall == 1:
  181. x = entryDoorPos[0]
  182. topX = self._randInt(x - sizeX + 2, x)
  183. topY = entryDoorPos[1] - sizeY + 1
  184. # Entry wall on the left
  185. elif entryDoorWall == 2:
  186. topX = entryDoorPos[0]
  187. y = entryDoorPos[1]
  188. topY = self._randInt(y - sizeY + 2, y)
  189. # Entry wall on the top
  190. elif entryDoorWall == 3:
  191. x = entryDoorPos[0]
  192. topX = self._randInt(x - sizeX + 2, x)
  193. topY = entryDoorPos[1]
  194. else:
  195. assert False, entryDoorWall
  196. # If the room is out of the grid, can't place a room here
  197. if topX < 0 or topY < 0:
  198. return False
  199. if topX + sizeX > self.gridSize or topY + sizeY >= self.gridSize:
  200. return False
  201. # If the room intersects with previous rooms, can't place it here
  202. for room in roomList[:-1]:
  203. nonOverlap = \
  204. topX + sizeX < room.top[0] or \
  205. room.top[0] + room.size[0] <= topX or \
  206. topY + sizeY < room.top[1] or \
  207. room.top[1] + room.size[1] <= topY
  208. if not nonOverlap:
  209. return False
  210. # Add this room to the list
  211. roomList.append(Room(
  212. (topX, topY),
  213. (sizeX, sizeY),
  214. entryDoorPos,
  215. None
  216. ))
  217. # If this was the last room, stop
  218. if numLeft == 1:
  219. return True
  220. # Try placing the next room
  221. for i in range(0, 8):
  222. # Pick which wall to place the out door on
  223. wallSet = set((0, 1, 2, 3))
  224. wallSet.remove(entryDoorWall)
  225. exitDoorWall = self._randElem(wallSet)
  226. nextEntryWall = (exitDoorWall + 2) % 4
  227. # Pick the exit door position
  228. # Exit on right wall
  229. if exitDoorWall == 0:
  230. exitDoorPos = (
  231. topX + sizeX - 1,
  232. topY + self._randInt(1, sizeY - 1)
  233. )
  234. # Exit on south wall
  235. elif exitDoorWall == 1:
  236. exitDoorPos = (
  237. topX + self._randInt(1, sizeX - 1),
  238. topY + sizeY - 1
  239. )
  240. # Exit on left wall
  241. elif exitDoorWall == 2:
  242. exitDoorPos = (
  243. topX,
  244. topY + self._randInt(1, sizeY - 1)
  245. )
  246. # Exit on north wall
  247. elif exitDoorWall == 3:
  248. exitDoorPos = (
  249. topX + self._randInt(1, sizeX - 1),
  250. topY
  251. )
  252. else:
  253. assert False
  254. # Recursively create the other rooms
  255. success = self._placeRoom(
  256. numLeft - 1,
  257. roomList=roomList,
  258. minSz=minSz,
  259. maxSz=maxSz,
  260. entryDoorWall=nextEntryWall,
  261. entryDoorPos=exitDoorPos
  262. )
  263. if success:
  264. break
  265. return True
  266. class MultiRoomEnvN6(MultiRoomEnv):
  267. def __init__(self):
  268. super(MultiRoomEnvN6, self).__init__(
  269. minNumRooms=6,
  270. maxNumRooms=6
  271. )
  272. register(
  273. id='MiniGrid-Multi-Room-N6-v0',
  274. entry_point='gym_minigrid.envs:MultiRoomEnvN6',
  275. reward_threshold=1000.0
  276. )
  277. class FetchEnv(MiniGridEnv):
  278. """
  279. Environment in which the agent has to fetch a random object
  280. named using English text strings
  281. """
  282. def __init__(
  283. self,
  284. size=8,
  285. numObjs=3):
  286. self.numObjs = numObjs
  287. super(FetchEnv, self).__init__(gridSize=size, maxSteps=5*size)
  288. def _genGrid(self, width, height):
  289. assert width == height
  290. gridSz = width
  291. # Create a grid surrounded by walls
  292. grid = Grid(width, height)
  293. for i in range(0, width):
  294. grid.set(i, 0, Wall())
  295. grid.set(i, height-1, Wall())
  296. for j in range(0, height):
  297. grid.set(0, j, Wall())
  298. grid.set(width-1, j, Wall())
  299. types = ['key', 'ball']
  300. colors = list(COLORS.keys())
  301. objs = []
  302. # For each object to be generated
  303. for i in range(0, self.numObjs):
  304. objType = self._randElem(types)
  305. objColor = self._randElem(colors)
  306. if objType == 'key':
  307. obj = Key(objColor)
  308. elif objType == 'ball':
  309. obj = Ball(objColor)
  310. while True:
  311. pos = (
  312. self._randInt(1, gridSz - 1),
  313. self._randInt(1, gridSz - 1)
  314. )
  315. if pos != self.startPos:
  316. grid.set(*pos, obj)
  317. break
  318. objs.append(obj)
  319. # Choose a random object to be picked up
  320. target = objs[self._randInt(0, len(objs))]
  321. self.targetType = target.type
  322. self.targetColor = target.color
  323. descStr = '%s %s' % (self.targetColor, self.targetType)
  324. # Generate the mission string
  325. idx = self._randInt(0, 5)
  326. if idx == 0:
  327. self.mission = 'get a %s' % descStr
  328. elif idx == 1:
  329. self.mission = 'go get a %s' % descStr
  330. elif idx == 2:
  331. self.mission = 'fetch a %s' % descStr
  332. elif idx == 3:
  333. self.mission = 'go fetch a %s' % descStr
  334. elif idx == 4:
  335. self.mission = 'you must fetch a %s' % descStr
  336. assert hasattr(self, 'mission')
  337. return grid
  338. def _reset(self):
  339. obs = MiniGridEnv._reset(self)
  340. obs = {
  341. 'image': obs,
  342. 'mission': self.mission,
  343. 'advice' : ''
  344. }
  345. return obs
  346. def _step(self, action):
  347. obs, reward, done, info = MiniGridEnv._step(self, action)
  348. if self.carrying:
  349. if self.carrying.color == self.targetColor and \
  350. self.carrying.type == self.targetType:
  351. reward = 1000 - self.stepCount
  352. done = True
  353. else:
  354. reward = -1000
  355. done = True
  356. obs = {
  357. 'image': obs,
  358. 'mission': self.mission,
  359. 'advice': ''
  360. }
  361. return obs, reward, done, info
  362. register(
  363. id='MiniGrid-Fetch-8x8-v0',
  364. entry_point='gym_minigrid.envs:FetchEnv',
  365. reward_threshold=900.0
  366. )