minigrid.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308
  1. import hashlib
  2. import math
  3. import string
  4. from enum import IntEnum
  5. import gym
  6. import numpy as np
  7. from gym import spaces
  8. # Size in pixels of a tile in the full-scale human view
  9. from gym_minigrid.rendering import (
  10. downsample,
  11. fill_coords,
  12. highlight_img,
  13. point_in_circle,
  14. point_in_line,
  15. point_in_rect,
  16. point_in_triangle,
  17. rotate_fn,
  18. )
  19. TILE_PIXELS = 32
  20. # Map of color names to RGB values
  21. COLORS = {
  22. "red": np.array([255, 0, 0]),
  23. "green": np.array([0, 255, 0]),
  24. "blue": np.array([0, 0, 255]),
  25. "purple": np.array([112, 39, 195]),
  26. "yellow": np.array([255, 255, 0]),
  27. "grey": np.array([100, 100, 100]),
  28. }
  29. COLOR_NAMES = sorted(list(COLORS.keys()))
  30. # Used to map colors to integers
  31. COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5}
  32. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  33. # Map of object type to integers
  34. OBJECT_TO_IDX = {
  35. "unseen": 0,
  36. "empty": 1,
  37. "wall": 2,
  38. "floor": 3,
  39. "door": 4,
  40. "key": 5,
  41. "ball": 6,
  42. "box": 7,
  43. "goal": 8,
  44. "lava": 9,
  45. "agent": 10,
  46. }
  47. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  48. # Map of state names to integers
  49. STATE_TO_IDX = {
  50. "open": 0,
  51. "closed": 1,
  52. "locked": 2,
  53. }
  54. # Map of agent direction indices to vectors
  55. DIR_TO_VEC = [
  56. # Pointing right (positive X)
  57. np.array((1, 0)),
  58. # Down (positive Y)
  59. np.array((0, 1)),
  60. # Pointing left (negative X)
  61. np.array((-1, 0)),
  62. # Up (negative Y)
  63. np.array((0, -1)),
  64. ]
  65. class WorldObj:
  66. """
  67. Base class for grid world objects
  68. """
  69. def __init__(self, type, color):
  70. assert type in OBJECT_TO_IDX, type
  71. assert color in COLOR_TO_IDX, color
  72. self.type = type
  73. self.color = color
  74. self.contains = None
  75. # Initial position of the object
  76. self.init_pos = None
  77. # Current position of the object
  78. self.cur_pos = None
  79. def can_overlap(self):
  80. """Can the agent overlap with this?"""
  81. return False
  82. def can_pickup(self):
  83. """Can the agent pick this up?"""
  84. return False
  85. def can_contain(self):
  86. """Can this contain another object?"""
  87. return False
  88. def see_behind(self):
  89. """Can the agent see behind this object?"""
  90. return True
  91. def toggle(self, env, pos):
  92. """Method to trigger/toggle an action this object performs"""
  93. return False
  94. def encode(self):
  95. """Encode the a description of this object as a 3-tuple of integers"""
  96. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  97. @staticmethod
  98. def decode(type_idx, color_idx, state):
  99. """Create an object from a 3-tuple state description"""
  100. obj_type = IDX_TO_OBJECT[type_idx]
  101. color = IDX_TO_COLOR[color_idx]
  102. if obj_type == "empty" or obj_type == "unseen":
  103. return None
  104. # State, 0: open, 1: closed, 2: locked
  105. is_open = state == 0
  106. is_locked = state == 2
  107. if obj_type == "wall":
  108. v = Wall(color)
  109. elif obj_type == "floor":
  110. v = Floor(color)
  111. elif obj_type == "ball":
  112. v = Ball(color)
  113. elif obj_type == "key":
  114. v = Key(color)
  115. elif obj_type == "box":
  116. v = Box(color)
  117. elif obj_type == "door":
  118. v = Door(color, is_open, is_locked)
  119. elif obj_type == "goal":
  120. v = Goal()
  121. elif obj_type == "lava":
  122. v = Lava()
  123. else:
  124. assert False, "unknown object type in decode '%s'" % obj_type
  125. return v
  126. def render(self, r):
  127. """Draw this object with the given renderer"""
  128. raise NotImplementedError
  129. class Goal(WorldObj):
  130. def __init__(self):
  131. super().__init__("goal", "green")
  132. def can_overlap(self):
  133. return True
  134. def render(self, img):
  135. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  136. class Floor(WorldObj):
  137. """
  138. Colored floor tile the agent can walk over
  139. """
  140. def __init__(self, color="blue"):
  141. super().__init__("floor", color)
  142. def can_overlap(self):
  143. return True
  144. def render(self, img):
  145. # Give the floor a pale color
  146. color = COLORS[self.color] / 2
  147. fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
  148. class Lava(WorldObj):
  149. def __init__(self):
  150. super().__init__("lava", "red")
  151. def can_overlap(self):
  152. return True
  153. def render(self, img):
  154. c = (255, 128, 0)
  155. # Background color
  156. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  157. # Little waves
  158. for i in range(3):
  159. ylo = 0.3 + 0.2 * i
  160. yhi = 0.4 + 0.2 * i
  161. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
  162. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
  163. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
  164. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
  165. class Wall(WorldObj):
  166. def __init__(self, color="grey"):
  167. super().__init__("wall", color)
  168. def see_behind(self):
  169. return False
  170. def render(self, img):
  171. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  172. class Door(WorldObj):
  173. def __init__(self, color, is_open=False, is_locked=False):
  174. super().__init__("door", color)
  175. self.is_open = is_open
  176. self.is_locked = is_locked
  177. def can_overlap(self):
  178. """The agent can only walk over this cell when the door is open"""
  179. return self.is_open
  180. def see_behind(self):
  181. return self.is_open
  182. def toggle(self, env, pos):
  183. # If the player has the right key to open the door
  184. if self.is_locked:
  185. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  186. self.is_locked = False
  187. self.is_open = True
  188. return True
  189. return False
  190. self.is_open = not self.is_open
  191. return True
  192. def encode(self):
  193. """Encode the a description of this object as a 3-tuple of integers"""
  194. # State, 0: open, 1: closed, 2: locked
  195. if self.is_open:
  196. state = 0
  197. elif self.is_locked:
  198. state = 2
  199. # if door is closed and unlocked
  200. elif not self.is_open:
  201. state = 1
  202. else:
  203. raise ValueError(
  204. "There is no possible state encoding for the state:\n -Door Open: {}\n -Door Closed: {}\n -Door Locked: {}".format(
  205. self.is_open, not self.is_open, self.is_locked
  206. )
  207. )
  208. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  209. def render(self, img):
  210. c = COLORS[self.color]
  211. if self.is_open:
  212. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  213. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
  214. return
  215. # Door frame and door
  216. if self.is_locked:
  217. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  218. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  219. # Draw key slot
  220. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  221. else:
  222. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  223. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
  224. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  225. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
  226. # Draw door handle
  227. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  228. class Key(WorldObj):
  229. def __init__(self, color="blue"):
  230. super().__init__("key", color)
  231. def can_pickup(self):
  232. return True
  233. def render(self, img):
  234. c = COLORS[self.color]
  235. # Vertical quad
  236. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  237. # Teeth
  238. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  239. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  240. # Ring
  241. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  242. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
  243. class Ball(WorldObj):
  244. def __init__(self, color="blue"):
  245. super().__init__("ball", color)
  246. def can_pickup(self):
  247. return True
  248. def render(self, img):
  249. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  250. class Box(WorldObj):
  251. def __init__(self, color, contains=None):
  252. super().__init__("box", color)
  253. self.contains = contains
  254. def can_pickup(self):
  255. return True
  256. def render(self, img):
  257. c = COLORS[self.color]
  258. # Outline
  259. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  260. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
  261. # Horizontal slit
  262. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  263. def toggle(self, env, pos):
  264. # Replace the box by its contents
  265. env.grid.set(*pos, self.contains)
  266. return True
  267. class Grid:
  268. """
  269. Represent a grid and operations on it
  270. """
  271. # Static cache of pre-renderer tiles
  272. tile_cache = {}
  273. def __init__(self, width, height):
  274. assert width >= 3
  275. assert height >= 3
  276. self.width = width
  277. self.height = height
  278. self.grid = [None] * width * height
  279. def __contains__(self, key):
  280. if isinstance(key, WorldObj):
  281. for e in self.grid:
  282. if e is key:
  283. return True
  284. elif isinstance(key, tuple):
  285. for e in self.grid:
  286. if e is None:
  287. continue
  288. if (e.color, e.type) == key:
  289. return True
  290. if key[0] is None and key[1] == e.type:
  291. return True
  292. return False
  293. def __eq__(self, other):
  294. grid1 = self.encode()
  295. grid2 = other.encode()
  296. return np.array_equal(grid2, grid1)
  297. def __ne__(self, other):
  298. return not self == other
  299. def copy(self):
  300. from copy import deepcopy
  301. return deepcopy(self)
  302. def set(self, i, j, v):
  303. assert i >= 0 and i < self.width
  304. assert j >= 0 and j < self.height
  305. self.grid[j * self.width + i] = v
  306. def get(self, i, j):
  307. assert i >= 0 and i < self.width
  308. assert j >= 0 and j < self.height
  309. return self.grid[j * self.width + i]
  310. def horz_wall(self, x, y, length=None, obj_type=Wall):
  311. if length is None:
  312. length = self.width - x
  313. for i in range(0, length):
  314. self.set(x + i, y, obj_type())
  315. def vert_wall(self, x, y, length=None, obj_type=Wall):
  316. if length is None:
  317. length = self.height - y
  318. for j in range(0, length):
  319. self.set(x, y + j, obj_type())
  320. def wall_rect(self, x, y, w, h):
  321. self.horz_wall(x, y, w)
  322. self.horz_wall(x, y + h - 1, w)
  323. self.vert_wall(x, y, h)
  324. self.vert_wall(x + w - 1, y, h)
  325. def rotate_left(self):
  326. """
  327. Rotate the grid to the left (counter-clockwise)
  328. """
  329. grid = Grid(self.height, self.width)
  330. for i in range(self.width):
  331. for j in range(self.height):
  332. v = self.get(i, j)
  333. grid.set(j, grid.height - 1 - i, v)
  334. return grid
  335. def slice(self, topX, topY, width, height):
  336. """
  337. Get a subset of the grid
  338. """
  339. grid = Grid(width, height)
  340. for j in range(0, height):
  341. for i in range(0, width):
  342. x = topX + i
  343. y = topY + j
  344. if x >= 0 and x < self.width and y >= 0 and y < self.height:
  345. v = self.get(x, y)
  346. else:
  347. v = Wall()
  348. grid.set(i, j, v)
  349. return grid
  350. @classmethod
  351. def render_tile(
  352. cls, obj, agent_dir=None, highlight=False, tile_size=TILE_PIXELS, subdivs=3
  353. ):
  354. """
  355. Render a tile and cache the result
  356. """
  357. # Hash map lookup key for the cache
  358. key = (agent_dir, highlight, tile_size)
  359. key = obj.encode() + key if obj else key
  360. if key in cls.tile_cache:
  361. return cls.tile_cache[key]
  362. img = np.zeros(
  363. shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
  364. )
  365. # Draw the grid lines (top and left edges)
  366. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  367. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  368. if obj is not None:
  369. obj.render(img)
  370. # Overlay the agent on top
  371. if agent_dir is not None:
  372. tri_fn = point_in_triangle(
  373. (0.12, 0.19),
  374. (0.87, 0.50),
  375. (0.12, 0.81),
  376. )
  377. # Rotate the agent based on its direction
  378. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
  379. fill_coords(img, tri_fn, (255, 0, 0))
  380. # Highlight the cell if needed
  381. if highlight:
  382. highlight_img(img)
  383. # Downsample the image to perform supersampling/anti-aliasing
  384. img = downsample(img, subdivs)
  385. # Cache the rendered tile
  386. cls.tile_cache[key] = img
  387. return img
  388. def render(self, tile_size, agent_pos=None, agent_dir=None, highlight_mask=None):
  389. """
  390. Render this grid at a given scale
  391. :param r: target renderer object
  392. :param tile_size: tile size in pixels
  393. """
  394. if highlight_mask is None:
  395. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  396. # Compute the total grid size
  397. width_px = self.width * tile_size
  398. height_px = self.height * tile_size
  399. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  400. # Render the grid
  401. for j in range(0, self.height):
  402. for i in range(0, self.width):
  403. cell = self.get(i, j)
  404. agent_here = np.array_equal(agent_pos, (i, j))
  405. tile_img = Grid.render_tile(
  406. cell,
  407. agent_dir=agent_dir if agent_here else None,
  408. highlight=highlight_mask[i, j],
  409. tile_size=tile_size,
  410. )
  411. ymin = j * tile_size
  412. ymax = (j + 1) * tile_size
  413. xmin = i * tile_size
  414. xmax = (i + 1) * tile_size
  415. img[ymin:ymax, xmin:xmax, :] = tile_img
  416. return img
  417. def encode(self, vis_mask=None):
  418. """
  419. Produce a compact numpy encoding of the grid
  420. """
  421. if vis_mask is None:
  422. vis_mask = np.ones((self.width, self.height), dtype=bool)
  423. array = np.zeros((self.width, self.height, 3), dtype="uint8")
  424. for i in range(self.width):
  425. for j in range(self.height):
  426. if vis_mask[i, j]:
  427. v = self.get(i, j)
  428. if v is None:
  429. array[i, j, 0] = OBJECT_TO_IDX["empty"]
  430. array[i, j, 1] = 0
  431. array[i, j, 2] = 0
  432. else:
  433. array[i, j, :] = v.encode()
  434. return array
  435. @staticmethod
  436. def decode(array):
  437. """
  438. Decode an array grid encoding back into a grid
  439. """
  440. width, height, channels = array.shape
  441. assert channels == 3
  442. vis_mask = np.ones(shape=(width, height), dtype=bool)
  443. grid = Grid(width, height)
  444. for i in range(width):
  445. for j in range(height):
  446. type_idx, color_idx, state = array[i, j]
  447. v = WorldObj.decode(type_idx, color_idx, state)
  448. grid.set(i, j, v)
  449. vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
  450. return grid, vis_mask
  451. def process_vis(self, agent_pos):
  452. mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  453. mask[agent_pos[0], agent_pos[1]] = True
  454. for j in reversed(range(0, self.height)):
  455. for i in range(0, self.width - 1):
  456. if not mask[i, j]:
  457. continue
  458. cell = self.get(i, j)
  459. if cell and not cell.see_behind():
  460. continue
  461. mask[i + 1, j] = True
  462. if j > 0:
  463. mask[i + 1, j - 1] = True
  464. mask[i, j - 1] = True
  465. for i in reversed(range(1, self.width)):
  466. if not mask[i, j]:
  467. continue
  468. cell = self.get(i, j)
  469. if cell and not cell.see_behind():
  470. continue
  471. mask[i - 1, j] = True
  472. if j > 0:
  473. mask[i - 1, j - 1] = True
  474. mask[i, j - 1] = True
  475. for j in range(0, self.height):
  476. for i in range(0, self.width):
  477. if not mask[i, j]:
  478. self.set(i, j, None)
  479. return mask
  480. class MiniGridEnv(gym.Env):
  481. """
  482. 2D grid world game environment
  483. """
  484. metadata = {
  485. # Deprecated: use 'render_modes' instead
  486. "render.modes": ["human", "rgb_array"],
  487. "video.frames_per_second": 10, # Deprecated: use 'render_fps' instead
  488. "render_modes": ["human", "rgb_array"],
  489. "render_fps": 10,
  490. }
  491. # Enumeration of possible actions
  492. class Actions(IntEnum):
  493. # Turn left, turn right, move forward
  494. left = 0
  495. right = 1
  496. forward = 2
  497. # Pick up an object
  498. pickup = 3
  499. # Drop an object
  500. drop = 4
  501. # Toggle/activate an object
  502. toggle = 5
  503. # Done completing task
  504. done = 6
  505. def __init__(
  506. self,
  507. grid_size: int = None,
  508. width: int = None,
  509. height: int = None,
  510. max_steps: int = 100,
  511. see_through_walls: bool = False,
  512. agent_view_size: int = 7,
  513. render_mode: str = None,
  514. **kwargs
  515. ):
  516. # Can't set both grid_size and width/height
  517. if grid_size:
  518. assert width is None and height is None
  519. width = grid_size
  520. height = grid_size
  521. # Action enumeration for this environment
  522. self.actions = MiniGridEnv.Actions
  523. # Actions are discrete integer values
  524. self.action_space = spaces.Discrete(len(self.actions))
  525. # Number of cells (width and height) in the agent view
  526. assert agent_view_size % 2 == 1
  527. assert agent_view_size >= 3
  528. self.agent_view_size = agent_view_size
  529. # Observations are dictionaries containing an
  530. # encoding of the grid and a textual 'mission' string
  531. self.observation_space = spaces.Box(
  532. low=0,
  533. high=255,
  534. shape=(self.agent_view_size, self.agent_view_size, 3),
  535. dtype="uint8",
  536. )
  537. self.observation_space = spaces.Dict(
  538. {
  539. "image": self.observation_space,
  540. "direction": spaces.Discrete(4),
  541. "mission": spaces.Text(
  542. max_length=200,
  543. charset=string.ascii_letters + string.digits + " .,!-",
  544. ),
  545. }
  546. )
  547. # render mode
  548. self.render_mode = render_mode
  549. # Range of possible rewards
  550. self.reward_range = (0, 1)
  551. # Environment configuration
  552. self.width = width
  553. self.height = height
  554. self.max_steps = max_steps
  555. self.see_through_walls = see_through_walls
  556. # Current position and direction of the agent
  557. self.agent_pos: np.ndarray = None
  558. self.agent_dir: int = None
  559. # Initialize the state
  560. self.reset()
  561. def reset(self, *, seed=None, return_info=False, options=None):
  562. super().reset(seed=seed)
  563. # Current position and direction of the agent
  564. self.agent_pos = None
  565. self.agent_dir = None
  566. # Generate a new random grid at the start of each episode
  567. self._gen_grid(self.width, self.height)
  568. # These fields should be defined by _gen_grid
  569. assert self.agent_pos is not None
  570. assert self.agent_dir is not None
  571. # Check that the agent doesn't overlap with an object
  572. start_cell = self.grid.get(*self.agent_pos)
  573. assert start_cell is None or start_cell.can_overlap()
  574. # Item picked up, being carried, initially nothing
  575. self.carrying = None
  576. # Step count since episode start
  577. self.step_count = 0
  578. # Return first observation
  579. obs = self.gen_obs()
  580. return obs
  581. def hash(self, size=16):
  582. """Compute a hash that uniquely identifies the current state of the environment.
  583. :param size: Size of the hashing
  584. """
  585. sample_hash = hashlib.sha256()
  586. to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
  587. for item in to_encode:
  588. sample_hash.update(str(item).encode("utf8"))
  589. return sample_hash.hexdigest()[:size]
  590. @property
  591. def steps_remaining(self):
  592. return self.max_steps - self.step_count
  593. def __str__(self):
  594. """
  595. Produce a pretty string of the environment's grid along with the agent.
  596. A grid cell is represented by 2-character string, the first one for
  597. the object and the second one for the color.
  598. """
  599. # Map of object types to short string
  600. OBJECT_TO_STR = {
  601. "wall": "W",
  602. "floor": "F",
  603. "door": "D",
  604. "key": "K",
  605. "ball": "A",
  606. "box": "B",
  607. "goal": "G",
  608. "lava": "V",
  609. }
  610. # Map agent's direction to short string
  611. AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
  612. str = ""
  613. for j in range(self.grid.height):
  614. for i in range(self.grid.width):
  615. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  616. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  617. continue
  618. c = self.grid.get(i, j)
  619. if c is None:
  620. str += " "
  621. continue
  622. if c.type == "door":
  623. if c.is_open:
  624. str += "__"
  625. elif c.is_locked:
  626. str += "L" + c.color[0].upper()
  627. else:
  628. str += "D" + c.color[0].upper()
  629. continue
  630. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  631. if j < self.grid.height - 1:
  632. str += "\n"
  633. return str
  634. def _gen_grid(self, width, height):
  635. assert False, "_gen_grid needs to be implemented by each environment"
  636. def _reward(self):
  637. """
  638. Compute the reward to be given upon success
  639. """
  640. return 1 - 0.9 * (self.step_count / self.max_steps)
  641. def _rand_int(self, low, high):
  642. """
  643. Generate random integer in [low,high[
  644. """
  645. return self.np_random.integers(low, high)
  646. def _rand_float(self, low, high):
  647. """
  648. Generate random float in [low,high[
  649. """
  650. return self.np_random.uniform(low, high)
  651. def _rand_bool(self):
  652. """
  653. Generate random boolean value
  654. """
  655. return self.np_random.integers(0, 2) == 0
  656. def _rand_elem(self, iterable):
  657. """
  658. Pick a random element in a list
  659. """
  660. lst = list(iterable)
  661. idx = self._rand_int(0, len(lst))
  662. return lst[idx]
  663. def _rand_subset(self, iterable, num_elems):
  664. """
  665. Sample a random subset of distinct elements of a list
  666. """
  667. lst = list(iterable)
  668. assert num_elems <= len(lst)
  669. out = []
  670. while len(out) < num_elems:
  671. elem = self._rand_elem(lst)
  672. lst.remove(elem)
  673. out.append(elem)
  674. return out
  675. def _rand_color(self):
  676. """
  677. Generate a random color name (string)
  678. """
  679. return self._rand_elem(COLOR_NAMES)
  680. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  681. """
  682. Generate a random (x,y) position tuple
  683. """
  684. return (
  685. self.np_random.integers(xLow, xHigh),
  686. self.np_random.integers(yLow, yHigh),
  687. )
  688. def place_obj(self, obj, top=None, size=None, reject_fn=None, max_tries=math.inf):
  689. """
  690. Place an object at an empty position in the grid
  691. :param top: top-left position of the rectangle where to place
  692. :param size: size of the rectangle where to place
  693. :param reject_fn: function to filter out potential positions
  694. """
  695. if top is None:
  696. top = (0, 0)
  697. else:
  698. top = (max(top[0], 0), max(top[1], 0))
  699. if size is None:
  700. size = (self.grid.width, self.grid.height)
  701. num_tries = 0
  702. while True:
  703. # This is to handle with rare cases where rejection sampling
  704. # gets stuck in an infinite loop
  705. if num_tries > max_tries:
  706. raise RecursionError("rejection sampling failed in place_obj")
  707. num_tries += 1
  708. pos = np.array(
  709. (
  710. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  711. self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
  712. )
  713. )
  714. # Don't place the object on top of another object
  715. if self.grid.get(*pos) is not None:
  716. continue
  717. # Don't place the object where the agent is
  718. if np.array_equal(pos, self.agent_pos):
  719. continue
  720. # Check if there is a filtering criterion
  721. if reject_fn and reject_fn(self, pos):
  722. continue
  723. break
  724. self.grid.set(*pos, obj)
  725. if obj is not None:
  726. obj.init_pos = pos
  727. obj.cur_pos = pos
  728. return pos
  729. def put_obj(self, obj, i, j):
  730. """
  731. Put an object at a specific position in the grid
  732. """
  733. self.grid.set(i, j, obj)
  734. obj.init_pos = (i, j)
  735. obj.cur_pos = (i, j)
  736. def place_agent(self, top=None, size=None, rand_dir=True, max_tries=math.inf):
  737. """
  738. Set the agent's starting point at an empty position in the grid
  739. """
  740. self.agent_pos = None
  741. pos = self.place_obj(None, top, size, max_tries=max_tries)
  742. self.agent_pos = pos
  743. if rand_dir:
  744. self.agent_dir = self._rand_int(0, 4)
  745. return pos
  746. @property
  747. def dir_vec(self):
  748. """
  749. Get the direction vector for the agent, pointing in the direction
  750. of forward movement.
  751. """
  752. assert self.agent_dir >= 0 and self.agent_dir < 4
  753. return DIR_TO_VEC[self.agent_dir]
  754. @property
  755. def right_vec(self):
  756. """
  757. Get the vector pointing to the right of the agent.
  758. """
  759. dx, dy = self.dir_vec
  760. return np.array((-dy, dx))
  761. @property
  762. def front_pos(self):
  763. """
  764. Get the position of the cell that is right in front of the agent
  765. """
  766. return self.agent_pos + self.dir_vec
  767. def get_view_coords(self, i, j):
  768. """
  769. Translate and rotate absolute grid coordinates (i, j) into the
  770. agent's partially observable view (sub-grid). Note that the resulting
  771. coordinates may be negative or outside of the agent's view size.
  772. """
  773. ax, ay = self.agent_pos
  774. dx, dy = self.dir_vec
  775. rx, ry = self.right_vec
  776. # Compute the absolute coordinates of the top-left view corner
  777. sz = self.agent_view_size
  778. hs = self.agent_view_size // 2
  779. tx = ax + (dx * (sz - 1)) - (rx * hs)
  780. ty = ay + (dy * (sz - 1)) - (ry * hs)
  781. lx = i - tx
  782. ly = j - ty
  783. # Project the coordinates of the object relative to the top-left
  784. # corner onto the agent's own coordinate system
  785. vx = rx * lx + ry * ly
  786. vy = -(dx * lx + dy * ly)
  787. return vx, vy
  788. def get_view_exts(self, agent_view_size=None):
  789. """
  790. Get the extents of the square set of tiles visible to the agent
  791. Note: the bottom extent indices are not included in the set
  792. if agent_view_size is None, use self.agent_view_size
  793. """
  794. agent_view_size = agent_view_size or self.agent_view_size
  795. # Facing right
  796. if self.agent_dir == 0:
  797. topX = self.agent_pos[0]
  798. topY = self.agent_pos[1] - agent_view_size // 2
  799. # Facing down
  800. elif self.agent_dir == 1:
  801. topX = self.agent_pos[0] - agent_view_size // 2
  802. topY = self.agent_pos[1]
  803. # Facing left
  804. elif self.agent_dir == 2:
  805. topX = self.agent_pos[0] - agent_view_size + 1
  806. topY = self.agent_pos[1] - agent_view_size // 2
  807. # Facing up
  808. elif self.agent_dir == 3:
  809. topX = self.agent_pos[0] - agent_view_size // 2
  810. topY = self.agent_pos[1] - agent_view_size + 1
  811. else:
  812. assert False, "invalid agent direction"
  813. botX = topX + agent_view_size
  814. botY = topY + agent_view_size
  815. return (topX, topY, botX, botY)
  816. def relative_coords(self, x, y):
  817. """
  818. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  819. """
  820. vx, vy = self.get_view_coords(x, y)
  821. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  822. return None
  823. return vx, vy
  824. def in_view(self, x, y):
  825. """
  826. check if a grid position is visible to the agent
  827. """
  828. return self.relative_coords(x, y) is not None
  829. def agent_sees(self, x, y):
  830. """
  831. Check if a non-empty grid position is visible to the agent
  832. """
  833. coordinates = self.relative_coords(x, y)
  834. if coordinates is None:
  835. return False
  836. vx, vy = coordinates
  837. obs = self.gen_obs()
  838. obs_grid, _ = Grid.decode(obs["image"])
  839. obs_cell = obs_grid.get(vx, vy)
  840. world_cell = self.grid.get(x, y)
  841. return obs_cell is not None and obs_cell.type == world_cell.type
  842. def step(self, action):
  843. self.step_count += 1
  844. reward = 0
  845. done = False
  846. # Get the position in front of the agent
  847. fwd_pos = self.front_pos
  848. # Get the contents of the cell in front of the agent
  849. fwd_cell = self.grid.get(*fwd_pos)
  850. # Rotate left
  851. if action == self.actions.left:
  852. self.agent_dir -= 1
  853. if self.agent_dir < 0:
  854. self.agent_dir += 4
  855. # Rotate right
  856. elif action == self.actions.right:
  857. self.agent_dir = (self.agent_dir + 1) % 4
  858. # Move forward
  859. elif action == self.actions.forward:
  860. if fwd_cell is None or fwd_cell.can_overlap():
  861. self.agent_pos = fwd_pos
  862. if fwd_cell is not None and fwd_cell.type == "goal":
  863. done = True
  864. reward = self._reward()
  865. if fwd_cell is not None and fwd_cell.type == "lava":
  866. done = True
  867. # Pick up an object
  868. elif action == self.actions.pickup:
  869. if fwd_cell and fwd_cell.can_pickup():
  870. if self.carrying is None:
  871. self.carrying = fwd_cell
  872. self.carrying.cur_pos = np.array([-1, -1])
  873. self.grid.set(*fwd_pos, None)
  874. # Drop an object
  875. elif action == self.actions.drop:
  876. if not fwd_cell and self.carrying:
  877. self.grid.set(*fwd_pos, self.carrying)
  878. self.carrying.cur_pos = fwd_pos
  879. self.carrying = None
  880. # Toggle/activate an object
  881. elif action == self.actions.toggle:
  882. if fwd_cell:
  883. fwd_cell.toggle(self, fwd_pos)
  884. # Done action (not used by default)
  885. elif action == self.actions.done:
  886. pass
  887. else:
  888. assert False, "unknown action"
  889. if self.step_count >= self.max_steps:
  890. done = True
  891. obs = self.gen_obs()
  892. return obs, reward, done, {}
  893. def gen_obs_grid(self, agent_view_size=None):
  894. """
  895. Generate the sub-grid observed by the agent.
  896. This method also outputs a visibility mask telling us which grid
  897. cells the agent can actually see.
  898. if agent_view_size is None, self.agent_view_size is used
  899. """
  900. topX, topY, botX, botY = self.get_view_exts(agent_view_size)
  901. agent_view_size = agent_view_size or self.agent_view_size
  902. grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
  903. for i in range(self.agent_dir + 1):
  904. grid = grid.rotate_left()
  905. # Process occluders and visibility
  906. # Note that this incurs some performance cost
  907. if not self.see_through_walls:
  908. vis_mask = grid.process_vis(
  909. agent_pos=(agent_view_size // 2, agent_view_size - 1)
  910. )
  911. else:
  912. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  913. # Make it so the agent sees what it's carrying
  914. # We do this by placing the carried object at the agent's position
  915. # in the agent's partially observable view
  916. agent_pos = grid.width // 2, grid.height - 1
  917. if self.carrying:
  918. grid.set(*agent_pos, self.carrying)
  919. else:
  920. grid.set(*agent_pos, None)
  921. return grid, vis_mask
  922. def gen_obs(self):
  923. """
  924. Generate the agent's view (partially observable, low-resolution encoding)
  925. """
  926. grid, vis_mask = self.gen_obs_grid()
  927. # Encode the partially observable view into a numpy array
  928. image = grid.encode(vis_mask)
  929. assert hasattr(
  930. self, "mission"
  931. ), "environments must define a textual mission string"
  932. # Observations are dictionaries containing:
  933. # - an image (partially observable view of the environment)
  934. # - the agent's direction/orientation (acting as a compass)
  935. # - a textual mission string (instructions for the agent)
  936. obs = {"image": image, "direction": self.agent_dir, "mission": self.mission}
  937. return obs
  938. def get_obs_render(self, obs, tile_size=TILE_PIXELS // 2):
  939. """
  940. Render an agent observation for visualization
  941. """
  942. grid, vis_mask = Grid.decode(obs)
  943. # Render the whole grid
  944. img = grid.render(
  945. tile_size,
  946. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  947. agent_dir=3,
  948. highlight_mask=vis_mask,
  949. )
  950. return img
  951. def render(self, mode="human", close=False, highlight=True, tile_size=TILE_PIXELS):
  952. """
  953. Render the whole-grid human view
  954. """
  955. if self.render_mode is not None:
  956. mode = self.render_mode
  957. if close:
  958. if self.window:
  959. self.window.close()
  960. return
  961. if mode == "human" and not self.window:
  962. import gym_minigrid.window
  963. self.window = gym_minigrid.window.Window("gym_minigrid")
  964. self.window.show(block=False)
  965. # Compute which cells are visible to the agent
  966. _, vis_mask = self.gen_obs_grid()
  967. # Compute the world coordinates of the bottom-left corner
  968. # of the agent's view area
  969. f_vec = self.dir_vec
  970. r_vec = self.right_vec
  971. top_left = (
  972. self.agent_pos
  973. + f_vec * (self.agent_view_size - 1)
  974. - r_vec * (self.agent_view_size // 2)
  975. )
  976. # Mask of which cells to highlight
  977. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  978. # For each cell in the visibility mask
  979. for vis_j in range(0, self.agent_view_size):
  980. for vis_i in range(0, self.agent_view_size):
  981. # If this cell is not visible, don't highlight it
  982. if not vis_mask[vis_i, vis_j]:
  983. continue
  984. # Compute the world coordinates of this cell
  985. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  986. if abs_i < 0 or abs_i >= self.width:
  987. continue
  988. if abs_j < 0 or abs_j >= self.height:
  989. continue
  990. # Mark this cell to be highlighted
  991. highlight_mask[abs_i, abs_j] = True
  992. # Render the whole grid
  993. img = self.grid.render(
  994. tile_size,
  995. self.agent_pos,
  996. self.agent_dir,
  997. highlight_mask=highlight_mask if highlight else None,
  998. )
  999. if mode == "human":
  1000. self.window.set_caption(self.mission)
  1001. self.window.show_img(img)
  1002. return img
  1003. def close(self):
  1004. if self.window:
  1005. self.window.close()
  1006. return