minigrid.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326
  1. import math
  2. import hashlib
  3. import string
  4. import gym
  5. from enum import IntEnum
  6. import numpy as np
  7. from gym import error, spaces, utils
  8. from gym.utils import seeding
  9. from .rendering import *
  10. # Size in pixels of a tile in the full-scale human view
  11. TILE_PIXELS = 32
  12. # Map of color names to RGB values
  13. COLORS = {
  14. 'red' : np.array([255, 0, 0]),
  15. 'green' : np.array([0, 255, 0]),
  16. 'blue' : np.array([0, 0, 255]),
  17. 'purple': np.array([112, 39, 195]),
  18. 'yellow': np.array([255, 255, 0]),
  19. 'grey' : np.array([100, 100, 100])
  20. }
  21. COLOR_NAMES = sorted(list(COLORS.keys()))
  22. # Used to map colors to integers
  23. COLOR_TO_IDX = {
  24. 'red' : 0,
  25. 'green' : 1,
  26. 'blue' : 2,
  27. 'purple': 3,
  28. 'yellow': 4,
  29. 'grey' : 5
  30. }
  31. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  32. # Map of object type to integers
  33. OBJECT_TO_IDX = {
  34. 'unseen' : 0,
  35. 'empty' : 1,
  36. 'wall' : 2,
  37. 'floor' : 3,
  38. 'door' : 4,
  39. 'key' : 5,
  40. 'ball' : 6,
  41. 'box' : 7,
  42. 'goal' : 8,
  43. 'lava' : 9,
  44. 'agent' : 10,
  45. }
  46. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  47. # Map of state names to integers
  48. STATE_TO_IDX = {
  49. 'open' : 0,
  50. 'closed': 1,
  51. 'locked': 2,
  52. }
  53. # Map of agent direction indices to vectors
  54. DIR_TO_VEC = [
  55. # Pointing right (positive X)
  56. np.array((1, 0)),
  57. # Down (positive Y)
  58. np.array((0, 1)),
  59. # Pointing left (negative X)
  60. np.array((-1, 0)),
  61. # Up (negative Y)
  62. np.array((0, -1)),
  63. ]
  64. class WorldObj:
  65. """
  66. Base class for grid world objects
  67. """
  68. def __init__(self, type, color):
  69. assert type in OBJECT_TO_IDX, type
  70. assert color in COLOR_TO_IDX, color
  71. self.type = type
  72. self.color = color
  73. self.contains = None
  74. # Initial position of the object
  75. self.init_pos = None
  76. # Current position of the object
  77. self.cur_pos = None
  78. def can_overlap(self):
  79. """Can the agent overlap with this?"""
  80. return False
  81. def can_pickup(self):
  82. """Can the agent pick this up?"""
  83. return False
  84. def can_contain(self):
  85. """Can this contain another object?"""
  86. return False
  87. def see_behind(self):
  88. """Can the agent see behind this object?"""
  89. return True
  90. def toggle(self, env, pos):
  91. """Method to trigger/toggle an action this object performs"""
  92. return False
  93. def encode(self):
  94. """Encode the a description of this object as a 3-tuple of integers"""
  95. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  96. @staticmethod
  97. def decode(type_idx, color_idx, state):
  98. """Create an object from a 3-tuple state description"""
  99. obj_type = IDX_TO_OBJECT[type_idx]
  100. color = IDX_TO_COLOR[color_idx]
  101. if obj_type == 'empty' or obj_type == 'unseen':
  102. return None
  103. # State, 0: open, 1: closed, 2: locked
  104. is_open = state == 0
  105. is_locked = state == 2
  106. if obj_type == 'wall':
  107. v = Wall(color)
  108. elif obj_type == 'floor':
  109. v = Floor(color)
  110. elif obj_type == 'ball':
  111. v = Ball(color)
  112. elif obj_type == 'key':
  113. v = Key(color)
  114. elif obj_type == 'box':
  115. v = Box(color)
  116. elif obj_type == 'door':
  117. v = Door(color, is_open, is_locked)
  118. elif obj_type == 'goal':
  119. v = Goal()
  120. elif obj_type == 'lava':
  121. v = Lava()
  122. else:
  123. assert False, "unknown object type in decode '%s'" % obj_type
  124. return v
  125. def render(self, r):
  126. """Draw this object with the given renderer"""
  127. raise NotImplementedError
  128. class Goal(WorldObj):
  129. def __init__(self):
  130. super().__init__('goal', 'green')
  131. def can_overlap(self):
  132. return True
  133. def render(self, img):
  134. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  135. class Floor(WorldObj):
  136. """
  137. Colored floor tile the agent can walk over
  138. """
  139. def __init__(self, color='blue'):
  140. super().__init__('floor', color)
  141. def can_overlap(self):
  142. return True
  143. def render(self, img):
  144. # Give the floor a pale color
  145. color = COLORS[self.color] / 2
  146. fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
  147. class Lava(WorldObj):
  148. def __init__(self):
  149. super().__init__('lava', 'red')
  150. def can_overlap(self):
  151. return True
  152. def render(self, img):
  153. c = (255, 128, 0)
  154. # Background color
  155. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  156. # Little waves
  157. for i in range(3):
  158. ylo = 0.3 + 0.2 * i
  159. yhi = 0.4 + 0.2 * i
  160. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0,0,0))
  161. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0,0,0))
  162. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0,0,0))
  163. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0,0,0))
  164. class Wall(WorldObj):
  165. def __init__(self, color='grey'):
  166. super().__init__('wall', color)
  167. def see_behind(self):
  168. return False
  169. def render(self, img):
  170. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  171. class Door(WorldObj):
  172. def __init__(self, color, is_open=False, is_locked=False):
  173. super().__init__('door', color)
  174. self.is_open = is_open
  175. self.is_locked = is_locked
  176. def can_overlap(self):
  177. """The agent can only walk over this cell when the door is open"""
  178. return self.is_open
  179. def see_behind(self):
  180. return self.is_open
  181. def toggle(self, env, pos):
  182. # If the player has the right key to open the door
  183. if self.is_locked:
  184. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  185. self.is_locked = False
  186. self.is_open = True
  187. return True
  188. return False
  189. self.is_open = not self.is_open
  190. return True
  191. def encode(self):
  192. """Encode the a description of this object as a 3-tuple of integers"""
  193. # State, 0: open, 1: closed, 2: locked
  194. if self.is_open:
  195. state = 0
  196. elif self.is_locked:
  197. state = 2
  198. elif not self.is_open:
  199. state = 1
  200. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  201. def render(self, img):
  202. c = COLORS[self.color]
  203. if self.is_open:
  204. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  205. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0,0,0))
  206. return
  207. # Door frame and door
  208. if self.is_locked:
  209. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  210. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  211. # Draw key slot
  212. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  213. else:
  214. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  215. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0,0,0))
  216. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  217. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0,0,0))
  218. # Draw door handle
  219. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  220. class Key(WorldObj):
  221. def __init__(self, color='blue'):
  222. super(Key, self).__init__('key', color)
  223. def can_pickup(self):
  224. return True
  225. def render(self, img):
  226. c = COLORS[self.color]
  227. # Vertical quad
  228. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  229. # Teeth
  230. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  231. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  232. # Ring
  233. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  234. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0,0,0))
  235. class Ball(WorldObj):
  236. def __init__(self, color='blue'):
  237. super(Ball, self).__init__('ball', color)
  238. def can_pickup(self):
  239. return True
  240. def render(self, img):
  241. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  242. class Box(WorldObj):
  243. def __init__(self, color, contains=None):
  244. super(Box, self).__init__('box', color)
  245. self.contains = contains
  246. def can_pickup(self):
  247. return True
  248. def render(self, img):
  249. c = COLORS[self.color]
  250. # Outline
  251. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  252. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0,0,0))
  253. # Horizontal slit
  254. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  255. def toggle(self, env, pos):
  256. # Replace the box by its contents
  257. env.grid.set(*pos, self.contains)
  258. return True
  259. class Grid:
  260. """
  261. Represent a grid and operations on it
  262. """
  263. # Static cache of pre-renderer tiles
  264. tile_cache = {}
  265. def __init__(self, width, height):
  266. assert width >= 3
  267. assert height >= 3
  268. self.width = width
  269. self.height = height
  270. self.grid = [None] * width * height
  271. def __contains__(self, key):
  272. if isinstance(key, WorldObj):
  273. for e in self.grid:
  274. if e is key:
  275. return True
  276. elif isinstance(key, tuple):
  277. for e in self.grid:
  278. if e is None:
  279. continue
  280. if (e.color, e.type) == key:
  281. return True
  282. if key[0] is None and key[1] == e.type:
  283. return True
  284. return False
  285. def __eq__(self, other):
  286. grid1 = self.encode()
  287. grid2 = other.encode()
  288. return np.array_equal(grid2, grid1)
  289. def __ne__(self, other):
  290. return not self == other
  291. def copy(self):
  292. from copy import deepcopy
  293. return deepcopy(self)
  294. def set(self, i, j, v):
  295. assert i >= 0 and i < self.width
  296. assert j >= 0 and j < self.height
  297. self.grid[j * self.width + i] = v
  298. def get(self, i, j):
  299. assert i >= 0 and i < self.width
  300. assert j >= 0 and j < self.height
  301. return self.grid[j * self.width + i]
  302. def horz_wall(self, x, y, length=None, obj_type=Wall):
  303. if length is None:
  304. length = self.width - x
  305. for i in range(0, length):
  306. self.set(x + i, y, obj_type())
  307. def vert_wall(self, x, y, length=None, obj_type=Wall):
  308. if length is None:
  309. length = self.height - y
  310. for j in range(0, length):
  311. self.set(x, y + j, obj_type())
  312. def wall_rect(self, x, y, w, h):
  313. self.horz_wall(x, y, w)
  314. self.horz_wall(x, y+h-1, w)
  315. self.vert_wall(x, y, h)
  316. self.vert_wall(x+w-1, y, h)
  317. def rotate_left(self):
  318. """
  319. Rotate the grid to the left (counter-clockwise)
  320. """
  321. grid = Grid(self.height, self.width)
  322. for i in range(self.width):
  323. for j in range(self.height):
  324. v = self.get(i, j)
  325. grid.set(j, grid.height - 1 - i, v)
  326. return grid
  327. def slice(self, topX, topY, width, height):
  328. """
  329. Get a subset of the grid
  330. """
  331. grid = Grid(width, height)
  332. for j in range(0, height):
  333. for i in range(0, width):
  334. x = topX + i
  335. y = topY + j
  336. if x >= 0 and x < self.width and \
  337. y >= 0 and y < self.height:
  338. v = self.get(x, y)
  339. else:
  340. v = Wall()
  341. grid.set(i, j, v)
  342. return grid
  343. @classmethod
  344. def render_tile(
  345. cls,
  346. obj,
  347. agent_dir=None,
  348. highlight=False,
  349. tile_size=TILE_PIXELS,
  350. subdivs=3
  351. ):
  352. """
  353. Render a tile and cache the result
  354. """
  355. # Hash map lookup key for the cache
  356. key = (agent_dir, highlight, tile_size)
  357. key = obj.encode() + key if obj else key
  358. if key in cls.tile_cache:
  359. return cls.tile_cache[key]
  360. img = np.zeros(shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8)
  361. # Draw the grid lines (top and left edges)
  362. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  363. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  364. if obj != None:
  365. obj.render(img)
  366. # Overlay the agent on top
  367. if agent_dir is not None:
  368. tri_fn = point_in_triangle(
  369. (0.12, 0.19),
  370. (0.87, 0.50),
  371. (0.12, 0.81),
  372. )
  373. # Rotate the agent based on its direction
  374. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir)
  375. fill_coords(img, tri_fn, (255, 0, 0))
  376. # Highlight the cell if needed
  377. if highlight:
  378. highlight_img(img)
  379. # Downsample the image to perform supersampling/anti-aliasing
  380. img = downsample(img, subdivs)
  381. # Cache the rendered tile
  382. cls.tile_cache[key] = img
  383. return img
  384. def render(
  385. self,
  386. tile_size,
  387. agent_pos=None,
  388. agent_dir=None,
  389. highlight_mask=None
  390. ):
  391. """
  392. Render this grid at a given scale
  393. :param r: target renderer object
  394. :param tile_size: tile size in pixels
  395. """
  396. if highlight_mask is None:
  397. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  398. # Compute the total grid size
  399. width_px = self.width * tile_size
  400. height_px = self.height * tile_size
  401. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  402. # Render the grid
  403. for j in range(0, self.height):
  404. for i in range(0, self.width):
  405. cell = self.get(i, j)
  406. agent_here = np.array_equal(agent_pos, (i, j))
  407. tile_img = Grid.render_tile(
  408. cell,
  409. agent_dir=agent_dir if agent_here else None,
  410. highlight=highlight_mask[i, j],
  411. tile_size=tile_size
  412. )
  413. ymin = j * tile_size
  414. ymax = (j+1) * tile_size
  415. xmin = i * tile_size
  416. xmax = (i+1) * tile_size
  417. img[ymin:ymax, xmin:xmax, :] = tile_img
  418. return img
  419. def encode(self, vis_mask=None):
  420. """
  421. Produce a compact numpy encoding of the grid
  422. """
  423. if vis_mask is None:
  424. vis_mask = np.ones((self.width, self.height), dtype=bool)
  425. array = np.zeros((self.width, self.height, 3), dtype='uint8')
  426. for i in range(self.width):
  427. for j in range(self.height):
  428. if vis_mask[i, j]:
  429. v = self.get(i, j)
  430. if v is None:
  431. array[i, j, 0] = OBJECT_TO_IDX['empty']
  432. array[i, j, 1] = 0
  433. array[i, j, 2] = 0
  434. else:
  435. array[i, j, :] = v.encode()
  436. return array
  437. @staticmethod
  438. def decode(array):
  439. """
  440. Decode an array grid encoding back into a grid
  441. """
  442. width, height, channels = array.shape
  443. assert channels == 3
  444. vis_mask = np.ones(shape=(width, height), dtype=bool)
  445. grid = Grid(width, height)
  446. for i in range(width):
  447. for j in range(height):
  448. type_idx, color_idx, state = array[i, j]
  449. v = WorldObj.decode(type_idx, color_idx, state)
  450. grid.set(i, j, v)
  451. vis_mask[i, j] = (type_idx != OBJECT_TO_IDX['unseen'])
  452. return grid, vis_mask
  453. def process_vis(grid, agent_pos):
  454. mask = np.zeros(shape=(grid.width, grid.height), dtype=bool)
  455. mask[agent_pos[0], agent_pos[1]] = True
  456. for j in reversed(range(0, grid.height)):
  457. for i in range(0, grid.width-1):
  458. if not mask[i, j]:
  459. continue
  460. cell = grid.get(i, j)
  461. if cell and not cell.see_behind():
  462. continue
  463. mask[i+1, j] = True
  464. if j > 0:
  465. mask[i+1, j-1] = True
  466. mask[i, j-1] = True
  467. for i in reversed(range(1, grid.width)):
  468. if not mask[i, j]:
  469. continue
  470. cell = grid.get(i, j)
  471. if cell and not cell.see_behind():
  472. continue
  473. mask[i-1, j] = True
  474. if j > 0:
  475. mask[i-1, j-1] = True
  476. mask[i, j-1] = True
  477. for j in range(0, grid.height):
  478. for i in range(0, grid.width):
  479. if not mask[i, j]:
  480. grid.set(i, j, None)
  481. return mask
  482. class StringGymSpace(gym.spaces.space.Space):
  483. """
  484. A gym space that represents a string of characters of bounded length
  485. """
  486. def __init__(self, min_length=0, max_length=1000):
  487. self.min_length = min_length
  488. self.max_length = max_length
  489. self.letters = string.ascii_letters + string.digits + ' .,!- '
  490. self._shape = ()
  491. self.dtype = np.dtype('U')
  492. def sample(self):
  493. length = np.random.randint(self.min_length, self.max_length)
  494. string = ''.join(np.random.choice(self.letters, size=length))
  495. return string
  496. def contains(self, x):
  497. return isinstance(x, str) and len(x) >= self.min_length and len(x) <= self.max_length
  498. def __repr__(self):
  499. return "StringGymSpace(min_length={}, max_length={})".format(self.min_length, self.max_length)
  500. class MiniGridEnv(gym.Env):
  501. """
  502. 2D grid world game environment
  503. """
  504. metadata = {
  505. 'render.modes': ['human', 'rgb_array'],
  506. 'video.frames_per_second' : 10
  507. }
  508. # Enumeration of possible actions
  509. class Actions(IntEnum):
  510. # Turn left, turn right, move forward
  511. left = 0
  512. right = 1
  513. forward = 2
  514. # Pick up an object
  515. pickup = 3
  516. # Drop an object
  517. drop = 4
  518. # Toggle/activate an object
  519. toggle = 5
  520. # Done completing task
  521. done = 6
  522. def __init__(
  523. self,
  524. grid_size=None,
  525. width=None,
  526. height=None,
  527. max_steps=100,
  528. see_through_walls=False,
  529. seed=1337,
  530. agent_view_size=7
  531. ):
  532. # Can't set both grid_size and width/height
  533. if grid_size:
  534. assert width == None and height == None
  535. width = grid_size
  536. height = grid_size
  537. # Action enumeration for this environment
  538. self.actions = MiniGridEnv.Actions
  539. # Actions are discrete integer values
  540. self.action_space = spaces.Discrete(len(self.actions))
  541. # Number of cells (width and height) in the agent view
  542. assert agent_view_size % 2 == 1
  543. assert agent_view_size >= 3
  544. self.agent_view_size = agent_view_size
  545. # Observations are dictionaries containing an
  546. # encoding of the grid and a textual 'mission' string
  547. self.observation_space = spaces.Box(
  548. low=0,
  549. high=255,
  550. shape=(self.agent_view_size, self.agent_view_size, 3),
  551. dtype='uint8'
  552. )
  553. self.observation_space = spaces.Dict({
  554. 'image': self.observation_space,
  555. 'direction': spaces.Discrete(4),
  556. 'mission': StringGymSpace(min_length=0, max_length=200),
  557. })
  558. # Range of possible rewards
  559. self.reward_range = (0, 1)
  560. # Window to use for human rendering mode
  561. self.window = None
  562. # Environment configuration
  563. self.width = width
  564. self.height = height
  565. self.max_steps = max_steps
  566. self.see_through_walls = see_through_walls
  567. # Current position and direction of the agent
  568. self.agent_pos = None
  569. self.agent_dir = None
  570. # Initialize the RNG
  571. self.seed(seed=seed)
  572. # Initialize the state
  573. self.reset()
  574. def reset(self):
  575. # Current position and direction of the agent
  576. self.agent_pos = None
  577. self.agent_dir = None
  578. # Generate a new random grid at the start of each episode
  579. # To keep the same grid for each episode, call env.seed() with
  580. # the same seed before calling env.reset()
  581. self._gen_grid(self.width, self.height)
  582. # These fields should be defined by _gen_grid
  583. assert self.agent_pos is not None
  584. assert self.agent_dir is not None
  585. # Check that the agent doesn't overlap with an object
  586. start_cell = self.grid.get(*self.agent_pos)
  587. assert start_cell is None or start_cell.can_overlap()
  588. # Item picked up, being carried, initially nothing
  589. self.carrying = None
  590. # Step count since episode start
  591. self.step_count = 0
  592. # Return first observation
  593. obs = self.gen_obs()
  594. return obs
  595. def seed(self, seed=1337):
  596. # Seed the random number generator
  597. self.np_random, _ = seeding.np_random(seed)
  598. return [seed]
  599. def hash(self, size=16):
  600. """Compute a hash that uniquely identifies the current state of the environment.
  601. :param size: Size of the hashing
  602. """
  603. sample_hash = hashlib.sha256()
  604. to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
  605. for item in to_encode:
  606. sample_hash.update(str(item).encode('utf8'))
  607. return sample_hash.hexdigest()[:size]
  608. @property
  609. def steps_remaining(self):
  610. return self.max_steps - self.step_count
  611. def __str__(self):
  612. """
  613. Produce a pretty string of the environment's grid along with the agent.
  614. A grid cell is represented by 2-character string, the first one for
  615. the object and the second one for the color.
  616. """
  617. # Map of object types to short string
  618. OBJECT_TO_STR = {
  619. 'wall' : 'W',
  620. 'floor' : 'F',
  621. 'door' : 'D',
  622. 'key' : 'K',
  623. 'ball' : 'A',
  624. 'box' : 'B',
  625. 'goal' : 'G',
  626. 'lava' : 'V',
  627. }
  628. # Short string for opened door
  629. OPENDED_DOOR_IDS = '_'
  630. # Map agent's direction to short string
  631. AGENT_DIR_TO_STR = {
  632. 0: '>',
  633. 1: 'V',
  634. 2: '<',
  635. 3: '^'
  636. }
  637. str = ''
  638. for j in range(self.grid.height):
  639. for i in range(self.grid.width):
  640. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  641. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  642. continue
  643. c = self.grid.get(i, j)
  644. if c == None:
  645. str += ' '
  646. continue
  647. if c.type == 'door':
  648. if c.is_open:
  649. str += '__'
  650. elif c.is_locked:
  651. str += 'L' + c.color[0].upper()
  652. else:
  653. str += 'D' + c.color[0].upper()
  654. continue
  655. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  656. if j < self.grid.height - 1:
  657. str += '\n'
  658. return str
  659. def _gen_grid(self, width, height):
  660. assert False, "_gen_grid needs to be implemented by each environment"
  661. def _reward(self):
  662. """
  663. Compute the reward to be given upon success
  664. """
  665. return 1 - 0.9 * (self.step_count / self.max_steps)
  666. def _rand_int(self, low, high):
  667. """
  668. Generate random integer in [low,high[
  669. """
  670. return self.np_random.integers(low, high)
  671. def _rand_float(self, low, high):
  672. """
  673. Generate random float in [low,high[
  674. """
  675. return self.np_random.uniform(low, high)
  676. def _rand_bool(self):
  677. """
  678. Generate random boolean value
  679. """
  680. return (self.np_random.integers(0, 2) == 0)
  681. def _rand_elem(self, iterable):
  682. """
  683. Pick a random element in a list
  684. """
  685. lst = list(iterable)
  686. idx = self._rand_int(0, len(lst))
  687. return lst[idx]
  688. def _rand_subset(self, iterable, num_elems):
  689. """
  690. Sample a random subset of distinct elements of a list
  691. """
  692. lst = list(iterable)
  693. assert num_elems <= len(lst)
  694. out = []
  695. while len(out) < num_elems:
  696. elem = self._rand_elem(lst)
  697. lst.remove(elem)
  698. out.append(elem)
  699. return out
  700. def _rand_color(self):
  701. """
  702. Generate a random color name (string)
  703. """
  704. return self._rand_elem(COLOR_NAMES)
  705. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  706. """
  707. Generate a random (x,y) position tuple
  708. """
  709. return (
  710. self.np_random.integers(xLow, xHigh),
  711. self.np_random.integers(yLow, yHigh)
  712. )
  713. def place_obj(self,
  714. obj,
  715. top=None,
  716. size=None,
  717. reject_fn=None,
  718. max_tries=math.inf
  719. ):
  720. """
  721. Place an object at an empty position in the grid
  722. :param top: top-left position of the rectangle where to place
  723. :param size: size of the rectangle where to place
  724. :param reject_fn: function to filter out potential positions
  725. """
  726. if top is None:
  727. top = (0, 0)
  728. else:
  729. top = (max(top[0], 0), max(top[1], 0))
  730. if size is None:
  731. size = (self.grid.width, self.grid.height)
  732. num_tries = 0
  733. while True:
  734. # This is to handle with rare cases where rejection sampling
  735. # gets stuck in an infinite loop
  736. if num_tries > max_tries:
  737. raise RecursionError('rejection sampling failed in place_obj')
  738. num_tries += 1
  739. pos = np.array((
  740. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  741. self._rand_int(top[1], min(top[1] + size[1], self.grid.height))
  742. ))
  743. # Don't place the object on top of another object
  744. if self.grid.get(*pos) != None:
  745. continue
  746. # Don't place the object where the agent is
  747. if np.array_equal(pos, self.agent_pos):
  748. continue
  749. # Check if there is a filtering criterion
  750. if reject_fn and reject_fn(self, pos):
  751. continue
  752. break
  753. self.grid.set(*pos, obj)
  754. if obj is not None:
  755. obj.init_pos = pos
  756. obj.cur_pos = pos
  757. return pos
  758. def put_obj(self, obj, i, j):
  759. """
  760. Put an object at a specific position in the grid
  761. """
  762. self.grid.set(i, j, obj)
  763. obj.init_pos = (i, j)
  764. obj.cur_pos = (i, j)
  765. def place_agent(
  766. self,
  767. top=None,
  768. size=None,
  769. rand_dir=True,
  770. max_tries=math.inf
  771. ):
  772. """
  773. Set the agent's starting point at an empty position in the grid
  774. """
  775. self.agent_pos = None
  776. pos = self.place_obj(None, top, size, max_tries=max_tries)
  777. self.agent_pos = pos
  778. if rand_dir:
  779. self.agent_dir = self._rand_int(0, 4)
  780. return pos
  781. @property
  782. def dir_vec(self):
  783. """
  784. Get the direction vector for the agent, pointing in the direction
  785. of forward movement.
  786. """
  787. assert self.agent_dir >= 0 and self.agent_dir < 4
  788. return DIR_TO_VEC[self.agent_dir]
  789. @property
  790. def right_vec(self):
  791. """
  792. Get the vector pointing to the right of the agent.
  793. """
  794. dx, dy = self.dir_vec
  795. return np.array((-dy, dx))
  796. @property
  797. def front_pos(self):
  798. """
  799. Get the position of the cell that is right in front of the agent
  800. """
  801. return self.agent_pos + self.dir_vec
  802. def get_view_coords(self, i, j):
  803. """
  804. Translate and rotate absolute grid coordinates (i, j) into the
  805. agent's partially observable view (sub-grid). Note that the resulting
  806. coordinates may be negative or outside of the agent's view size.
  807. """
  808. ax, ay = self.agent_pos
  809. dx, dy = self.dir_vec
  810. rx, ry = self.right_vec
  811. # Compute the absolute coordinates of the top-left view corner
  812. sz = self.agent_view_size
  813. hs = self.agent_view_size // 2
  814. tx = ax + (dx * (sz-1)) - (rx * hs)
  815. ty = ay + (dy * (sz-1)) - (ry * hs)
  816. lx = i - tx
  817. ly = j - ty
  818. # Project the coordinates of the object relative to the top-left
  819. # corner onto the agent's own coordinate system
  820. vx = (rx*lx + ry*ly)
  821. vy = -(dx*lx + dy*ly)
  822. return vx, vy
  823. def get_view_exts(self):
  824. """
  825. Get the extents of the square set of tiles visible to the agent
  826. Note: the bottom extent indices are not included in the set
  827. """
  828. # Facing right
  829. if self.agent_dir == 0:
  830. topX = self.agent_pos[0]
  831. topY = self.agent_pos[1] - self.agent_view_size // 2
  832. # Facing down
  833. elif self.agent_dir == 1:
  834. topX = self.agent_pos[0] - self.agent_view_size // 2
  835. topY = self.agent_pos[1]
  836. # Facing left
  837. elif self.agent_dir == 2:
  838. topX = self.agent_pos[0] - self.agent_view_size + 1
  839. topY = self.agent_pos[1] - self.agent_view_size // 2
  840. # Facing up
  841. elif self.agent_dir == 3:
  842. topX = self.agent_pos[0] - self.agent_view_size // 2
  843. topY = self.agent_pos[1] - self.agent_view_size + 1
  844. else:
  845. assert False, "invalid agent direction"
  846. botX = topX + self.agent_view_size
  847. botY = topY + self.agent_view_size
  848. return (topX, topY, botX, botY)
  849. def relative_coords(self, x, y):
  850. """
  851. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  852. """
  853. vx, vy = self.get_view_coords(x, y)
  854. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  855. return None
  856. return vx, vy
  857. def in_view(self, x, y):
  858. """
  859. check if a grid position is visible to the agent
  860. """
  861. return self.relative_coords(x, y) is not None
  862. def agent_sees(self, x, y):
  863. """
  864. Check if a non-empty grid position is visible to the agent
  865. """
  866. coordinates = self.relative_coords(x, y)
  867. if coordinates is None:
  868. return False
  869. vx, vy = coordinates
  870. obs = self.gen_obs()
  871. obs_grid, _ = Grid.decode(obs['image'])
  872. obs_cell = obs_grid.get(vx, vy)
  873. world_cell = self.grid.get(x, y)
  874. return obs_cell is not None and obs_cell.type == world_cell.type
  875. def step(self, action):
  876. self.step_count += 1
  877. reward = 0
  878. done = False
  879. # Get the position in front of the agent
  880. fwd_pos = self.front_pos
  881. # Get the contents of the cell in front of the agent
  882. fwd_cell = self.grid.get(*fwd_pos)
  883. # Rotate left
  884. if action == self.actions.left:
  885. self.agent_dir -= 1
  886. if self.agent_dir < 0:
  887. self.agent_dir += 4
  888. # Rotate right
  889. elif action == self.actions.right:
  890. self.agent_dir = (self.agent_dir + 1) % 4
  891. # Move forward
  892. elif action == self.actions.forward:
  893. if fwd_cell == None or fwd_cell.can_overlap():
  894. self.agent_pos = fwd_pos
  895. if fwd_cell != None and fwd_cell.type == 'goal':
  896. done = True
  897. reward = self._reward()
  898. if fwd_cell != None and fwd_cell.type == 'lava':
  899. done = True
  900. # Pick up an object
  901. elif action == self.actions.pickup:
  902. if fwd_cell and fwd_cell.can_pickup():
  903. if self.carrying is None:
  904. self.carrying = fwd_cell
  905. self.carrying.cur_pos = np.array([-1, -1])
  906. self.grid.set(*fwd_pos, None)
  907. # Drop an object
  908. elif action == self.actions.drop:
  909. if not fwd_cell and self.carrying:
  910. self.grid.set(*fwd_pos, self.carrying)
  911. self.carrying.cur_pos = fwd_pos
  912. self.carrying = None
  913. # Toggle/activate an object
  914. elif action == self.actions.toggle:
  915. if fwd_cell:
  916. fwd_cell.toggle(self, fwd_pos)
  917. # Done action (not used by default)
  918. elif action == self.actions.done:
  919. pass
  920. else:
  921. assert False, "unknown action"
  922. if self.step_count >= self.max_steps:
  923. done = True
  924. obs = self.gen_obs()
  925. return obs, reward, done, {}
  926. def gen_obs_grid(self):
  927. """
  928. Generate the sub-grid observed by the agent.
  929. This method also outputs a visibility mask telling us which grid
  930. cells the agent can actually see.
  931. """
  932. topX, topY, botX, botY = self.get_view_exts()
  933. grid = self.grid.slice(topX, topY, self.agent_view_size, self.agent_view_size)
  934. for i in range(self.agent_dir + 1):
  935. grid = grid.rotate_left()
  936. # Process occluders and visibility
  937. # Note that this incurs some performance cost
  938. if not self.see_through_walls:
  939. vis_mask = grid.process_vis(agent_pos=(self.agent_view_size // 2 , self.agent_view_size - 1))
  940. else:
  941. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  942. # Make it so the agent sees what it's carrying
  943. # We do this by placing the carried object at the agent's position
  944. # in the agent's partially observable view
  945. agent_pos = grid.width // 2, grid.height - 1
  946. if self.carrying:
  947. grid.set(*agent_pos, self.carrying)
  948. else:
  949. grid.set(*agent_pos, None)
  950. return grid, vis_mask
  951. def gen_obs(self):
  952. """
  953. Generate the agent's view (partially observable, low-resolution encoding)
  954. """
  955. grid, vis_mask = self.gen_obs_grid()
  956. # Encode the partially observable view into a numpy array
  957. image = grid.encode(vis_mask)
  958. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  959. # Observations are dictionaries containing:
  960. # - an image (partially observable view of the environment)
  961. # - the agent's direction/orientation (acting as a compass)
  962. # - a textual mission string (instructions for the agent)
  963. obs = {
  964. 'image': image,
  965. 'direction': self.agent_dir,
  966. 'mission': self.mission
  967. }
  968. return obs
  969. def get_obs_render(self, obs, tile_size=TILE_PIXELS//2):
  970. """
  971. Render an agent observation for visualization
  972. """
  973. grid, vis_mask = Grid.decode(obs)
  974. # Render the whole grid
  975. img = grid.render(
  976. tile_size,
  977. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  978. agent_dir=3,
  979. highlight_mask=vis_mask
  980. )
  981. return img
  982. def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
  983. """
  984. Render the whole-grid human view
  985. """
  986. if close:
  987. if self.window:
  988. self.window.close()
  989. return
  990. if mode == 'human' and not self.window:
  991. import gym_minigrid.window
  992. self.window = gym_minigrid.window.Window('gym_minigrid')
  993. self.window.show(block=False)
  994. # Compute which cells are visible to the agent
  995. _, vis_mask = self.gen_obs_grid()
  996. # Compute the world coordinates of the bottom-left corner
  997. # of the agent's view area
  998. f_vec = self.dir_vec
  999. r_vec = self.right_vec
  1000. top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
  1001. # Mask of which cells to highlight
  1002. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  1003. # For each cell in the visibility mask
  1004. for vis_j in range(0, self.agent_view_size):
  1005. for vis_i in range(0, self.agent_view_size):
  1006. # If this cell is not visible, don't highlight it
  1007. if not vis_mask[vis_i, vis_j]:
  1008. continue
  1009. # Compute the world coordinates of this cell
  1010. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  1011. if abs_i < 0 or abs_i >= self.width:
  1012. continue
  1013. if abs_j < 0 or abs_j >= self.height:
  1014. continue
  1015. # Mark this cell to be highlighted
  1016. highlight_mask[abs_i, abs_j] = True
  1017. # Render the whole grid
  1018. img = self.grid.render(
  1019. tile_size,
  1020. self.agent_pos,
  1021. self.agent_dir,
  1022. highlight_mask=highlight_mask if highlight else None
  1023. )
  1024. if mode == 'human':
  1025. self.window.set_caption(self.mission)
  1026. self.window.show_img(img)
  1027. return img
  1028. def close(self):
  1029. if self.window:
  1030. self.window.close()
  1031. return