minigrid.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323
  1. import hashlib
  2. import math
  3. import string
  4. from abc import abstractmethod
  5. from enum import IntEnum
  6. import gym
  7. import numpy as np
  8. from gym import spaces
  9. # Size in pixels of a tile in the full-scale human view
  10. from gym_minigrid.rendering import (
  11. downsample,
  12. fill_coords,
  13. highlight_img,
  14. point_in_circle,
  15. point_in_line,
  16. point_in_rect,
  17. point_in_triangle,
  18. rotate_fn,
  19. )
  20. from gym_minigrid.window import Window
  21. TILE_PIXELS = 32
  22. # Map of color names to RGB values
  23. COLORS = {
  24. "red": np.array([255, 0, 0]),
  25. "green": np.array([0, 255, 0]),
  26. "blue": np.array([0, 0, 255]),
  27. "purple": np.array([112, 39, 195]),
  28. "yellow": np.array([255, 255, 0]),
  29. "grey": np.array([100, 100, 100]),
  30. }
  31. COLOR_NAMES = sorted(list(COLORS.keys()))
  32. # Used to map colors to integers
  33. COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5}
  34. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  35. # Map of object type to integers
  36. OBJECT_TO_IDX = {
  37. "unseen": 0,
  38. "empty": 1,
  39. "wall": 2,
  40. "floor": 3,
  41. "door": 4,
  42. "key": 5,
  43. "ball": 6,
  44. "box": 7,
  45. "goal": 8,
  46. "lava": 9,
  47. "agent": 10,
  48. }
  49. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  50. # Map of state names to integers
  51. STATE_TO_IDX = {
  52. "open": 0,
  53. "closed": 1,
  54. "locked": 2,
  55. }
  56. # Map of agent direction indices to vectors
  57. DIR_TO_VEC = [
  58. # Pointing right (positive X)
  59. np.array((1, 0)),
  60. # Down (positive Y)
  61. np.array((0, 1)),
  62. # Pointing left (negative X)
  63. np.array((-1, 0)),
  64. # Up (negative Y)
  65. np.array((0, -1)),
  66. ]
  67. class WorldObj:
  68. """
  69. Base class for grid world objects
  70. """
  71. def __init__(self, type, color):
  72. assert type in OBJECT_TO_IDX, type
  73. assert color in COLOR_TO_IDX, color
  74. self.type = type
  75. self.color = color
  76. self.contains = None
  77. # Initial position of the object
  78. self.init_pos = None
  79. # Current position of the object
  80. self.cur_pos = None
  81. def can_overlap(self):
  82. """Can the agent overlap with this?"""
  83. return False
  84. def can_pickup(self):
  85. """Can the agent pick this up?"""
  86. return False
  87. def can_contain(self):
  88. """Can this contain another object?"""
  89. return False
  90. def see_behind(self):
  91. """Can the agent see behind this object?"""
  92. return True
  93. def toggle(self, env, pos):
  94. """Method to trigger/toggle an action this object performs"""
  95. return False
  96. def encode(self):
  97. """Encode the a description of this object as a 3-tuple of integers"""
  98. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  99. @staticmethod
  100. def decode(type_idx, color_idx, state):
  101. """Create an object from a 3-tuple state description"""
  102. obj_type = IDX_TO_OBJECT[type_idx]
  103. color = IDX_TO_COLOR[color_idx]
  104. if obj_type == "empty" or obj_type == "unseen":
  105. return None
  106. # State, 0: open, 1: closed, 2: locked
  107. is_open = state == 0
  108. is_locked = state == 2
  109. if obj_type == "wall":
  110. v = Wall(color)
  111. elif obj_type == "floor":
  112. v = Floor(color)
  113. elif obj_type == "ball":
  114. v = Ball(color)
  115. elif obj_type == "key":
  116. v = Key(color)
  117. elif obj_type == "box":
  118. v = Box(color)
  119. elif obj_type == "door":
  120. v = Door(color, is_open, is_locked)
  121. elif obj_type == "goal":
  122. v = Goal()
  123. elif obj_type == "lava":
  124. v = Lava()
  125. else:
  126. assert False, "unknown object type in decode '%s'" % obj_type
  127. return v
  128. def render(self, r):
  129. """Draw this object with the given renderer"""
  130. raise NotImplementedError
  131. class Goal(WorldObj):
  132. def __init__(self):
  133. super().__init__("goal", "green")
  134. def can_overlap(self):
  135. return True
  136. def render(self, img):
  137. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  138. class Floor(WorldObj):
  139. """
  140. Colored floor tile the agent can walk over
  141. """
  142. def __init__(self, color="blue"):
  143. super().__init__("floor", color)
  144. def can_overlap(self):
  145. return True
  146. def render(self, img):
  147. # Give the floor a pale color
  148. color = COLORS[self.color] / 2
  149. fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
  150. class Lava(WorldObj):
  151. def __init__(self):
  152. super().__init__("lava", "red")
  153. def can_overlap(self):
  154. return True
  155. def render(self, img):
  156. c = (255, 128, 0)
  157. # Background color
  158. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  159. # Little waves
  160. for i in range(3):
  161. ylo = 0.3 + 0.2 * i
  162. yhi = 0.4 + 0.2 * i
  163. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
  164. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
  165. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
  166. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
  167. class Wall(WorldObj):
  168. def __init__(self, color="grey"):
  169. super().__init__("wall", color)
  170. def see_behind(self):
  171. return False
  172. def render(self, img):
  173. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  174. class Door(WorldObj):
  175. def __init__(self, color, is_open=False, is_locked=False):
  176. super().__init__("door", color)
  177. self.is_open = is_open
  178. self.is_locked = is_locked
  179. def can_overlap(self):
  180. """The agent can only walk over this cell when the door is open"""
  181. return self.is_open
  182. def see_behind(self):
  183. return self.is_open
  184. def toggle(self, env, pos):
  185. # If the player has the right key to open the door
  186. if self.is_locked:
  187. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  188. self.is_locked = False
  189. self.is_open = True
  190. return True
  191. return False
  192. self.is_open = not self.is_open
  193. return True
  194. def encode(self):
  195. """Encode the a description of this object as a 3-tuple of integers"""
  196. # State, 0: open, 1: closed, 2: locked
  197. if self.is_open:
  198. state = 0
  199. elif self.is_locked:
  200. state = 2
  201. # if door is closed and unlocked
  202. elif not self.is_open:
  203. state = 1
  204. else:
  205. raise ValueError(
  206. "There is no possible state encoding for the state:\n -Door Open: {}\n -Door Closed: {}\n -Door Locked: {}".format(
  207. self.is_open, not self.is_open, self.is_locked
  208. )
  209. )
  210. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  211. def render(self, img):
  212. c = COLORS[self.color]
  213. if self.is_open:
  214. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  215. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
  216. return
  217. # Door frame and door
  218. if self.is_locked:
  219. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  220. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  221. # Draw key slot
  222. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  223. else:
  224. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  225. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
  226. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  227. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
  228. # Draw door handle
  229. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  230. class Key(WorldObj):
  231. def __init__(self, color="blue"):
  232. super().__init__("key", color)
  233. def can_pickup(self):
  234. return True
  235. def render(self, img):
  236. c = COLORS[self.color]
  237. # Vertical quad
  238. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  239. # Teeth
  240. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  241. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  242. # Ring
  243. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  244. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
  245. class Ball(WorldObj):
  246. def __init__(self, color="blue"):
  247. super().__init__("ball", color)
  248. def can_pickup(self):
  249. return True
  250. def render(self, img):
  251. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  252. class Box(WorldObj):
  253. def __init__(self, color, contains=None):
  254. super().__init__("box", color)
  255. self.contains = contains
  256. def can_pickup(self):
  257. return True
  258. def render(self, img):
  259. c = COLORS[self.color]
  260. # Outline
  261. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  262. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
  263. # Horizontal slit
  264. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  265. def toggle(self, env, pos):
  266. # Replace the box by its contents
  267. env.grid.set(pos[0], pos[1], self.contains)
  268. return True
  269. class Grid:
  270. """
  271. Represent a grid and operations on it
  272. """
  273. # Static cache of pre-renderer tiles
  274. tile_cache = {}
  275. def __init__(self, width, height):
  276. assert width >= 3
  277. assert height >= 3
  278. self.width = width
  279. self.height = height
  280. self.grid = [None] * width * height
  281. def __contains__(self, key):
  282. if isinstance(key, WorldObj):
  283. for e in self.grid:
  284. if e is key:
  285. return True
  286. elif isinstance(key, tuple):
  287. for e in self.grid:
  288. if e is None:
  289. continue
  290. if (e.color, e.type) == key:
  291. return True
  292. if key[0] is None and key[1] == e.type:
  293. return True
  294. return False
  295. def __eq__(self, other):
  296. grid1 = self.encode()
  297. grid2 = other.encode()
  298. return np.array_equal(grid2, grid1)
  299. def __ne__(self, other):
  300. return not self == other
  301. def copy(self):
  302. from copy import deepcopy
  303. return deepcopy(self)
  304. def set(self, i, j, v):
  305. assert i >= 0 and i < self.width
  306. assert j >= 0 and j < self.height
  307. self.grid[j * self.width + i] = v
  308. def get(self, i, j):
  309. assert i >= 0 and i < self.width
  310. assert j >= 0 and j < self.height
  311. return self.grid[j * self.width + i]
  312. def horz_wall(self, x, y, length=None, obj_type=Wall):
  313. if length is None:
  314. length = self.width - x
  315. for i in range(0, length):
  316. self.set(x + i, y, obj_type())
  317. def vert_wall(self, x, y, length=None, obj_type=Wall):
  318. if length is None:
  319. length = self.height - y
  320. for j in range(0, length):
  321. self.set(x, y + j, obj_type())
  322. def wall_rect(self, x, y, w, h):
  323. self.horz_wall(x, y, w)
  324. self.horz_wall(x, y + h - 1, w)
  325. self.vert_wall(x, y, h)
  326. self.vert_wall(x + w - 1, y, h)
  327. def rotate_left(self):
  328. """
  329. Rotate the grid to the left (counter-clockwise)
  330. """
  331. grid = Grid(self.height, self.width)
  332. for i in range(self.width):
  333. for j in range(self.height):
  334. v = self.get(i, j)
  335. grid.set(j, grid.height - 1 - i, v)
  336. return grid
  337. def slice(self, topX, topY, width, height):
  338. """
  339. Get a subset of the grid
  340. """
  341. grid = Grid(width, height)
  342. for j in range(0, height):
  343. for i in range(0, width):
  344. x = topX + i
  345. y = topY + j
  346. if x >= 0 and x < self.width and y >= 0 and y < self.height:
  347. v = self.get(x, y)
  348. else:
  349. v = Wall()
  350. grid.set(i, j, v)
  351. return grid
  352. @classmethod
  353. def render_tile(
  354. cls, obj, agent_dir=None, highlight=False, tile_size=TILE_PIXELS, subdivs=3
  355. ):
  356. """
  357. Render a tile and cache the result
  358. """
  359. # Hash map lookup key for the cache
  360. key = (agent_dir, highlight, tile_size)
  361. key = obj.encode() + key if obj else key
  362. if key in cls.tile_cache:
  363. return cls.tile_cache[key]
  364. img = np.zeros(
  365. shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
  366. )
  367. # Draw the grid lines (top and left edges)
  368. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  369. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  370. if obj is not None:
  371. obj.render(img)
  372. # Overlay the agent on top
  373. if agent_dir is not None:
  374. tri_fn = point_in_triangle(
  375. (0.12, 0.19),
  376. (0.87, 0.50),
  377. (0.12, 0.81),
  378. )
  379. # Rotate the agent based on its direction
  380. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
  381. fill_coords(img, tri_fn, (255, 0, 0))
  382. # Highlight the cell if needed
  383. if highlight:
  384. highlight_img(img)
  385. # Downsample the image to perform supersampling/anti-aliasing
  386. img = downsample(img, subdivs)
  387. # Cache the rendered tile
  388. cls.tile_cache[key] = img
  389. return img
  390. def render(self, tile_size, agent_pos, agent_dir=None, highlight_mask=None):
  391. """
  392. Render this grid at a given scale
  393. :param r: target renderer object
  394. :param tile_size: tile size in pixels
  395. """
  396. if highlight_mask is None:
  397. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  398. # Compute the total grid size
  399. width_px = self.width * tile_size
  400. height_px = self.height * tile_size
  401. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  402. # Render the grid
  403. for j in range(0, self.height):
  404. for i in range(0, self.width):
  405. cell = self.get(i, j)
  406. agent_here = np.array_equal(agent_pos, (i, j))
  407. tile_img = Grid.render_tile(
  408. cell,
  409. agent_dir=agent_dir if agent_here else None,
  410. highlight=highlight_mask[i, j],
  411. tile_size=tile_size,
  412. )
  413. ymin = j * tile_size
  414. ymax = (j + 1) * tile_size
  415. xmin = i * tile_size
  416. xmax = (i + 1) * tile_size
  417. img[ymin:ymax, xmin:xmax, :] = tile_img
  418. return img
  419. def encode(self, vis_mask=None):
  420. """
  421. Produce a compact numpy encoding of the grid
  422. """
  423. if vis_mask is None:
  424. vis_mask = np.ones((self.width, self.height), dtype=bool)
  425. array = np.zeros((self.width, self.height, 3), dtype="uint8")
  426. for i in range(self.width):
  427. for j in range(self.height):
  428. if vis_mask[i, j]:
  429. v = self.get(i, j)
  430. if v is None:
  431. array[i, j, 0] = OBJECT_TO_IDX["empty"]
  432. array[i, j, 1] = 0
  433. array[i, j, 2] = 0
  434. else:
  435. array[i, j, :] = v.encode()
  436. return array
  437. @staticmethod
  438. def decode(array):
  439. """
  440. Decode an array grid encoding back into a grid
  441. """
  442. width, height, channels = array.shape
  443. assert channels == 3
  444. vis_mask = np.ones(shape=(width, height), dtype=bool)
  445. grid = Grid(width, height)
  446. for i in range(width):
  447. for j in range(height):
  448. type_idx, color_idx, state = array[i, j]
  449. v = WorldObj.decode(type_idx, color_idx, state)
  450. grid.set(i, j, v)
  451. vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
  452. return grid, vis_mask
  453. def process_vis(self, agent_pos):
  454. mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  455. mask[agent_pos[0], agent_pos[1]] = True
  456. for j in reversed(range(0, self.height)):
  457. for i in range(0, self.width - 1):
  458. if not mask[i, j]:
  459. continue
  460. cell = self.get(i, j)
  461. if cell and not cell.see_behind():
  462. continue
  463. mask[i + 1, j] = True
  464. if j > 0:
  465. mask[i + 1, j - 1] = True
  466. mask[i, j - 1] = True
  467. for i in reversed(range(1, self.width)):
  468. if not mask[i, j]:
  469. continue
  470. cell = self.get(i, j)
  471. if cell and not cell.see_behind():
  472. continue
  473. mask[i - 1, j] = True
  474. if j > 0:
  475. mask[i - 1, j - 1] = True
  476. mask[i, j - 1] = True
  477. for j in range(0, self.height):
  478. for i in range(0, self.width):
  479. if not mask[i, j]:
  480. self.set(i, j, None)
  481. return mask
  482. class MiniGridEnv(gym.Env):
  483. """
  484. 2D grid world game environment
  485. """
  486. metadata = {
  487. # Deprecated: use 'render_modes' instead
  488. "render.modes": ["human", "rgb_array"],
  489. "video.frames_per_second": 10, # Deprecated: use 'render_fps' instead
  490. "render_modes": ["human", "rgb_array"],
  491. "render_fps": 10,
  492. }
  493. # Enumeration of possible actions
  494. class Actions(IntEnum):
  495. # Turn left, turn right, move forward
  496. left = 0
  497. right = 1
  498. forward = 2
  499. # Pick up an object
  500. pickup = 3
  501. # Drop an object
  502. drop = 4
  503. # Toggle/activate an object
  504. toggle = 5
  505. # Done completing task
  506. done = 6
  507. def __init__(
  508. self,
  509. grid_size: int = None,
  510. width: int = None,
  511. height: int = None,
  512. max_steps: int = 100,
  513. see_through_walls: bool = False,
  514. agent_view_size: int = 7,
  515. render_mode: str = None,
  516. **kwargs,
  517. ):
  518. # Can't set both grid_size and width/height
  519. if grid_size:
  520. assert width is None and height is None
  521. width = grid_size
  522. height = grid_size
  523. # Action enumeration for this environment
  524. self.actions = MiniGridEnv.Actions
  525. # Actions are discrete integer values
  526. self.action_space = spaces.Discrete(len(self.actions))
  527. # Number of cells (width and height) in the agent view
  528. assert agent_view_size % 2 == 1
  529. assert agent_view_size >= 3
  530. self.agent_view_size = agent_view_size
  531. # Observations are dictionaries containing an
  532. # encoding of the grid and a textual 'mission' string
  533. self.observation_space = spaces.Box(
  534. low=0,
  535. high=255,
  536. shape=(self.agent_view_size, self.agent_view_size, 3),
  537. dtype="uint8",
  538. )
  539. self.observation_space = spaces.Dict(
  540. {
  541. "image": self.observation_space,
  542. "direction": spaces.Discrete(4),
  543. "mission": spaces.Text(
  544. max_length=200,
  545. charset=string.ascii_letters + string.digits + " .,!-",
  546. ),
  547. }
  548. )
  549. # render mode
  550. self.render_mode = render_mode
  551. # Range of possible rewards
  552. self.reward_range = (0, 1)
  553. self.window: Window = None
  554. # Environment configuration
  555. self.width = width
  556. self.height = height
  557. self.max_steps = max_steps
  558. self.see_through_walls = see_through_walls
  559. # Current position and direction of the agent
  560. self.agent_pos: np.ndarray = None
  561. self.agent_dir: int = None
  562. # Current grid and mission and carryinh
  563. self.grid = Grid(width, height)
  564. self.mission = ""
  565. self.carrying = None
  566. # Initialize the state
  567. self.reset()
  568. def reset(self, *, seed=None, return_info=False, options=None):
  569. super().reset(seed=seed)
  570. # Reinitialize episode-specific variables
  571. self.agent_pos = (-1, -1)
  572. self.agent_dir = -1
  573. # Generate a new random grid at the start of each episode
  574. self._gen_grid(self.width, self.height)
  575. # These fields should be defined by _gen_grid
  576. assert (
  577. self.agent_pos >= (0, 0)
  578. if isinstance(self.agent_pos, tuple)
  579. else all(self.agent_pos >= 0) and self.agent_dir >= 0
  580. )
  581. # Check that the agent doesn't overlap with an object
  582. start_cell = self.grid.get(*self.agent_pos)
  583. assert start_cell is None or start_cell.can_overlap()
  584. # Item picked up, being carried, initially nothing
  585. self.carrying = None
  586. # Step count since episode start
  587. self.step_count = 0
  588. # Return first observation
  589. obs = self.gen_obs()
  590. return obs
  591. def hash(self, size=16):
  592. """Compute a hash that uniquely identifies the current state of the environment.
  593. :param size: Size of the hashing
  594. """
  595. sample_hash = hashlib.sha256()
  596. to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
  597. for item in to_encode:
  598. sample_hash.update(str(item).encode("utf8"))
  599. return sample_hash.hexdigest()[:size]
  600. @property
  601. def steps_remaining(self):
  602. return self.max_steps - self.step_count
  603. def __str__(self):
  604. """
  605. Produce a pretty string of the environment's grid along with the agent.
  606. A grid cell is represented by 2-character string, the first one for
  607. the object and the second one for the color.
  608. """
  609. # Map of object types to short string
  610. OBJECT_TO_STR = {
  611. "wall": "W",
  612. "floor": "F",
  613. "door": "D",
  614. "key": "K",
  615. "ball": "A",
  616. "box": "B",
  617. "goal": "G",
  618. "lava": "V",
  619. }
  620. # Map agent's direction to short string
  621. AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
  622. str = ""
  623. for j in range(self.grid.height):
  624. for i in range(self.grid.width):
  625. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  626. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  627. continue
  628. c = self.grid.get(i, j)
  629. if c is None:
  630. str += " "
  631. continue
  632. if c.type == "door":
  633. if c.is_open:
  634. str += "__"
  635. elif c.is_locked:
  636. str += "L" + c.color[0].upper()
  637. else:
  638. str += "D" + c.color[0].upper()
  639. continue
  640. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  641. if j < self.grid.height - 1:
  642. str += "\n"
  643. return str
  644. @abstractmethod
  645. def _gen_grid(self, width, height):
  646. pass
  647. def _reward(self):
  648. """
  649. Compute the reward to be given upon success
  650. """
  651. return 1 - 0.9 * (self.step_count / self.max_steps)
  652. def _rand_int(self, low, high):
  653. """
  654. Generate random integer in [low,high[
  655. """
  656. return self.np_random.integers(low, high)
  657. def _rand_float(self, low, high):
  658. """
  659. Generate random float in [low,high[
  660. """
  661. return self.np_random.uniform(low, high)
  662. def _rand_bool(self):
  663. """
  664. Generate random boolean value
  665. """
  666. return self.np_random.integers(0, 2) == 0
  667. def _rand_elem(self, iterable):
  668. """
  669. Pick a random element in a list
  670. """
  671. lst = list(iterable)
  672. idx = self._rand_int(0, len(lst))
  673. return lst[idx]
  674. def _rand_subset(self, iterable, num_elems):
  675. """
  676. Sample a random subset of distinct elements of a list
  677. """
  678. lst = list(iterable)
  679. assert num_elems <= len(lst)
  680. out = []
  681. while len(out) < num_elems:
  682. elem = self._rand_elem(lst)
  683. lst.remove(elem)
  684. out.append(elem)
  685. return out
  686. def _rand_color(self):
  687. """
  688. Generate a random color name (string)
  689. """
  690. return self._rand_elem(COLOR_NAMES)
  691. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  692. """
  693. Generate a random (x,y) position tuple
  694. """
  695. return (
  696. self.np_random.integers(xLow, xHigh),
  697. self.np_random.integers(yLow, yHigh),
  698. )
  699. def place_obj(self, obj, top=None, size=None, reject_fn=None, max_tries=math.inf):
  700. """
  701. Place an object at an empty position in the grid
  702. :param top: top-left position of the rectangle where to place
  703. :param size: size of the rectangle where to place
  704. :param reject_fn: function to filter out potential positions
  705. """
  706. if top is None:
  707. top = (0, 0)
  708. else:
  709. top = (max(top[0], 0), max(top[1], 0))
  710. if size is None:
  711. size = (self.grid.width, self.grid.height)
  712. num_tries = 0
  713. while True:
  714. # This is to handle with rare cases where rejection sampling
  715. # gets stuck in an infinite loop
  716. if num_tries > max_tries:
  717. raise RecursionError("rejection sampling failed in place_obj")
  718. num_tries += 1
  719. pos = np.array(
  720. (
  721. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  722. self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
  723. )
  724. )
  725. pos = tuple(pos)
  726. # Don't place the object on top of another object
  727. if self.grid.get(*pos) is not None:
  728. continue
  729. # Don't place the object where the agent is
  730. if np.array_equal(pos, self.agent_pos):
  731. continue
  732. # Check if there is a filtering criterion
  733. if reject_fn and reject_fn(self, pos):
  734. continue
  735. break
  736. self.grid.set(pos[0], pos[1], obj)
  737. if obj is not None:
  738. obj.init_pos = pos
  739. obj.cur_pos = pos
  740. return pos
  741. def put_obj(self, obj, i, j):
  742. """
  743. Put an object at a specific position in the grid
  744. """
  745. self.grid.set(i, j, obj)
  746. obj.init_pos = (i, j)
  747. obj.cur_pos = (i, j)
  748. def place_agent(self, top=None, size=None, rand_dir=True, max_tries=math.inf):
  749. """
  750. Set the agent's starting point at an empty position in the grid
  751. """
  752. self.agent_pos = (-1, -1)
  753. pos = self.place_obj(None, top, size, max_tries=max_tries)
  754. self.agent_pos = pos
  755. if rand_dir:
  756. self.agent_dir = self._rand_int(0, 4)
  757. return pos
  758. @property
  759. def dir_vec(self):
  760. """
  761. Get the direction vector for the agent, pointing in the direction
  762. of forward movement.
  763. """
  764. assert self.agent_dir >= 0 and self.agent_dir < 4
  765. return DIR_TO_VEC[self.agent_dir]
  766. @property
  767. def right_vec(self):
  768. """
  769. Get the vector pointing to the right of the agent.
  770. """
  771. dx, dy = self.dir_vec
  772. return np.array((-dy, dx))
  773. @property
  774. def front_pos(self):
  775. """
  776. Get the position of the cell that is right in front of the agent
  777. """
  778. return self.agent_pos + self.dir_vec
  779. def get_view_coords(self, i, j):
  780. """
  781. Translate and rotate absolute grid coordinates (i, j) into the
  782. agent's partially observable view (sub-grid). Note that the resulting
  783. coordinates may be negative or outside of the agent's view size.
  784. """
  785. ax, ay = self.agent_pos
  786. dx, dy = self.dir_vec
  787. rx, ry = self.right_vec
  788. # Compute the absolute coordinates of the top-left view corner
  789. sz = self.agent_view_size
  790. hs = self.agent_view_size // 2
  791. tx = ax + (dx * (sz - 1)) - (rx * hs)
  792. ty = ay + (dy * (sz - 1)) - (ry * hs)
  793. lx = i - tx
  794. ly = j - ty
  795. # Project the coordinates of the object relative to the top-left
  796. # corner onto the agent's own coordinate system
  797. vx = rx * lx + ry * ly
  798. vy = -(dx * lx + dy * ly)
  799. return vx, vy
  800. def get_view_exts(self, agent_view_size=None):
  801. """
  802. Get the extents of the square set of tiles visible to the agent
  803. Note: the bottom extent indices are not included in the set
  804. if agent_view_size is None, use self.agent_view_size
  805. """
  806. agent_view_size = agent_view_size or self.agent_view_size
  807. # Facing right
  808. if self.agent_dir == 0:
  809. topX = self.agent_pos[0]
  810. topY = self.agent_pos[1] - agent_view_size // 2
  811. # Facing down
  812. elif self.agent_dir == 1:
  813. topX = self.agent_pos[0] - agent_view_size // 2
  814. topY = self.agent_pos[1]
  815. # Facing left
  816. elif self.agent_dir == 2:
  817. topX = self.agent_pos[0] - agent_view_size + 1
  818. topY = self.agent_pos[1] - agent_view_size // 2
  819. # Facing up
  820. elif self.agent_dir == 3:
  821. topX = self.agent_pos[0] - agent_view_size // 2
  822. topY = self.agent_pos[1] - agent_view_size + 1
  823. else:
  824. assert False, "invalid agent direction"
  825. botX = topX + agent_view_size
  826. botY = topY + agent_view_size
  827. return (topX, topY, botX, botY)
  828. def relative_coords(self, x, y):
  829. """
  830. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  831. """
  832. vx, vy = self.get_view_coords(x, y)
  833. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  834. return None
  835. return vx, vy
  836. def in_view(self, x, y):
  837. """
  838. check if a grid position is visible to the agent
  839. """
  840. return self.relative_coords(x, y) is not None
  841. def agent_sees(self, x, y):
  842. """
  843. Check if a non-empty grid position is visible to the agent
  844. """
  845. coordinates = self.relative_coords(x, y)
  846. if coordinates is None:
  847. return False
  848. vx, vy = coordinates
  849. obs = self.gen_obs()
  850. obs_grid, _ = Grid.decode(obs["image"])
  851. obs_cell = obs_grid.get(vx, vy)
  852. world_cell = self.grid.get(x, y)
  853. assert world_cell is not None
  854. return obs_cell is not None and obs_cell.type == world_cell.type
  855. def step(self, action):
  856. self.step_count += 1
  857. reward = 0
  858. terminated = False
  859. truncated = False
  860. # Get the position in front of the agent
  861. fwd_pos = self.front_pos
  862. # Get the contents of the cell in front of the agent
  863. fwd_cell = self.grid.get(*fwd_pos)
  864. # Rotate left
  865. if action == self.actions.left:
  866. self.agent_dir -= 1
  867. if self.agent_dir < 0:
  868. self.agent_dir += 4
  869. # Rotate right
  870. elif action == self.actions.right:
  871. self.agent_dir = (self.agent_dir + 1) % 4
  872. # Move forward
  873. elif action == self.actions.forward:
  874. if fwd_cell is None or fwd_cell.can_overlap():
  875. self.agent_pos = tuple(fwd_pos)
  876. if fwd_cell is not None and fwd_cell.type == "goal":
  877. terminated = True
  878. reward = self._reward()
  879. if fwd_cell is not None and fwd_cell.type == "lava":
  880. terminated = True
  881. # Pick up an object
  882. elif action == self.actions.pickup:
  883. if fwd_cell and fwd_cell.can_pickup():
  884. if self.carrying is None:
  885. self.carrying = fwd_cell
  886. self.carrying.cur_pos = np.array([-1, -1])
  887. self.grid.set(fwd_pos[0], fwd_pos[1], None)
  888. # Drop an object
  889. elif action == self.actions.drop:
  890. if not fwd_cell and self.carrying:
  891. self.grid.set(fwd_pos[0], fwd_pos[1], self.carrying)
  892. self.carrying.cur_pos = fwd_pos
  893. self.carrying = None
  894. # Toggle/activate an object
  895. elif action == self.actions.toggle:
  896. if fwd_cell:
  897. fwd_cell.toggle(self, fwd_pos)
  898. # Done action (not used by default)
  899. elif action == self.actions.done:
  900. pass
  901. else:
  902. raise ValueError(f"Unknown action: {action}")
  903. if self.step_count >= self.max_steps:
  904. truncated = True
  905. obs = self.gen_obs()
  906. return obs, reward, terminated, truncated, {}
  907. def gen_obs_grid(self, agent_view_size=None):
  908. """
  909. Generate the sub-grid observed by the agent.
  910. This method also outputs a visibility mask telling us which grid
  911. cells the agent can actually see.
  912. if agent_view_size is None, self.agent_view_size is used
  913. """
  914. topX, topY, botX, botY = self.get_view_exts(agent_view_size)
  915. agent_view_size = agent_view_size or self.agent_view_size
  916. grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
  917. for i in range(self.agent_dir + 1):
  918. grid = grid.rotate_left()
  919. # Process occluders and visibility
  920. # Note that this incurs some performance cost
  921. if not self.see_through_walls:
  922. vis_mask = grid.process_vis(
  923. agent_pos=(agent_view_size // 2, agent_view_size - 1)
  924. )
  925. else:
  926. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  927. # Make it so the agent sees what it's carrying
  928. # We do this by placing the carried object at the agent's position
  929. # in the agent's partially observable view
  930. agent_pos = grid.width // 2, grid.height - 1
  931. if self.carrying:
  932. grid.set(*agent_pos, self.carrying)
  933. else:
  934. grid.set(*agent_pos, None)
  935. return grid, vis_mask
  936. def gen_obs(self):
  937. """
  938. Generate the agent's view (partially observable, low-resolution encoding)
  939. """
  940. grid, vis_mask = self.gen_obs_grid()
  941. # Encode the partially observable view into a numpy array
  942. image = grid.encode(vis_mask)
  943. # Observations are dictionaries containing:
  944. # - an image (partially observable view of the environment)
  945. # - the agent's direction/orientation (acting as a compass)
  946. # - a textual mission string (instructions for the agent)
  947. obs = {"image": image, "direction": self.agent_dir, "mission": self.mission}
  948. return obs
  949. def get_obs_render(self, obs, tile_size=TILE_PIXELS // 2):
  950. """
  951. Render an agent observation for visualization
  952. """
  953. grid, vis_mask = Grid.decode(obs)
  954. # Render the whole grid
  955. img = grid.render(
  956. tile_size,
  957. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  958. agent_dir=3,
  959. highlight_mask=vis_mask,
  960. )
  961. return img
  962. def render(self, mode="human", close=False, highlight=True, tile_size=TILE_PIXELS):
  963. """
  964. Render the whole-grid human view
  965. """
  966. if self.render_mode is not None:
  967. mode = self.render_mode
  968. if close:
  969. if self.window:
  970. self.window.close()
  971. return
  972. if mode == "human" and not self.window:
  973. self.window = Window("gym_minigrid")
  974. self.window.show(block=False)
  975. # Compute which cells are visible to the agent
  976. _, vis_mask = self.gen_obs_grid()
  977. # Compute the world coordinates of the bottom-left corner
  978. # of the agent's view area
  979. f_vec = self.dir_vec
  980. r_vec = self.right_vec
  981. top_left = (
  982. self.agent_pos
  983. + f_vec * (self.agent_view_size - 1)
  984. - r_vec * (self.agent_view_size // 2)
  985. )
  986. # Mask of which cells to highlight
  987. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  988. # For each cell in the visibility mask
  989. for vis_j in range(0, self.agent_view_size):
  990. for vis_i in range(0, self.agent_view_size):
  991. # If this cell is not visible, don't highlight it
  992. if not vis_mask[vis_i, vis_j]:
  993. continue
  994. # Compute the world coordinates of this cell
  995. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  996. if abs_i < 0 or abs_i >= self.width:
  997. continue
  998. if abs_j < 0 or abs_j >= self.height:
  999. continue
  1000. # Mark this cell to be highlighted
  1001. highlight_mask[abs_i, abs_j] = True
  1002. # Render the whole grid
  1003. img = self.grid.render(
  1004. tile_size,
  1005. self.agent_pos,
  1006. self.agent_dir,
  1007. highlight_mask=highlight_mask if highlight else None,
  1008. )
  1009. if mode == "human":
  1010. assert self.window is not None
  1011. self.window.set_caption(self.mission)
  1012. self.window.show_img(img)
  1013. return img
  1014. def close(self):
  1015. if self.window:
  1016. self.window.close()
  1017. return