minigrid.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332
  1. import math
  2. import hashlib
  3. import string
  4. import gym
  5. from enum import IntEnum
  6. import numpy as np
  7. from gym import error, spaces, utils
  8. from .rendering import *
  9. # Size in pixels of a tile in the full-scale human view
  10. TILE_PIXELS = 32
  11. # Map of color names to RGB values
  12. COLORS = {
  13. 'red': np.array([255, 0, 0]),
  14. 'green': np.array([0, 255, 0]),
  15. 'blue': np.array([0, 0, 255]),
  16. 'purple': np.array([112, 39, 195]),
  17. 'yellow': np.array([255, 255, 0]),
  18. 'grey': np.array([100, 100, 100])
  19. }
  20. COLOR_NAMES = sorted(list(COLORS.keys()))
  21. # Used to map colors to integers
  22. COLOR_TO_IDX = {
  23. 'red': 0,
  24. 'green': 1,
  25. 'blue': 2,
  26. 'purple': 3,
  27. 'yellow': 4,
  28. 'grey': 5
  29. }
  30. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  31. # Map of object type to integers
  32. OBJECT_TO_IDX = {
  33. 'unseen': 0,
  34. 'empty': 1,
  35. 'wall': 2,
  36. 'floor': 3,
  37. 'door': 4,
  38. 'key': 5,
  39. 'ball': 6,
  40. 'box': 7,
  41. 'goal': 8,
  42. 'lava': 9,
  43. 'agent': 10,
  44. }
  45. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  46. # Map of state names to integers
  47. STATE_TO_IDX = {
  48. 'open': 0,
  49. 'closed': 1,
  50. 'locked': 2,
  51. }
  52. # Map of agent direction indices to vectors
  53. DIR_TO_VEC = [
  54. # Pointing right (positive X)
  55. np.array((1, 0)),
  56. # Down (positive Y)
  57. np.array((0, 1)),
  58. # Pointing left (negative X)
  59. np.array((-1, 0)),
  60. # Up (negative Y)
  61. np.array((0, -1)),
  62. ]
  63. class WorldObj:
  64. """
  65. Base class for grid world objects
  66. """
  67. def __init__(self, type, color):
  68. assert type in OBJECT_TO_IDX, type
  69. assert color in COLOR_TO_IDX, color
  70. self.type = type
  71. self.color = color
  72. self.contains = None
  73. # Initial position of the object
  74. self.init_pos = None
  75. # Current position of the object
  76. self.cur_pos = None
  77. def can_overlap(self):
  78. """Can the agent overlap with this?"""
  79. return False
  80. def can_pickup(self):
  81. """Can the agent pick this up?"""
  82. return False
  83. def can_contain(self):
  84. """Can this contain another object?"""
  85. return False
  86. def see_behind(self):
  87. """Can the agent see behind this object?"""
  88. return True
  89. def toggle(self, env, pos):
  90. """Method to trigger/toggle an action this object performs"""
  91. return False
  92. def encode(self):
  93. """Encode the a description of this object as a 3-tuple of integers"""
  94. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  95. @staticmethod
  96. def decode(type_idx, color_idx, state):
  97. """Create an object from a 3-tuple state description"""
  98. obj_type = IDX_TO_OBJECT[type_idx]
  99. color = IDX_TO_COLOR[color_idx]
  100. if obj_type == 'empty' or obj_type == 'unseen':
  101. return None
  102. # State, 0: open, 1: closed, 2: locked
  103. is_open = state == 0
  104. is_locked = state == 2
  105. if obj_type == 'wall':
  106. v = Wall(color)
  107. elif obj_type == 'floor':
  108. v = Floor(color)
  109. elif obj_type == 'ball':
  110. v = Ball(color)
  111. elif obj_type == 'key':
  112. v = Key(color)
  113. elif obj_type == 'box':
  114. v = Box(color)
  115. elif obj_type == 'door':
  116. v = Door(color, is_open, is_locked)
  117. elif obj_type == 'goal':
  118. v = Goal()
  119. elif obj_type == 'lava':
  120. v = Lava()
  121. else:
  122. assert False, "unknown object type in decode '%s'" % obj_type
  123. return v
  124. def render(self, r):
  125. """Draw this object with the given renderer"""
  126. raise NotImplementedError
  127. class Goal(WorldObj):
  128. def __init__(self):
  129. super().__init__('goal', 'green')
  130. def can_overlap(self):
  131. return True
  132. def render(self, img):
  133. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  134. class Floor(WorldObj):
  135. """
  136. Colored floor tile the agent can walk over
  137. """
  138. def __init__(self, color='blue'):
  139. super().__init__('floor', color)
  140. def can_overlap(self):
  141. return True
  142. def render(self, img):
  143. # Give the floor a pale color
  144. color = COLORS[self.color] / 2
  145. fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
  146. class Lava(WorldObj):
  147. def __init__(self):
  148. super().__init__('lava', 'red')
  149. def can_overlap(self):
  150. return True
  151. def render(self, img):
  152. c = (255, 128, 0)
  153. # Background color
  154. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  155. # Little waves
  156. for i in range(3):
  157. ylo = 0.3 + 0.2 * i
  158. yhi = 0.4 + 0.2 * i
  159. fill_coords(img, point_in_line(
  160. 0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
  161. fill_coords(img, point_in_line(
  162. 0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
  163. fill_coords(img, point_in_line(
  164. 0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
  165. fill_coords(img, point_in_line(
  166. 0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
  167. class Wall(WorldObj):
  168. def __init__(self, color='grey'):
  169. super().__init__('wall', color)
  170. def see_behind(self):
  171. return False
  172. def render(self, img):
  173. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  174. class Door(WorldObj):
  175. def __init__(self, color, is_open=False, is_locked=False):
  176. super().__init__('door', color)
  177. self.is_open = is_open
  178. self.is_locked = is_locked
  179. def can_overlap(self):
  180. """The agent can only walk over this cell when the door is open"""
  181. return self.is_open
  182. def see_behind(self):
  183. return self.is_open
  184. def toggle(self, env, pos):
  185. # If the player has the right key to open the door
  186. if self.is_locked:
  187. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  188. self.is_locked = False
  189. self.is_open = True
  190. return True
  191. return False
  192. self.is_open = not self.is_open
  193. return True
  194. def encode(self):
  195. """Encode the a description of this object as a 3-tuple of integers"""
  196. # State, 0: open, 1: closed, 2: locked
  197. if self.is_open:
  198. state = 0
  199. elif self.is_locked:
  200. state = 2
  201. elif not self.is_open:
  202. state = 1
  203. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  204. def render(self, img):
  205. c = COLORS[self.color]
  206. if self.is_open:
  207. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  208. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
  209. return
  210. # Door frame and door
  211. if self.is_locked:
  212. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  213. fill_coords(img, point_in_rect(
  214. 0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  215. # Draw key slot
  216. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  217. else:
  218. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  219. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
  220. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  221. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
  222. # Draw door handle
  223. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  224. class Key(WorldObj):
  225. def __init__(self, color='blue'):
  226. super(Key, self).__init__('key', color)
  227. def can_pickup(self):
  228. return True
  229. def render(self, img):
  230. c = COLORS[self.color]
  231. # Vertical quad
  232. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  233. # Teeth
  234. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  235. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  236. # Ring
  237. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  238. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
  239. class Ball(WorldObj):
  240. def __init__(self, color='blue'):
  241. super(Ball, self).__init__('ball', color)
  242. def can_pickup(self):
  243. return True
  244. def render(self, img):
  245. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  246. class Box(WorldObj):
  247. def __init__(self, color, contains=None):
  248. super(Box, self).__init__('box', color)
  249. self.contains = contains
  250. def can_pickup(self):
  251. return True
  252. def render(self, img):
  253. c = COLORS[self.color]
  254. # Outline
  255. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  256. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
  257. # Horizontal slit
  258. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  259. def toggle(self, env, pos):
  260. # Replace the box by its contents
  261. env.grid.set(*pos, self.contains)
  262. return True
  263. class Grid:
  264. """
  265. Represent a grid and operations on it
  266. """
  267. # Static cache of pre-renderer tiles
  268. tile_cache = {}
  269. def __init__(self, width, height):
  270. assert width >= 3
  271. assert height >= 3
  272. self.width = width
  273. self.height = height
  274. self.grid = [None] * width * height
  275. def __contains__(self, key):
  276. if isinstance(key, WorldObj):
  277. for e in self.grid:
  278. if e is key:
  279. return True
  280. elif isinstance(key, tuple):
  281. for e in self.grid:
  282. if e is None:
  283. continue
  284. if (e.color, e.type) == key:
  285. return True
  286. if key[0] is None and key[1] == e.type:
  287. return True
  288. return False
  289. def __eq__(self, other):
  290. grid1 = self.encode()
  291. grid2 = other.encode()
  292. return np.array_equal(grid2, grid1)
  293. def __ne__(self, other):
  294. return not self == other
  295. def copy(self):
  296. from copy import deepcopy
  297. return deepcopy(self)
  298. def set(self, i, j, v):
  299. assert i >= 0 and i < self.width
  300. assert j >= 0 and j < self.height
  301. self.grid[j * self.width + i] = v
  302. def get(self, i, j):
  303. assert i >= 0 and i < self.width
  304. assert j >= 0 and j < self.height
  305. return self.grid[j * self.width + i]
  306. def horz_wall(self, x, y, length=None, obj_type=Wall):
  307. if length is None:
  308. length = self.width - x
  309. for i in range(0, length):
  310. self.set(x + i, y, obj_type())
  311. def vert_wall(self, x, y, length=None, obj_type=Wall):
  312. if length is None:
  313. length = self.height - y
  314. for j in range(0, length):
  315. self.set(x, y + j, obj_type())
  316. def wall_rect(self, x, y, w, h):
  317. self.horz_wall(x, y, w)
  318. self.horz_wall(x, y+h-1, w)
  319. self.vert_wall(x, y, h)
  320. self.vert_wall(x+w-1, y, h)
  321. def rotate_left(self):
  322. """
  323. Rotate the grid to the left (counter-clockwise)
  324. """
  325. grid = Grid(self.height, self.width)
  326. for i in range(self.width):
  327. for j in range(self.height):
  328. v = self.get(i, j)
  329. grid.set(j, grid.height - 1 - i, v)
  330. return grid
  331. def slice(self, topX, topY, width, height):
  332. """
  333. Get a subset of the grid
  334. """
  335. grid = Grid(width, height)
  336. for j in range(0, height):
  337. for i in range(0, width):
  338. x = topX + i
  339. y = topY + j
  340. if x >= 0 and x < self.width and \
  341. y >= 0 and y < self.height:
  342. v = self.get(x, y)
  343. else:
  344. v = Wall()
  345. grid.set(i, j, v)
  346. return grid
  347. @classmethod
  348. def render_tile(
  349. cls,
  350. obj,
  351. agent_dir=None,
  352. highlight=False,
  353. tile_size=TILE_PIXELS,
  354. subdivs=3
  355. ):
  356. """
  357. Render a tile and cache the result
  358. """
  359. # Hash map lookup key for the cache
  360. key = (agent_dir, highlight, tile_size)
  361. key = obj.encode() + key if obj else key
  362. if key in cls.tile_cache:
  363. return cls.tile_cache[key]
  364. img = np.zeros(shape=(tile_size * subdivs,
  365. tile_size * subdivs, 3), dtype=np.uint8)
  366. # Draw the grid lines (top and left edges)
  367. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  368. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  369. if obj != None:
  370. obj.render(img)
  371. # Overlay the agent on top
  372. if agent_dir is not None:
  373. tri_fn = point_in_triangle(
  374. (0.12, 0.19),
  375. (0.87, 0.50),
  376. (0.12, 0.81),
  377. )
  378. # Rotate the agent based on its direction
  379. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5,
  380. theta=0.5*math.pi*agent_dir)
  381. fill_coords(img, tri_fn, (255, 0, 0))
  382. # Highlight the cell if needed
  383. if highlight:
  384. highlight_img(img)
  385. # Downsample the image to perform supersampling/anti-aliasing
  386. img = downsample(img, subdivs)
  387. # Cache the rendered tile
  388. cls.tile_cache[key] = img
  389. return img
  390. def render(
  391. self,
  392. tile_size,
  393. agent_pos=None,
  394. agent_dir=None,
  395. highlight_mask=None
  396. ):
  397. """
  398. Render this grid at a given scale
  399. :param r: target renderer object
  400. :param tile_size: tile size in pixels
  401. """
  402. if highlight_mask is None:
  403. highlight_mask = np.zeros(
  404. shape=(self.width, self.height), dtype=bool)
  405. # Compute the total grid size
  406. width_px = self.width * tile_size
  407. height_px = self.height * tile_size
  408. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  409. # Render the grid
  410. for j in range(0, self.height):
  411. for i in range(0, self.width):
  412. cell = self.get(i, j)
  413. agent_here = np.array_equal(agent_pos, (i, j))
  414. tile_img = Grid.render_tile(
  415. cell,
  416. agent_dir=agent_dir if agent_here else None,
  417. highlight=highlight_mask[i, j],
  418. tile_size=tile_size
  419. )
  420. ymin = j * tile_size
  421. ymax = (j+1) * tile_size
  422. xmin = i * tile_size
  423. xmax = (i+1) * tile_size
  424. img[ymin:ymax, xmin:xmax, :] = tile_img
  425. return img
  426. def encode(self, vis_mask=None):
  427. """
  428. Produce a compact numpy encoding of the grid
  429. """
  430. if vis_mask is None:
  431. vis_mask = np.ones((self.width, self.height), dtype=bool)
  432. array = np.zeros((self.width, self.height, 3), dtype='uint8')
  433. for i in range(self.width):
  434. for j in range(self.height):
  435. if vis_mask[i, j]:
  436. v = self.get(i, j)
  437. if v is None:
  438. array[i, j, 0] = OBJECT_TO_IDX['empty']
  439. array[i, j, 1] = 0
  440. array[i, j, 2] = 0
  441. else:
  442. array[i, j, :] = v.encode()
  443. return array
  444. @staticmethod
  445. def decode(array):
  446. """
  447. Decode an array grid encoding back into a grid
  448. """
  449. width, height, channels = array.shape
  450. assert channels == 3
  451. vis_mask = np.ones(shape=(width, height), dtype=bool)
  452. grid = Grid(width, height)
  453. for i in range(width):
  454. for j in range(height):
  455. type_idx, color_idx, state = array[i, j]
  456. v = WorldObj.decode(type_idx, color_idx, state)
  457. grid.set(i, j, v)
  458. vis_mask[i, j] = (type_idx != OBJECT_TO_IDX['unseen'])
  459. return grid, vis_mask
  460. def process_vis(grid, agent_pos):
  461. mask = np.zeros(shape=(grid.width, grid.height), dtype=bool)
  462. mask[agent_pos[0], agent_pos[1]] = True
  463. for j in reversed(range(0, grid.height)):
  464. for i in range(0, grid.width-1):
  465. if not mask[i, j]:
  466. continue
  467. cell = grid.get(i, j)
  468. if cell and not cell.see_behind():
  469. continue
  470. mask[i+1, j] = True
  471. if j > 0:
  472. mask[i+1, j-1] = True
  473. mask[i, j-1] = True
  474. for i in reversed(range(1, grid.width)):
  475. if not mask[i, j]:
  476. continue
  477. cell = grid.get(i, j)
  478. if cell and not cell.see_behind():
  479. continue
  480. mask[i-1, j] = True
  481. if j > 0:
  482. mask[i-1, j-1] = True
  483. mask[i, j-1] = True
  484. for j in range(0, grid.height):
  485. for i in range(0, grid.width):
  486. if not mask[i, j]:
  487. grid.set(i, j, None)
  488. return mask
  489. class MiniGridEnv(gym.Env):
  490. """
  491. 2D grid world game environment
  492. """
  493. metadata = {
  494. # Deprecated: use 'render_modes' instead
  495. 'render.modes': ['human', 'rgb_array'],
  496. 'video.frames_per_second': 10, # Deprecated: use 'render_fps' instead
  497. 'render_modes': ['human', 'rgb_array'],
  498. 'render_fps': 10
  499. }
  500. # Enumeration of possible actions
  501. class Actions(IntEnum):
  502. # Turn left, turn right, move forward
  503. left = 0
  504. right = 1
  505. forward = 2
  506. # Pick up an object
  507. pickup = 3
  508. # Drop an object
  509. drop = 4
  510. # Toggle/activate an object
  511. toggle = 5
  512. # Done completing task
  513. done = 6
  514. def __init__(
  515. self,
  516. grid_size=None,
  517. width=None,
  518. height=None,
  519. max_steps=100,
  520. see_through_walls=False,
  521. agent_view_size=7,
  522. render_mode=None,
  523. **kwargs
  524. ):
  525. # Can't set both grid_size and width/height
  526. if grid_size:
  527. assert width == None and height == None
  528. width = grid_size
  529. height = grid_size
  530. # Action enumeration for this environment
  531. self.actions = MiniGridEnv.Actions
  532. # Actions are discrete integer values
  533. self.action_space = spaces.Discrete(len(self.actions))
  534. # Number of cells (width and height) in the agent view
  535. assert agent_view_size % 2 == 1
  536. assert agent_view_size >= 3
  537. self.agent_view_size = agent_view_size
  538. # Observations are dictionaries containing an
  539. # encoding of the grid and a textual 'mission' string
  540. self.observation_space = spaces.Box(
  541. low=0,
  542. high=255,
  543. shape=(self.agent_view_size, self.agent_view_size, 3),
  544. dtype='uint8'
  545. )
  546. self.observation_space = spaces.Dict({
  547. 'image': self.observation_space,
  548. 'direction': spaces.Discrete(4),
  549. 'mission': spaces.Text(max_length=200,
  550. charset=string.ascii_letters + string.digits + ' .,!-'
  551. )
  552. })
  553. # render mode
  554. self.render_mode = render_mode
  555. # Range of possible rewards
  556. self.reward_range = (0, 1)
  557. # Window to use for human rendering mode
  558. self.window = None
  559. # Environment configuration
  560. self.width = width
  561. self.height = height
  562. self.max_steps = max_steps
  563. self.see_through_walls = see_through_walls
  564. # Current position and direction of the agent
  565. self.agent_pos = None
  566. self.agent_dir = None
  567. # Initialize the state
  568. self.reset()
  569. def reset(self, *, seed=None, return_info=False, options=None):
  570. super().reset(seed=seed)
  571. # Current position and direction of the agent
  572. self.agent_pos = None
  573. self.agent_dir = None
  574. # Generate a new random grid at the start of each episode
  575. self._gen_grid(self.width, self.height)
  576. # These fields should be defined by _gen_grid
  577. assert self.agent_pos is not None
  578. assert self.agent_dir is not None
  579. # Check that the agent doesn't overlap with an object
  580. start_cell = self.grid.get(*self.agent_pos)
  581. assert start_cell is None or start_cell.can_overlap()
  582. # Item picked up, being carried, initially nothing
  583. self.carrying = None
  584. # Step count since episode start
  585. self.step_count = 0
  586. # Return first observation
  587. obs = self.gen_obs()
  588. return obs
  589. def hash(self, size=16):
  590. """Compute a hash that uniquely identifies the current state of the environment.
  591. :param size: Size of the hashing
  592. """
  593. sample_hash = hashlib.sha256()
  594. to_encode = [self.grid.encode().tolist(), self.agent_pos,
  595. self.agent_dir]
  596. for item in to_encode:
  597. sample_hash.update(str(item).encode('utf8'))
  598. return sample_hash.hexdigest()[:size]
  599. @property
  600. def steps_remaining(self):
  601. return self.max_steps - self.step_count
  602. def __str__(self):
  603. """
  604. Produce a pretty string of the environment's grid along with the agent.
  605. A grid cell is represented by 2-character string, the first one for
  606. the object and the second one for the color.
  607. """
  608. # Map of object types to short string
  609. OBJECT_TO_STR = {
  610. 'wall': 'W',
  611. 'floor': 'F',
  612. 'door': 'D',
  613. 'key': 'K',
  614. 'ball': 'A',
  615. 'box': 'B',
  616. 'goal': 'G',
  617. 'lava': 'V',
  618. }
  619. # Short string for opened door
  620. OPENDED_DOOR_IDS = '_'
  621. # Map agent's direction to short string
  622. AGENT_DIR_TO_STR = {
  623. 0: '>',
  624. 1: 'V',
  625. 2: '<',
  626. 3: '^'
  627. }
  628. str = ''
  629. for j in range(self.grid.height):
  630. for i in range(self.grid.width):
  631. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  632. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  633. continue
  634. c = self.grid.get(i, j)
  635. if c == None:
  636. str += ' '
  637. continue
  638. if c.type == 'door':
  639. if c.is_open:
  640. str += '__'
  641. elif c.is_locked:
  642. str += 'L' + c.color[0].upper()
  643. else:
  644. str += 'D' + c.color[0].upper()
  645. continue
  646. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  647. if j < self.grid.height - 1:
  648. str += '\n'
  649. return str
  650. def _gen_grid(self, width, height):
  651. assert False, "_gen_grid needs to be implemented by each environment"
  652. def _reward(self):
  653. """
  654. Compute the reward to be given upon success
  655. """
  656. return 1 - 0.9 * (self.step_count / self.max_steps)
  657. def _rand_int(self, low, high):
  658. """
  659. Generate random integer in [low,high[
  660. """
  661. return self.np_random.integers(low, high)
  662. def _rand_float(self, low, high):
  663. """
  664. Generate random float in [low,high[
  665. """
  666. return self.np_random.uniform(low, high)
  667. def _rand_bool(self):
  668. """
  669. Generate random boolean value
  670. """
  671. return (self.np_random.integers(0, 2) == 0)
  672. def _rand_elem(self, iterable):
  673. """
  674. Pick a random element in a list
  675. """
  676. lst = list(iterable)
  677. idx = self._rand_int(0, len(lst))
  678. return lst[idx]
  679. def _rand_subset(self, iterable, num_elems):
  680. """
  681. Sample a random subset of distinct elements of a list
  682. """
  683. lst = list(iterable)
  684. assert num_elems <= len(lst)
  685. out = []
  686. while len(out) < num_elems:
  687. elem = self._rand_elem(lst)
  688. lst.remove(elem)
  689. out.append(elem)
  690. return out
  691. def _rand_color(self):
  692. """
  693. Generate a random color name (string)
  694. """
  695. return self._rand_elem(COLOR_NAMES)
  696. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  697. """
  698. Generate a random (x,y) position tuple
  699. """
  700. return (
  701. self.np_random.integers(xLow, xHigh),
  702. self.np_random.integers(yLow, yHigh)
  703. )
  704. def place_obj(self,
  705. obj,
  706. top=None,
  707. size=None,
  708. reject_fn=None,
  709. max_tries=math.inf
  710. ):
  711. """
  712. Place an object at an empty position in the grid
  713. :param top: top-left position of the rectangle where to place
  714. :param size: size of the rectangle where to place
  715. :param reject_fn: function to filter out potential positions
  716. """
  717. if top is None:
  718. top = (0, 0)
  719. else:
  720. top = (max(top[0], 0), max(top[1], 0))
  721. if size is None:
  722. size = (self.grid.width, self.grid.height)
  723. num_tries = 0
  724. while True:
  725. # This is to handle with rare cases where rejection sampling
  726. # gets stuck in an infinite loop
  727. if num_tries > max_tries:
  728. raise RecursionError('rejection sampling failed in place_obj')
  729. num_tries += 1
  730. pos = np.array((
  731. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  732. self._rand_int(top[1], min(top[1] + size[1], self.grid.height))
  733. ))
  734. # Don't place the object on top of another object
  735. if self.grid.get(*pos) != None:
  736. continue
  737. # Don't place the object where the agent is
  738. if np.array_equal(pos, self.agent_pos):
  739. continue
  740. # Check if there is a filtering criterion
  741. if reject_fn and reject_fn(self, pos):
  742. continue
  743. break
  744. self.grid.set(*pos, obj)
  745. if obj is not None:
  746. obj.init_pos = pos
  747. obj.cur_pos = pos
  748. return pos
  749. def put_obj(self, obj, i, j):
  750. """
  751. Put an object at a specific position in the grid
  752. """
  753. self.grid.set(i, j, obj)
  754. obj.init_pos = (i, j)
  755. obj.cur_pos = (i, j)
  756. def place_agent(
  757. self,
  758. top=None,
  759. size=None,
  760. rand_dir=True,
  761. max_tries=math.inf
  762. ):
  763. """
  764. Set the agent's starting point at an empty position in the grid
  765. """
  766. self.agent_pos = None
  767. pos = self.place_obj(None, top, size, max_tries=max_tries)
  768. self.agent_pos = pos
  769. if rand_dir:
  770. self.agent_dir = self._rand_int(0, 4)
  771. return pos
  772. @property
  773. def dir_vec(self):
  774. """
  775. Get the direction vector for the agent, pointing in the direction
  776. of forward movement.
  777. """
  778. assert self.agent_dir >= 0 and self.agent_dir < 4
  779. return DIR_TO_VEC[self.agent_dir]
  780. @property
  781. def right_vec(self):
  782. """
  783. Get the vector pointing to the right of the agent.
  784. """
  785. dx, dy = self.dir_vec
  786. return np.array((-dy, dx))
  787. @property
  788. def front_pos(self):
  789. """
  790. Get the position of the cell that is right in front of the agent
  791. """
  792. return self.agent_pos + self.dir_vec
  793. def get_view_coords(self, i, j):
  794. """
  795. Translate and rotate absolute grid coordinates (i, j) into the
  796. agent's partially observable view (sub-grid). Note that the resulting
  797. coordinates may be negative or outside of the agent's view size.
  798. """
  799. ax, ay = self.agent_pos
  800. dx, dy = self.dir_vec
  801. rx, ry = self.right_vec
  802. # Compute the absolute coordinates of the top-left view corner
  803. sz = self.agent_view_size
  804. hs = self.agent_view_size // 2
  805. tx = ax + (dx * (sz-1)) - (rx * hs)
  806. ty = ay + (dy * (sz-1)) - (ry * hs)
  807. lx = i - tx
  808. ly = j - ty
  809. # Project the coordinates of the object relative to the top-left
  810. # corner onto the agent's own coordinate system
  811. vx = (rx*lx + ry*ly)
  812. vy = -(dx*lx + dy*ly)
  813. return vx, vy
  814. def get_view_exts(self, agent_view_size=None):
  815. """
  816. Get the extents of the square set of tiles visible to the agent
  817. Note: the bottom extent indices are not included in the set
  818. if agent_view_size is None, use self.agent_view_size
  819. """
  820. agent_view_size = agent_view_size or self.agent_view_size
  821. # Facing right
  822. if self.agent_dir == 0:
  823. topX = self.agent_pos[0]
  824. topY = self.agent_pos[1] - agent_view_size // 2
  825. # Facing down
  826. elif self.agent_dir == 1:
  827. topX = self.agent_pos[0] - agent_view_size // 2
  828. topY = self.agent_pos[1]
  829. # Facing left
  830. elif self.agent_dir == 2:
  831. topX = self.agent_pos[0] - agent_view_size + 1
  832. topY = self.agent_pos[1] - agent_view_size // 2
  833. # Facing up
  834. elif self.agent_dir == 3:
  835. topX = self.agent_pos[0] - agent_view_size // 2
  836. topY = self.agent_pos[1] - agent_view_size + 1
  837. else:
  838. assert False, "invalid agent direction"
  839. botX = topX + agent_view_size
  840. botY = topY + agent_view_size
  841. return (topX, topY, botX, botY)
  842. def relative_coords(self, x, y):
  843. """
  844. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  845. """
  846. vx, vy = self.get_view_coords(x, y)
  847. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  848. return None
  849. return vx, vy
  850. def in_view(self, x, y):
  851. """
  852. check if a grid position is visible to the agent
  853. """
  854. return self.relative_coords(x, y) is not None
  855. def agent_sees(self, x, y):
  856. """
  857. Check if a non-empty grid position is visible to the agent
  858. """
  859. coordinates = self.relative_coords(x, y)
  860. if coordinates is None:
  861. return False
  862. vx, vy = coordinates
  863. obs = self.gen_obs()
  864. obs_grid, _ = Grid.decode(obs['image'])
  865. obs_cell = obs_grid.get(vx, vy)
  866. world_cell = self.grid.get(x, y)
  867. return obs_cell is not None and obs_cell.type == world_cell.type
  868. def step(self, action):
  869. self.step_count += 1
  870. reward = 0
  871. done = False
  872. # Get the position in front of the agent
  873. fwd_pos = self.front_pos
  874. # Get the contents of the cell in front of the agent
  875. fwd_cell = self.grid.get(*fwd_pos)
  876. # Rotate left
  877. if action == self.actions.left:
  878. self.agent_dir -= 1
  879. if self.agent_dir < 0:
  880. self.agent_dir += 4
  881. # Rotate right
  882. elif action == self.actions.right:
  883. self.agent_dir = (self.agent_dir + 1) % 4
  884. # Move forward
  885. elif action == self.actions.forward:
  886. if fwd_cell == None or fwd_cell.can_overlap():
  887. self.agent_pos = fwd_pos
  888. if fwd_cell != None and fwd_cell.type == 'goal':
  889. done = True
  890. reward = self._reward()
  891. if fwd_cell != None and fwd_cell.type == 'lava':
  892. done = True
  893. # Pick up an object
  894. elif action == self.actions.pickup:
  895. if fwd_cell and fwd_cell.can_pickup():
  896. if self.carrying is None:
  897. self.carrying = fwd_cell
  898. self.carrying.cur_pos = np.array([-1, -1])
  899. self.grid.set(*fwd_pos, None)
  900. # Drop an object
  901. elif action == self.actions.drop:
  902. if not fwd_cell and self.carrying:
  903. self.grid.set(*fwd_pos, self.carrying)
  904. self.carrying.cur_pos = fwd_pos
  905. self.carrying = None
  906. # Toggle/activate an object
  907. elif action == self.actions.toggle:
  908. if fwd_cell:
  909. fwd_cell.toggle(self, fwd_pos)
  910. # Done action (not used by default)
  911. elif action == self.actions.done:
  912. pass
  913. else:
  914. assert False, "unknown action"
  915. if self.step_count >= self.max_steps:
  916. done = True
  917. obs = self.gen_obs()
  918. return obs, reward, done, {}
  919. def gen_obs_grid(self, agent_view_size=None):
  920. """
  921. Generate the sub-grid observed by the agent.
  922. This method also outputs a visibility mask telling us which grid
  923. cells the agent can actually see.
  924. if agent_view_size is None, self.agent_view_size is used
  925. """
  926. topX, topY, botX, botY = self.get_view_exts(agent_view_size)
  927. agent_view_size = agent_view_size or self.agent_view_size
  928. grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
  929. for i in range(self.agent_dir + 1):
  930. grid = grid.rotate_left()
  931. # Process occluders and visibility
  932. # Note that this incurs some performance cost
  933. if not self.see_through_walls:
  934. vis_mask = grid.process_vis(agent_pos=(
  935. agent_view_size // 2, agent_view_size - 1))
  936. else:
  937. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  938. # Make it so the agent sees what it's carrying
  939. # We do this by placing the carried object at the agent's position
  940. # in the agent's partially observable view
  941. agent_pos = grid.width // 2, grid.height - 1
  942. if self.carrying:
  943. grid.set(*agent_pos, self.carrying)
  944. else:
  945. grid.set(*agent_pos, None)
  946. return grid, vis_mask
  947. def gen_obs(self):
  948. """
  949. Generate the agent's view (partially observable, low-resolution encoding)
  950. """
  951. grid, vis_mask = self.gen_obs_grid()
  952. # Encode the partially observable view into a numpy array
  953. image = grid.encode(vis_mask)
  954. assert hasattr(
  955. self, 'mission'), "environments must define a textual mission string"
  956. # Observations are dictionaries containing:
  957. # - an image (partially observable view of the environment)
  958. # - the agent's direction/orientation (acting as a compass)
  959. # - a textual mission string (instructions for the agent)
  960. obs = {
  961. 'image': image,
  962. 'direction': self.agent_dir,
  963. 'mission': self.mission
  964. }
  965. return obs
  966. def get_obs_render(self, obs, tile_size=TILE_PIXELS//2):
  967. """
  968. Render an agent observation for visualization
  969. """
  970. grid, vis_mask = Grid.decode(obs)
  971. # Render the whole grid
  972. img = grid.render(
  973. tile_size,
  974. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  975. agent_dir=3,
  976. highlight_mask=vis_mask
  977. )
  978. return img
  979. def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
  980. """
  981. Render the whole-grid human view
  982. """
  983. if self.render_mode is not None:
  984. mode = self.render_mode
  985. if close:
  986. if self.window:
  987. self.window.close()
  988. return
  989. if mode == 'human' and not self.window:
  990. import gym_minigrid.window
  991. self.window = gym_minigrid.window.Window('gym_minigrid')
  992. self.window.show(block=False)
  993. # Compute which cells are visible to the agent
  994. _, vis_mask = self.gen_obs_grid()
  995. # Compute the world coordinates of the bottom-left corner
  996. # of the agent's view area
  997. f_vec = self.dir_vec
  998. r_vec = self.right_vec
  999. top_left = self.agent_pos + f_vec * \
  1000. (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
  1001. # Mask of which cells to highlight
  1002. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  1003. # For each cell in the visibility mask
  1004. for vis_j in range(0, self.agent_view_size):
  1005. for vis_i in range(0, self.agent_view_size):
  1006. # If this cell is not visible, don't highlight it
  1007. if not vis_mask[vis_i, vis_j]:
  1008. continue
  1009. # Compute the world coordinates of this cell
  1010. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  1011. if abs_i < 0 or abs_i >= self.width:
  1012. continue
  1013. if abs_j < 0 or abs_j >= self.height:
  1014. continue
  1015. # Mark this cell to be highlighted
  1016. highlight_mask[abs_i, abs_j] = True
  1017. # Render the whole grid
  1018. img = self.grid.render(
  1019. tile_size,
  1020. self.agent_pos,
  1021. self.agent_dir,
  1022. highlight_mask=highlight_mask if highlight else None
  1023. )
  1024. if mode == 'human':
  1025. self.window.set_caption(self.mission)
  1026. self.window.show_img(img)
  1027. return img
  1028. def close(self):
  1029. if self.window:
  1030. self.window.close()
  1031. return