minigrid.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. from .rendering import *
  8. # Size in pixels of a tile in the full-scale human view
  9. TILE_PIXELS = 32
  10. # Map of color names to RGB values
  11. COLORS = {
  12. 'red' : np.array([255, 0, 0]),
  13. 'green' : np.array([0, 255, 0]),
  14. 'blue' : np.array([0, 0, 255]),
  15. 'purple': np.array([112, 39, 195]),
  16. 'yellow': np.array([255, 255, 0]),
  17. 'grey' : np.array([100, 100, 100])
  18. }
  19. COLOR_NAMES = sorted(list(COLORS.keys()))
  20. # Used to map colors to integers
  21. COLOR_TO_IDX = {
  22. 'red' : 0,
  23. 'green' : 1,
  24. 'blue' : 2,
  25. 'purple': 3,
  26. 'yellow': 4,
  27. 'grey' : 5
  28. }
  29. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  30. # Map of object type to integers
  31. OBJECT_TO_IDX = {
  32. 'unseen' : 0,
  33. 'empty' : 1,
  34. 'wall' : 2,
  35. 'floor' : 3,
  36. 'door' : 4,
  37. 'key' : 5,
  38. 'ball' : 6,
  39. 'box' : 7,
  40. 'goal' : 8,
  41. 'lava' : 9,
  42. 'agent' : 10,
  43. }
  44. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  45. # Map of state names to integers
  46. STATE_TO_IDX = {
  47. 'open' : 0,
  48. 'closed': 1,
  49. 'locked': 2,
  50. }
  51. # Map of agent direction indices to vectors
  52. DIR_TO_VEC = [
  53. # Pointing right (positive X)
  54. np.array((1, 0)),
  55. # Down (positive Y)
  56. np.array((0, 1)),
  57. # Pointing left (negative X)
  58. np.array((-1, 0)),
  59. # Up (negative Y)
  60. np.array((0, -1)),
  61. ]
  62. class WorldObj:
  63. """
  64. Base class for grid world objects
  65. """
  66. def __init__(self, type, color):
  67. assert type in OBJECT_TO_IDX, type
  68. assert color in COLOR_TO_IDX, color
  69. self.type = type
  70. self.color = color
  71. self.contains = None
  72. # Initial position of the object
  73. self.init_pos = None
  74. # Current position of the object
  75. self.cur_pos = None
  76. def can_overlap(self):
  77. """Can the agent overlap with this?"""
  78. return False
  79. def can_pickup(self):
  80. """Can the agent pick this up?"""
  81. return False
  82. def can_contain(self):
  83. """Can this contain another object?"""
  84. return False
  85. def see_behind(self):
  86. """Can the agent see behind this object?"""
  87. return True
  88. def toggle(self, env, pos):
  89. """Method to trigger/toggle an action this object performs"""
  90. return False
  91. def encode(self):
  92. """Encode the a description of this object as a 3-tuple of integers"""
  93. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  94. @staticmethod
  95. def decode(type_idx, color_idx, state):
  96. """Create an object from a 3-tuple state description"""
  97. obj_type = IDX_TO_OBJECT[type_idx]
  98. color = IDX_TO_COLOR[color_idx]
  99. if obj_type == 'empty' or obj_type == 'unseen':
  100. return None
  101. # State, 0: open, 1: closed, 2: locked
  102. is_open = state == 0
  103. is_locked = state == 2
  104. if obj_type == 'wall':
  105. v = Wall(color)
  106. elif obj_type == 'floor':
  107. v = Floor(color)
  108. elif obj_type == 'ball':
  109. v = Ball(color)
  110. elif obj_type == 'key':
  111. v = Key(color)
  112. elif obj_type == 'box':
  113. v = Box(color)
  114. elif obj_type == 'door':
  115. v = Door(color, is_open, is_locked)
  116. elif obj_type == 'goal':
  117. v = Goal()
  118. elif obj_type == 'lava':
  119. v = Lava()
  120. else:
  121. assert False, "unknown object type in decode '%s'" % objType
  122. return v
  123. def render(self, r):
  124. """Draw this object with the given renderer"""
  125. raise NotImplementedError
  126. class Goal(WorldObj):
  127. def __init__(self):
  128. super().__init__('goal', 'green')
  129. def can_overlap(self):
  130. return True
  131. def render(self, img):
  132. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  133. class Floor(WorldObj):
  134. """
  135. Colored floor tile the agent can walk over
  136. """
  137. def __init__(self, color='blue'):
  138. super().__init__('floor', color)
  139. def can_overlap(self):
  140. return True
  141. def render(self, r):
  142. # Give the floor a pale color
  143. c = COLORS[self.color]
  144. r.setLineColor(100, 100, 100, 0)
  145. r.setColor(*c/2)
  146. r.drawPolygon([
  147. (1 , TILE_PIXELS),
  148. (TILE_PIXELS, TILE_PIXELS),
  149. (TILE_PIXELS, 1),
  150. (1 , 1)
  151. ])
  152. class Lava(WorldObj):
  153. def __init__(self):
  154. super().__init__('lava', 'red')
  155. def can_overlap(self):
  156. return True
  157. def render(self, img):
  158. c = (255, 128, 0)
  159. # Background color
  160. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  161. # Little waves
  162. for i in range(3):
  163. ylo = 0.3 + 0.2 * i
  164. yhi = 0.4 + 0.2 * i
  165. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0,0,0))
  166. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0,0,0))
  167. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0,0,0))
  168. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0,0,0))
  169. class Wall(WorldObj):
  170. def __init__(self, color='grey'):
  171. super().__init__('wall', color)
  172. def see_behind(self):
  173. return False
  174. def render(self, img):
  175. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  176. class Door(WorldObj):
  177. def __init__(self, color, is_open=False, is_locked=False):
  178. super().__init__('door', color)
  179. self.is_open = is_open
  180. self.is_locked = is_locked
  181. def can_overlap(self):
  182. """The agent can only walk over this cell when the door is open"""
  183. return self.is_open
  184. def see_behind(self):
  185. return self.is_open
  186. def toggle(self, env, pos):
  187. # If the player has the right key to open the door
  188. if self.is_locked:
  189. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  190. self.is_locked = False
  191. self.is_open = True
  192. return True
  193. return False
  194. self.is_open = not self.is_open
  195. return True
  196. def encode(self):
  197. """Encode the a description of this object as a 3-tuple of integers"""
  198. # State, 0: open, 1: closed, 2: locked
  199. if self.is_open:
  200. state = 0
  201. elif self.is_locked:
  202. state = 2
  203. elif not self.is_open:
  204. state = 1
  205. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  206. def render(self, img):
  207. c = COLORS[self.color]
  208. if self.is_open:
  209. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  210. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0,0,0))
  211. return
  212. # Door frame and door
  213. if self.is_locked:
  214. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  215. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  216. # Draw key slot
  217. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  218. else:
  219. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  220. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0,0,0))
  221. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  222. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0,0,0))
  223. # Draw door handle
  224. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  225. class Key(WorldObj):
  226. def __init__(self, color='blue'):
  227. super(Key, self).__init__('key', color)
  228. def can_pickup(self):
  229. return True
  230. def render(self, img):
  231. c = COLORS[self.color]
  232. # Vertical quad
  233. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  234. # Teeth
  235. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  236. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  237. # Ring
  238. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  239. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0,0,0))
  240. class Ball(WorldObj):
  241. def __init__(self, color='blue'):
  242. super(Ball, self).__init__('ball', color)
  243. def can_pickup(self):
  244. return True
  245. def render(self, img):
  246. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  247. class Box(WorldObj):
  248. def __init__(self, color, contains=None):
  249. super(Box, self).__init__('box', color)
  250. self.contains = contains
  251. def can_pickup(self):
  252. return True
  253. def render(self, img):
  254. c = COLORS[self.color]
  255. # Outline
  256. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  257. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0,0,0))
  258. # Horizontal slit
  259. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  260. def toggle(self, env, pos):
  261. # Replace the box by its contents
  262. env.grid.set(*pos, self.contains)
  263. return True
  264. class Grid:
  265. """
  266. Represent a grid and operations on it
  267. """
  268. # Static cache of pre-renderer tiles
  269. tile_cache = {}
  270. def __init__(self, width, height):
  271. assert width >= 3
  272. assert height >= 3
  273. self.width = width
  274. self.height = height
  275. self.grid = [None] * width * height
  276. def __contains__(self, key):
  277. if isinstance(key, WorldObj):
  278. for e in self.grid:
  279. if e is key:
  280. return True
  281. elif isinstance(key, tuple):
  282. for e in self.grid:
  283. if e is None:
  284. continue
  285. if (e.color, e.type) == key:
  286. return True
  287. if key[0] is None and key[1] == e.type:
  288. return True
  289. return False
  290. def __eq__(self, other):
  291. grid1 = self.encode()
  292. grid2 = other.encode()
  293. return np.array_equal(grid2, grid1)
  294. def __ne__(self, other):
  295. return not self == other
  296. def copy(self):
  297. from copy import deepcopy
  298. return deepcopy(self)
  299. def set(self, i, j, v):
  300. assert i >= 0 and i < self.width
  301. assert j >= 0 and j < self.height
  302. self.grid[j * self.width + i] = v
  303. def get(self, i, j):
  304. assert i >= 0 and i < self.width
  305. assert j >= 0 and j < self.height
  306. return self.grid[j * self.width + i]
  307. def horz_wall(self, x, y, length=None, obj_type=Wall):
  308. if length is None:
  309. length = self.width - x
  310. for i in range(0, length):
  311. self.set(x + i, y, obj_type())
  312. def vert_wall(self, x, y, length=None, obj_type=Wall):
  313. if length is None:
  314. length = self.height - y
  315. for j in range(0, length):
  316. self.set(x, y + j, obj_type())
  317. def wall_rect(self, x, y, w, h):
  318. self.horz_wall(x, y, w)
  319. self.horz_wall(x, y+h-1, w)
  320. self.vert_wall(x, y, h)
  321. self.vert_wall(x+w-1, y, h)
  322. def rotate_left(self):
  323. """
  324. Rotate the grid to the left (counter-clockwise)
  325. """
  326. grid = Grid(self.height, self.width)
  327. for i in range(self.width):
  328. for j in range(self.height):
  329. v = self.get(i, j)
  330. grid.set(j, grid.height - 1 - i, v)
  331. return grid
  332. def slice(self, topX, topY, width, height):
  333. """
  334. Get a subset of the grid
  335. """
  336. grid = Grid(width, height)
  337. for j in range(0, height):
  338. for i in range(0, width):
  339. x = topX + i
  340. y = topY + j
  341. if x >= 0 and x < self.width and \
  342. y >= 0 and y < self.height:
  343. v = self.get(x, y)
  344. else:
  345. v = Wall()
  346. grid.set(i, j, v)
  347. return grid
  348. @classmethod
  349. def render_tile(
  350. cls,
  351. obj,
  352. agent_dir=None,
  353. highlight=False,
  354. tile_size=TILE_PIXELS,
  355. subdivs=3
  356. ):
  357. """
  358. Render a tile and cache the result
  359. """
  360. # Hash map lookup key for the cache
  361. key = (agent_dir, highlight, tile_size)
  362. key = obj.encode() + key if obj else key
  363. if key in cls.tile_cache:
  364. return cls.tile_cache[key]
  365. img = np.zeros(shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8)
  366. # Draw the grid lines (top and left edges)
  367. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  368. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  369. if obj != None:
  370. obj.render(img)
  371. # Overlay the agent on top
  372. if agent_dir is not None:
  373. tri_fn = point_in_triangle(
  374. (0.12, 0.19),
  375. (0.87, 0.50),
  376. (0.12, 0.81),
  377. )
  378. # Rotate the agent based on its direction
  379. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir)
  380. fill_coords(img, tri_fn, (255, 0, 0))
  381. # Highlight the cell if needed
  382. if highlight:
  383. highlight_img(img)
  384. # Downsample the image to perform supersampling/anti-aliasing
  385. img = downsample(img, subdivs)
  386. # Cache the rendered tile
  387. cls.tile_cache[key] = img
  388. return img
  389. def render(
  390. self,
  391. tile_size,
  392. agent_pos=None,
  393. agent_dir=None,
  394. highlight_mask=None
  395. ):
  396. """
  397. Render this grid at a given scale
  398. :param r: target renderer object
  399. :param tile_size: tile size in pixels
  400. """
  401. if highlight_mask is None:
  402. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool)
  403. # Compute the total grid size
  404. width_px = self.width * tile_size
  405. height_px = self.height * tile_size
  406. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  407. # Render the grid
  408. for j in range(0, self.height):
  409. for i in range(0, self.width):
  410. cell = self.get(i, j)
  411. agent_here = np.array_equal(agent_pos, (i, j))
  412. tile_img = Grid.render_tile(
  413. cell,
  414. agent_dir=agent_dir if agent_here else None,
  415. highlight=highlight_mask[i, j],
  416. tile_size=tile_size
  417. )
  418. ymin = j * tile_size
  419. ymax = (j+1) * tile_size
  420. xmin = i * tile_size
  421. xmax = (i+1) * tile_size
  422. img[ymin:ymax, xmin:xmax, :] = tile_img
  423. return img
  424. def encode(self, vis_mask=None):
  425. """
  426. Produce a compact numpy encoding of the grid
  427. """
  428. if vis_mask is None:
  429. vis_mask = np.ones((self.width, self.height), dtype=bool)
  430. array = np.zeros((self.width, self.height, 3), dtype='uint8')
  431. for i in range(self.width):
  432. for j in range(self.height):
  433. if vis_mask[i, j]:
  434. v = self.get(i, j)
  435. if v is None:
  436. array[i, j, 0] = OBJECT_TO_IDX['empty']
  437. array[i, j, 1] = 0
  438. array[i, j, 2] = 0
  439. else:
  440. array[i, j, :] = v.encode()
  441. return array
  442. @staticmethod
  443. def decode(array):
  444. """
  445. Decode an array grid encoding back into a grid
  446. """
  447. width, height, channels = array.shape
  448. assert channels == 3
  449. vis_mask = np.ones(shape=(width, height), dtype=np.bool)
  450. grid = Grid(width, height)
  451. for i in range(width):
  452. for j in range(height):
  453. type_idx, color_idx, state = array[i, j]
  454. v = WorldObj.decode(type_idx, color_idx, state)
  455. grid.set(i, j, v)
  456. vis_mask[i, j] = (type_idx != OBJECT_TO_IDX['unseen'])
  457. return grid, vis_mask
  458. def process_vis(grid, agent_pos):
  459. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  460. mask[agent_pos[0], agent_pos[1]] = True
  461. for j in reversed(range(0, grid.height)):
  462. for i in range(0, grid.width-1):
  463. if not mask[i, j]:
  464. continue
  465. cell = grid.get(i, j)
  466. if cell and not cell.see_behind():
  467. continue
  468. mask[i+1, j] = True
  469. if j > 0:
  470. mask[i+1, j-1] = True
  471. mask[i, j-1] = True
  472. for i in reversed(range(1, grid.width)):
  473. if not mask[i, j]:
  474. continue
  475. cell = grid.get(i, j)
  476. if cell and not cell.see_behind():
  477. continue
  478. mask[i-1, j] = True
  479. if j > 0:
  480. mask[i-1, j-1] = True
  481. mask[i, j-1] = True
  482. for j in range(0, grid.height):
  483. for i in range(0, grid.width):
  484. if not mask[i, j]:
  485. grid.set(i, j, None)
  486. return mask
  487. class MiniGridEnv(gym.Env):
  488. """
  489. 2D grid world game environment
  490. """
  491. metadata = {
  492. 'render.modes': ['human', 'rgb_array'],
  493. 'video.frames_per_second' : 10
  494. }
  495. # Enumeration of possible actions
  496. class Actions(IntEnum):
  497. # Turn left, turn right, move forward
  498. left = 0
  499. right = 1
  500. forward = 2
  501. # Pick up an object
  502. pickup = 3
  503. # Drop an object
  504. drop = 4
  505. # Toggle/activate an object
  506. toggle = 5
  507. # Done completing task
  508. done = 6
  509. def __init__(
  510. self,
  511. grid_size=None,
  512. width=None,
  513. height=None,
  514. max_steps=100,
  515. see_through_walls=False,
  516. seed=1337,
  517. agent_view_size=7
  518. ):
  519. # Can't set both grid_size and width/height
  520. if grid_size:
  521. assert width == None and height == None
  522. width = grid_size
  523. height = grid_size
  524. # Action enumeration for this environment
  525. self.actions = MiniGridEnv.Actions
  526. # Actions are discrete integer values
  527. self.action_space = spaces.Discrete(len(self.actions))
  528. # Number of cells (width and height) in the agent view
  529. self.agent_view_size = agent_view_size
  530. # Observations are dictionaries containing an
  531. # encoding of the grid and a textual 'mission' string
  532. self.observation_space = spaces.Box(
  533. low=0,
  534. high=255,
  535. shape=(self.agent_view_size, self.agent_view_size, 3),
  536. dtype='uint8'
  537. )
  538. self.observation_space = spaces.Dict({
  539. 'image': self.observation_space
  540. })
  541. # Range of possible rewards
  542. self.reward_range = (0, 1)
  543. # Window to use for human rendering mode
  544. self.window = None
  545. # Environment configuration
  546. self.width = width
  547. self.height = height
  548. self.max_steps = max_steps
  549. self.see_through_walls = see_through_walls
  550. # Current position and direction of the agent
  551. self.agent_pos = None
  552. self.agent_dir = None
  553. # Initialize the RNG
  554. self.seed(seed=seed)
  555. # Initialize the state
  556. self.reset()
  557. def reset(self):
  558. # Current position and direction of the agent
  559. self.agent_pos = None
  560. self.agent_dir = None
  561. # Generate a new random grid at the start of each episode
  562. # To keep the same grid for each episode, call env.seed() with
  563. # the same seed before calling env.reset()
  564. self._gen_grid(self.width, self.height)
  565. # These fields should be defined by _gen_grid
  566. assert self.agent_pos is not None
  567. assert self.agent_dir is not None
  568. # Check that the agent doesn't overlap with an object
  569. start_cell = self.grid.get(*self.agent_pos)
  570. assert start_cell is None or start_cell.can_overlap()
  571. # Item picked up, being carried, initially nothing
  572. self.carrying = None
  573. # Step count since episode start
  574. self.step_count = 0
  575. # Return first observation
  576. obs = self.gen_obs()
  577. return obs
  578. def seed(self, seed=1337):
  579. # Seed the random number generator
  580. self.np_random, _ = seeding.np_random(seed)
  581. return [seed]
  582. @property
  583. def steps_remaining(self):
  584. return self.max_steps - self.step_count
  585. def __str__(self):
  586. """
  587. Produce a pretty string of the environment's grid along with the agent.
  588. A grid cell is represented by 2-character string, the first one for
  589. the object and the second one for the color.
  590. """
  591. # Map of object types to short string
  592. OBJECT_TO_STR = {
  593. 'wall' : 'W',
  594. 'floor' : 'F',
  595. 'door' : 'D',
  596. 'key' : 'K',
  597. 'ball' : 'A',
  598. 'box' : 'B',
  599. 'goal' : 'G',
  600. 'lava' : 'V',
  601. }
  602. # Short string for opened door
  603. OPENDED_DOOR_IDS = '_'
  604. # Map agent's direction to short string
  605. AGENT_DIR_TO_STR = {
  606. 0: '>',
  607. 1: 'V',
  608. 2: '<',
  609. 3: '^'
  610. }
  611. str = ''
  612. for j in range(self.grid.height):
  613. for i in range(self.grid.width):
  614. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  615. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  616. continue
  617. c = self.grid.get(i, j)
  618. if c == None:
  619. str += ' '
  620. continue
  621. if c.type == 'door':
  622. if c.is_open:
  623. str += '__'
  624. elif c.is_locked:
  625. str += 'L' + c.color[0].upper()
  626. else:
  627. str += 'D' + c.color[0].upper()
  628. continue
  629. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  630. if j < self.grid.height - 1:
  631. str += '\n'
  632. return str
  633. def _gen_grid(self, width, height):
  634. assert False, "_gen_grid needs to be implemented by each environment"
  635. def _reward(self):
  636. """
  637. Compute the reward to be given upon success
  638. """
  639. return 1 - 0.9 * (self.step_count / self.max_steps)
  640. def _rand_int(self, low, high):
  641. """
  642. Generate random integer in [low,high[
  643. """
  644. return self.np_random.randint(low, high)
  645. def _rand_float(self, low, high):
  646. """
  647. Generate random float in [low,high[
  648. """
  649. return self.np_random.uniform(low, high)
  650. def _rand_bool(self):
  651. """
  652. Generate random boolean value
  653. """
  654. return (self.np_random.randint(0, 2) == 0)
  655. def _rand_elem(self, iterable):
  656. """
  657. Pick a random element in a list
  658. """
  659. lst = list(iterable)
  660. idx = self._rand_int(0, len(lst))
  661. return lst[idx]
  662. def _rand_subset(self, iterable, num_elems):
  663. """
  664. Sample a random subset of distinct elements of a list
  665. """
  666. lst = list(iterable)
  667. assert num_elems <= len(lst)
  668. out = []
  669. while len(out) < num_elems:
  670. elem = self._rand_elem(lst)
  671. lst.remove(elem)
  672. out.append(elem)
  673. return out
  674. def _rand_color(self):
  675. """
  676. Generate a random color name (string)
  677. """
  678. return self._rand_elem(COLOR_NAMES)
  679. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  680. """
  681. Generate a random (x,y) position tuple
  682. """
  683. return (
  684. self.np_random.randint(xLow, xHigh),
  685. self.np_random.randint(yLow, yHigh)
  686. )
  687. def place_obj(self,
  688. obj,
  689. top=None,
  690. size=None,
  691. reject_fn=None,
  692. max_tries=math.inf
  693. ):
  694. """
  695. Place an object at an empty position in the grid
  696. :param top: top-left position of the rectangle where to place
  697. :param size: size of the rectangle where to place
  698. :param reject_fn: function to filter out potential positions
  699. """
  700. if top is None:
  701. top = (0, 0)
  702. else:
  703. top = (max(top[0], 0), max(top[1], 0))
  704. if size is None:
  705. size = (self.grid.width, self.grid.height)
  706. num_tries = 0
  707. while True:
  708. # This is to handle with rare cases where rejection sampling
  709. # gets stuck in an infinite loop
  710. if num_tries > max_tries:
  711. raise RecursionError('rejection sampling failed in place_obj')
  712. num_tries += 1
  713. pos = np.array((
  714. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  715. self._rand_int(top[1], min(top[1] + size[1], self.grid.height))
  716. ))
  717. # Don't place the object on top of another object
  718. if self.grid.get(*pos) != None:
  719. continue
  720. # Don't place the object where the agent is
  721. if np.array_equal(pos, self.agent_pos):
  722. continue
  723. # Check if there is a filtering criterion
  724. if reject_fn and reject_fn(self, pos):
  725. continue
  726. break
  727. self.grid.set(*pos, obj)
  728. if obj is not None:
  729. obj.init_pos = pos
  730. obj.cur_pos = pos
  731. return pos
  732. def put_obj(self, obj, i, j):
  733. """
  734. Put an object at a specific position in the grid
  735. """
  736. self.grid.set(i, j, obj)
  737. obj.init_pos = (i, j)
  738. obj.cur_pos = (i, j)
  739. def place_agent(
  740. self,
  741. top=None,
  742. size=None,
  743. rand_dir=True,
  744. max_tries=math.inf
  745. ):
  746. """
  747. Set the agent's starting point at an empty position in the grid
  748. """
  749. self.agent_pos = None
  750. pos = self.place_obj(None, top, size, max_tries=max_tries)
  751. self.agent_pos = pos
  752. if rand_dir:
  753. self.agent_dir = self._rand_int(0, 4)
  754. return pos
  755. @property
  756. def dir_vec(self):
  757. """
  758. Get the direction vector for the agent, pointing in the direction
  759. of forward movement.
  760. """
  761. assert self.agent_dir >= 0 and self.agent_dir < 4
  762. return DIR_TO_VEC[self.agent_dir]
  763. @property
  764. def right_vec(self):
  765. """
  766. Get the vector pointing to the right of the agent.
  767. """
  768. dx, dy = self.dir_vec
  769. return np.array((-dy, dx))
  770. @property
  771. def front_pos(self):
  772. """
  773. Get the position of the cell that is right in front of the agent
  774. """
  775. return self.agent_pos + self.dir_vec
  776. def get_view_coords(self, i, j):
  777. """
  778. Translate and rotate absolute grid coordinates (i, j) into the
  779. agent's partially observable view (sub-grid). Note that the resulting
  780. coordinates may be negative or outside of the agent's view size.
  781. """
  782. ax, ay = self.agent_pos
  783. dx, dy = self.dir_vec
  784. rx, ry = self.right_vec
  785. # Compute the absolute coordinates of the top-left view corner
  786. sz = self.agent_view_size
  787. hs = self.agent_view_size // 2
  788. tx = ax + (dx * (sz-1)) - (rx * hs)
  789. ty = ay + (dy * (sz-1)) - (ry * hs)
  790. lx = i - tx
  791. ly = j - ty
  792. # Project the coordinates of the object relative to the top-left
  793. # corner onto the agent's own coordinate system
  794. vx = (rx*lx + ry*ly)
  795. vy = -(dx*lx + dy*ly)
  796. return vx, vy
  797. def get_view_exts(self):
  798. """
  799. Get the extents of the square set of tiles visible to the agent
  800. Note: the bottom extent indices are not included in the set
  801. """
  802. # Facing right
  803. if self.agent_dir == 0:
  804. topX = self.agent_pos[0]
  805. topY = self.agent_pos[1] - self.agent_view_size // 2
  806. # Facing down
  807. elif self.agent_dir == 1:
  808. topX = self.agent_pos[0] - self.agent_view_size // 2
  809. topY = self.agent_pos[1]
  810. # Facing left
  811. elif self.agent_dir == 2:
  812. topX = self.agent_pos[0] - self.agent_view_size + 1
  813. topY = self.agent_pos[1] - self.agent_view_size // 2
  814. # Facing up
  815. elif self.agent_dir == 3:
  816. topX = self.agent_pos[0] - self.agent_view_size // 2
  817. topY = self.agent_pos[1] - self.agent_view_size + 1
  818. else:
  819. assert False, "invalid agent direction"
  820. botX = topX + self.agent_view_size
  821. botY = topY + self.agent_view_size
  822. return (topX, topY, botX, botY)
  823. def relative_coords(self, x, y):
  824. """
  825. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  826. """
  827. vx, vy = self.get_view_coords(x, y)
  828. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  829. return None
  830. return vx, vy
  831. def in_view(self, x, y):
  832. """
  833. check if a grid position is visible to the agent
  834. """
  835. return self.relative_coords(x, y) is not None
  836. def agent_sees(self, x, y):
  837. """
  838. Check if a non-empty grid position is visible to the agent
  839. """
  840. coordinates = self.relative_coords(x, y)
  841. if coordinates is None:
  842. return False
  843. vx, vy = coordinates
  844. obs = self.gen_obs()
  845. obs_grid, _ = Grid.decode(obs['image'])
  846. obs_cell = obs_grid.get(vx, vy)
  847. world_cell = self.grid.get(x, y)
  848. return obs_cell is not None and obs_cell.type == world_cell.type
  849. def step(self, action):
  850. self.step_count += 1
  851. reward = 0
  852. done = False
  853. # Get the position in front of the agent
  854. fwd_pos = self.front_pos
  855. # Get the contents of the cell in front of the agent
  856. fwd_cell = self.grid.get(*fwd_pos)
  857. # Rotate left
  858. if action == self.actions.left:
  859. self.agent_dir -= 1
  860. if self.agent_dir < 0:
  861. self.agent_dir += 4
  862. # Rotate right
  863. elif action == self.actions.right:
  864. self.agent_dir = (self.agent_dir + 1) % 4
  865. # Move forward
  866. elif action == self.actions.forward:
  867. if fwd_cell == None or fwd_cell.can_overlap():
  868. self.agent_pos = fwd_pos
  869. if fwd_cell != None and fwd_cell.type == 'goal':
  870. done = True
  871. reward = self._reward()
  872. if fwd_cell != None and fwd_cell.type == 'lava':
  873. done = True
  874. # Pick up an object
  875. elif action == self.actions.pickup:
  876. if fwd_cell and fwd_cell.can_pickup():
  877. if self.carrying is None:
  878. self.carrying = fwd_cell
  879. self.carrying.cur_pos = np.array([-1, -1])
  880. self.grid.set(*fwd_pos, None)
  881. # Drop an object
  882. elif action == self.actions.drop:
  883. if not fwd_cell and self.carrying:
  884. self.grid.set(*fwd_pos, self.carrying)
  885. self.carrying.cur_pos = fwd_pos
  886. self.carrying = None
  887. # Toggle/activate an object
  888. elif action == self.actions.toggle:
  889. if fwd_cell:
  890. fwd_cell.toggle(self, fwd_pos)
  891. # Done action (not used by default)
  892. elif action == self.actions.done:
  893. pass
  894. else:
  895. assert False, "unknown action"
  896. if self.step_count >= self.max_steps:
  897. done = True
  898. obs = self.gen_obs()
  899. return obs, reward, done, {}
  900. def gen_obs_grid(self):
  901. """
  902. Generate the sub-grid observed by the agent.
  903. This method also outputs a visibility mask telling us which grid
  904. cells the agent can actually see.
  905. """
  906. topX, topY, botX, botY = self.get_view_exts()
  907. grid = self.grid.slice(topX, topY, self.agent_view_size, self.agent_view_size)
  908. for i in range(self.agent_dir + 1):
  909. grid = grid.rotate_left()
  910. # Process occluders and visibility
  911. # Note that this incurs some performance cost
  912. if not self.see_through_walls:
  913. vis_mask = grid.process_vis(agent_pos=(self.agent_view_size // 2 , self.agent_view_size - 1))
  914. else:
  915. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
  916. # Make it so the agent sees what it's carrying
  917. # We do this by placing the carried object at the agent's position
  918. # in the agent's partially observable view
  919. agent_pos = grid.width // 2, grid.height - 1
  920. if self.carrying:
  921. grid.set(*agent_pos, self.carrying)
  922. else:
  923. grid.set(*agent_pos, None)
  924. return grid, vis_mask
  925. def gen_obs(self):
  926. """
  927. Generate the agent's view (partially observable, low-resolution encoding)
  928. """
  929. grid, vis_mask = self.gen_obs_grid()
  930. # Encode the partially observable view into a numpy array
  931. image = grid.encode(vis_mask)
  932. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  933. # Observations are dictionaries containing:
  934. # - an image (partially observable view of the environment)
  935. # - the agent's direction/orientation (acting as a compass)
  936. # - a textual mission string (instructions for the agent)
  937. obs = {
  938. 'image': image,
  939. 'direction': self.agent_dir,
  940. 'mission': self.mission
  941. }
  942. return obs
  943. def get_obs_render(self, obs, tile_size=TILE_PIXELS//2):
  944. """
  945. Render an agent observation for visualization
  946. """
  947. grid, vis_mask = Grid.decode(obs)
  948. # Render the whole grid
  949. img = grid.render(
  950. tile_size,
  951. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  952. agent_dir=3,
  953. highlight_mask=vis_mask
  954. )
  955. return img
  956. def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
  957. """
  958. Render the whole-grid human view
  959. """
  960. if close:
  961. if self.window:
  962. self.window.close()
  963. return
  964. if mode == 'human' and not self.window:
  965. import gym_minigrid.window
  966. self.window = gym_minigrid.window.Window('gym_minigrid')
  967. self.window.show(block=False)
  968. # Compute which cells are visible to the agent
  969. _, vis_mask = self.gen_obs_grid()
  970. # Compute the world coordinates of the bottom-left corner
  971. # of the agent's view area
  972. f_vec = self.dir_vec
  973. r_vec = self.right_vec
  974. top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
  975. # Mask of which cells to highlight
  976. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool)
  977. # For each cell in the visibility mask
  978. for vis_j in range(0, self.agent_view_size):
  979. for vis_i in range(0, self.agent_view_size):
  980. # If this cell is not visible, don't highlight it
  981. if not vis_mask[vis_i, vis_j]:
  982. continue
  983. # Compute the world coordinates of this cell
  984. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  985. if abs_i < 0 or abs_i >= self.width:
  986. continue
  987. if abs_j < 0 or abs_j >= self.height:
  988. continue
  989. # Mark this cell to be highlighted
  990. highlight_mask[abs_i, abs_j] = True
  991. # Render the whole grid
  992. img = self.grid.render(
  993. tile_size,
  994. self.agent_pos,
  995. self.agent_dir,
  996. highlight_mask=highlight_mask if highlight else None
  997. )
  998. if mode == 'human':
  999. self.window.show_img(img)
  1000. self.window.set_caption(self.mission)
  1001. return img