minigrid.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281
  1. import math
  2. import gym
  3. from enum import IntEnum
  4. import numpy as np
  5. from gym import error, spaces, utils
  6. from gym.utils import seeding
  7. from .rendering import *
  8. # Size in pixels of a tile in the full-scale human view
  9. TILE_PIXELS = 32
  10. # Map of color names to RGB values
  11. COLORS = {
  12. 'red' : np.array([255, 0, 0]),
  13. 'green' : np.array([0, 255, 0]),
  14. 'blue' : np.array([0, 0, 255]),
  15. 'purple': np.array([112, 39, 195]),
  16. 'yellow': np.array([255, 255, 0]),
  17. 'grey' : np.array([100, 100, 100])
  18. }
  19. COLOR_NAMES = sorted(list(COLORS.keys()))
  20. # Used to map colors to integers
  21. COLOR_TO_IDX = {
  22. 'red' : 0,
  23. 'green' : 1,
  24. 'blue' : 2,
  25. 'purple': 3,
  26. 'yellow': 4,
  27. 'grey' : 5
  28. }
  29. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  30. # Map of object type to integers
  31. OBJECT_TO_IDX = {
  32. 'unseen' : 0,
  33. 'empty' : 1,
  34. 'wall' : 2,
  35. 'floor' : 3,
  36. 'door' : 4,
  37. 'key' : 5,
  38. 'ball' : 6,
  39. 'box' : 7,
  40. 'goal' : 8,
  41. 'lava' : 9,
  42. 'agent' : 10,
  43. }
  44. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  45. # Map of state names to integers
  46. STATE_TO_IDX = {
  47. 'open' : 0,
  48. 'closed': 1,
  49. 'locked': 2,
  50. }
  51. # Map of agent direction indices to vectors
  52. DIR_TO_VEC = [
  53. # Pointing right (positive X)
  54. np.array((1, 0)),
  55. # Down (positive Y)
  56. np.array((0, 1)),
  57. # Pointing left (negative X)
  58. np.array((-1, 0)),
  59. # Up (negative Y)
  60. np.array((0, -1)),
  61. ]
  62. class WorldObj:
  63. """
  64. Base class for grid world objects
  65. """
  66. def __init__(self, type, color):
  67. assert type in OBJECT_TO_IDX, type
  68. assert color in COLOR_TO_IDX, color
  69. self.type = type
  70. self.color = color
  71. self.contains = None
  72. # Initial position of the object
  73. self.init_pos = None
  74. # Current position of the object
  75. self.cur_pos = None
  76. def can_overlap(self):
  77. """Can the agent overlap with this?"""
  78. return False
  79. def can_pickup(self):
  80. """Can the agent pick this up?"""
  81. return False
  82. def can_contain(self):
  83. """Can this contain another object?"""
  84. return False
  85. def see_behind(self):
  86. """Can the agent see behind this object?"""
  87. return True
  88. def toggle(self, env, pos):
  89. """Method to trigger/toggle an action this object performs"""
  90. return False
  91. def encode(self):
  92. """Encode the a description of this object as a 3-tuple of integers"""
  93. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  94. @staticmethod
  95. def decode(type_idx, color_idx, state):
  96. """Create an object from a 3-tuple state description"""
  97. obj_type = IDX_TO_OBJECT[type_idx]
  98. color = IDX_TO_COLOR[color_idx]
  99. if obj_type == 'empty' or obj_type == 'unseen':
  100. return None
  101. # State, 0: open, 1: closed, 2: locked
  102. is_open = state == 0
  103. is_locked = state == 2
  104. if obj_type == 'wall':
  105. v = Wall(color)
  106. elif obj_type == 'floor':
  107. v = Floor(color)
  108. elif obj_type == 'ball':
  109. v = Ball(color)
  110. elif obj_type == 'key':
  111. v = Key(color)
  112. elif obj_type == 'box':
  113. v = Box(color)
  114. elif obj_type == 'door':
  115. v = Door(color, is_open, is_locked)
  116. elif obj_type == 'goal':
  117. v = Goal()
  118. elif obj_type == 'lava':
  119. v = Lava()
  120. else:
  121. assert False, "unknown object type in decode '%s'" % obj_type
  122. return v
  123. def render(self, r):
  124. """Draw this object with the given renderer"""
  125. raise NotImplementedError
  126. class Goal(WorldObj):
  127. def __init__(self):
  128. super().__init__('goal', 'green')
  129. def can_overlap(self):
  130. return True
  131. def render(self, img):
  132. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  133. class Floor(WorldObj):
  134. """
  135. Colored floor tile the agent can walk over
  136. """
  137. def __init__(self, color='blue'):
  138. super().__init__('floor', color)
  139. def can_overlap(self):
  140. return True
  141. def render(self, img):
  142. # Give the floor a pale color
  143. color = COLORS[self.color] / 2
  144. fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
  145. class Lava(WorldObj):
  146. def __init__(self):
  147. super().__init__('lava', 'red')
  148. def can_overlap(self):
  149. return True
  150. def render(self, img):
  151. c = (255, 128, 0)
  152. # Background color
  153. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  154. # Little waves
  155. for i in range(3):
  156. ylo = 0.3 + 0.2 * i
  157. yhi = 0.4 + 0.2 * i
  158. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0,0,0))
  159. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0,0,0))
  160. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0,0,0))
  161. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0,0,0))
  162. class Wall(WorldObj):
  163. def __init__(self, color='grey'):
  164. super().__init__('wall', color)
  165. def see_behind(self):
  166. return False
  167. def render(self, img):
  168. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  169. class Door(WorldObj):
  170. def __init__(self, color, is_open=False, is_locked=False):
  171. super().__init__('door', color)
  172. self.is_open = is_open
  173. self.is_locked = is_locked
  174. def can_overlap(self):
  175. """The agent can only walk over this cell when the door is open"""
  176. return self.is_open
  177. def see_behind(self):
  178. return self.is_open
  179. def toggle(self, env, pos):
  180. # If the player has the right key to open the door
  181. if self.is_locked:
  182. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  183. self.is_locked = False
  184. self.is_open = True
  185. return True
  186. return False
  187. self.is_open = not self.is_open
  188. return True
  189. def encode(self):
  190. """Encode the a description of this object as a 3-tuple of integers"""
  191. # State, 0: open, 1: closed, 2: locked
  192. if self.is_open:
  193. state = 0
  194. elif self.is_locked:
  195. state = 2
  196. elif not self.is_open:
  197. state = 1
  198. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  199. def render(self, img):
  200. c = COLORS[self.color]
  201. if self.is_open:
  202. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  203. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0,0,0))
  204. return
  205. # Door frame and door
  206. if self.is_locked:
  207. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  208. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  209. # Draw key slot
  210. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  211. else:
  212. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  213. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0,0,0))
  214. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  215. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0,0,0))
  216. # Draw door handle
  217. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  218. class Key(WorldObj):
  219. def __init__(self, color='blue'):
  220. super(Key, self).__init__('key', color)
  221. def can_pickup(self):
  222. return True
  223. def render(self, img):
  224. c = COLORS[self.color]
  225. # Vertical quad
  226. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  227. # Teeth
  228. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  229. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  230. # Ring
  231. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  232. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0,0,0))
  233. class Ball(WorldObj):
  234. def __init__(self, color='blue'):
  235. super(Ball, self).__init__('ball', color)
  236. def can_pickup(self):
  237. return True
  238. def render(self, img):
  239. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  240. class Box(WorldObj):
  241. def __init__(self, color, contains=None):
  242. super(Box, self).__init__('box', color)
  243. self.contains = contains
  244. def can_pickup(self):
  245. return True
  246. def render(self, img):
  247. c = COLORS[self.color]
  248. # Outline
  249. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  250. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0,0,0))
  251. # Horizontal slit
  252. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  253. def toggle(self, env, pos):
  254. # Replace the box by its contents
  255. env.grid.set(*pos, self.contains)
  256. return True
  257. class Grid:
  258. """
  259. Represent a grid and operations on it
  260. """
  261. # Static cache of pre-renderer tiles
  262. tile_cache = {}
  263. def __init__(self, width, height):
  264. assert width >= 3
  265. assert height >= 3
  266. self.width = width
  267. self.height = height
  268. self.grid = [None] * width * height
  269. def __contains__(self, key):
  270. if isinstance(key, WorldObj):
  271. for e in self.grid:
  272. if e is key:
  273. return True
  274. elif isinstance(key, tuple):
  275. for e in self.grid:
  276. if e is None:
  277. continue
  278. if (e.color, e.type) == key:
  279. return True
  280. if key[0] is None and key[1] == e.type:
  281. return True
  282. return False
  283. def __eq__(self, other):
  284. grid1 = self.encode()
  285. grid2 = other.encode()
  286. return np.array_equal(grid2, grid1)
  287. def __ne__(self, other):
  288. return not self == other
  289. def copy(self):
  290. from copy import deepcopy
  291. return deepcopy(self)
  292. def set(self, i, j, v):
  293. assert i >= 0 and i < self.width
  294. assert j >= 0 and j < self.height
  295. self.grid[j * self.width + i] = v
  296. def get(self, i, j):
  297. assert i >= 0 and i < self.width
  298. assert j >= 0 and j < self.height
  299. return self.grid[j * self.width + i]
  300. def horz_wall(self, x, y, length=None, obj_type=Wall):
  301. if length is None:
  302. length = self.width - x
  303. for i in range(0, length):
  304. self.set(x + i, y, obj_type())
  305. def vert_wall(self, x, y, length=None, obj_type=Wall):
  306. if length is None:
  307. length = self.height - y
  308. for j in range(0, length):
  309. self.set(x, y + j, obj_type())
  310. def wall_rect(self, x, y, w, h):
  311. self.horz_wall(x, y, w)
  312. self.horz_wall(x, y+h-1, w)
  313. self.vert_wall(x, y, h)
  314. self.vert_wall(x+w-1, y, h)
  315. def rotate_left(self):
  316. """
  317. Rotate the grid to the left (counter-clockwise)
  318. """
  319. grid = Grid(self.height, self.width)
  320. for i in range(self.width):
  321. for j in range(self.height):
  322. v = self.get(i, j)
  323. grid.set(j, grid.height - 1 - i, v)
  324. return grid
  325. def slice(self, topX, topY, width, height):
  326. """
  327. Get a subset of the grid
  328. """
  329. grid = Grid(width, height)
  330. for j in range(0, height):
  331. for i in range(0, width):
  332. x = topX + i
  333. y = topY + j
  334. if x >= 0 and x < self.width and \
  335. y >= 0 and y < self.height:
  336. v = self.get(x, y)
  337. else:
  338. v = Wall()
  339. grid.set(i, j, v)
  340. return grid
  341. @classmethod
  342. def render_tile(
  343. cls,
  344. obj,
  345. agent_dir=None,
  346. highlight=False,
  347. tile_size=TILE_PIXELS,
  348. subdivs=3
  349. ):
  350. """
  351. Render a tile and cache the result
  352. """
  353. # Hash map lookup key for the cache
  354. key = (agent_dir, highlight, tile_size)
  355. key = obj.encode() + key if obj else key
  356. if key in cls.tile_cache:
  357. return cls.tile_cache[key]
  358. img = np.zeros(shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8)
  359. # Draw the grid lines (top and left edges)
  360. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  361. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  362. if obj != None:
  363. obj.render(img)
  364. # Overlay the agent on top
  365. if agent_dir is not None:
  366. tri_fn = point_in_triangle(
  367. (0.12, 0.19),
  368. (0.87, 0.50),
  369. (0.12, 0.81),
  370. )
  371. # Rotate the agent based on its direction
  372. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir)
  373. fill_coords(img, tri_fn, (255, 0, 0))
  374. # Highlight the cell if needed
  375. if highlight:
  376. highlight_img(img)
  377. # Downsample the image to perform supersampling/anti-aliasing
  378. img = downsample(img, subdivs)
  379. # Cache the rendered tile
  380. cls.tile_cache[key] = img
  381. return img
  382. def render(
  383. self,
  384. tile_size,
  385. agent_pos=None,
  386. agent_dir=None,
  387. highlight_mask=None
  388. ):
  389. """
  390. Render this grid at a given scale
  391. :param r: target renderer object
  392. :param tile_size: tile size in pixels
  393. """
  394. if highlight_mask is None:
  395. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool)
  396. # Compute the total grid size
  397. width_px = self.width * tile_size
  398. height_px = self.height * tile_size
  399. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  400. # Render the grid
  401. for j in range(0, self.height):
  402. for i in range(0, self.width):
  403. cell = self.get(i, j)
  404. agent_here = np.array_equal(agent_pos, (i, j))
  405. tile_img = Grid.render_tile(
  406. cell,
  407. agent_dir=agent_dir if agent_here else None,
  408. highlight=highlight_mask[i, j],
  409. tile_size=tile_size
  410. )
  411. ymin = j * tile_size
  412. ymax = (j+1) * tile_size
  413. xmin = i * tile_size
  414. xmax = (i+1) * tile_size
  415. img[ymin:ymax, xmin:xmax, :] = tile_img
  416. return img
  417. def encode(self, vis_mask=None):
  418. """
  419. Produce a compact numpy encoding of the grid
  420. """
  421. if vis_mask is None:
  422. vis_mask = np.ones((self.width, self.height), dtype=bool)
  423. array = np.zeros((self.width, self.height, 3), dtype='uint8')
  424. for i in range(self.width):
  425. for j in range(self.height):
  426. if vis_mask[i, j]:
  427. v = self.get(i, j)
  428. if v is None:
  429. array[i, j, 0] = OBJECT_TO_IDX['empty']
  430. array[i, j, 1] = 0
  431. array[i, j, 2] = 0
  432. else:
  433. array[i, j, :] = v.encode()
  434. return array
  435. @staticmethod
  436. def decode(array):
  437. """
  438. Decode an array grid encoding back into a grid
  439. """
  440. width, height, channels = array.shape
  441. assert channels == 3
  442. vis_mask = np.ones(shape=(width, height), dtype=np.bool)
  443. grid = Grid(width, height)
  444. for i in range(width):
  445. for j in range(height):
  446. type_idx, color_idx, state = array[i, j]
  447. v = WorldObj.decode(type_idx, color_idx, state)
  448. grid.set(i, j, v)
  449. vis_mask[i, j] = (type_idx != OBJECT_TO_IDX['unseen'])
  450. return grid, vis_mask
  451. def process_vis(grid, agent_pos):
  452. mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)
  453. mask[agent_pos[0], agent_pos[1]] = True
  454. for j in reversed(range(0, grid.height)):
  455. for i in range(0, grid.width-1):
  456. if not mask[i, j]:
  457. continue
  458. cell = grid.get(i, j)
  459. if cell and not cell.see_behind():
  460. continue
  461. mask[i+1, j] = True
  462. if j > 0:
  463. mask[i+1, j-1] = True
  464. mask[i, j-1] = True
  465. for i in reversed(range(1, grid.width)):
  466. if not mask[i, j]:
  467. continue
  468. cell = grid.get(i, j)
  469. if cell and not cell.see_behind():
  470. continue
  471. mask[i-1, j] = True
  472. if j > 0:
  473. mask[i-1, j-1] = True
  474. mask[i, j-1] = True
  475. for j in range(0, grid.height):
  476. for i in range(0, grid.width):
  477. if not mask[i, j]:
  478. grid.set(i, j, None)
  479. return mask
  480. class MiniGridEnv(gym.Env):
  481. """
  482. 2D grid world game environment
  483. """
  484. metadata = {
  485. 'render.modes': ['human', 'rgb_array'],
  486. 'video.frames_per_second' : 10
  487. }
  488. # Enumeration of possible actions
  489. class Actions(IntEnum):
  490. # Turn left, turn right, move forward
  491. left = 0
  492. right = 1
  493. forward = 2
  494. # Pick up an object
  495. pickup = 3
  496. # Drop an object
  497. drop = 4
  498. # Toggle/activate an object
  499. toggle = 5
  500. # Done completing task
  501. done = 6
  502. def __init__(
  503. self,
  504. grid_size=None,
  505. width=None,
  506. height=None,
  507. max_steps=100,
  508. see_through_walls=False,
  509. seed=1337,
  510. agent_view_size=7
  511. ):
  512. # Can't set both grid_size and width/height
  513. if grid_size:
  514. assert width == None and height == None
  515. width = grid_size
  516. height = grid_size
  517. # Action enumeration for this environment
  518. self.actions = MiniGridEnv.Actions
  519. # Actions are discrete integer values
  520. self.action_space = spaces.Discrete(len(self.actions))
  521. # Number of cells (width and height) in the agent view
  522. self.agent_view_size = agent_view_size
  523. # Observations are dictionaries containing an
  524. # encoding of the grid and a textual 'mission' string
  525. self.observation_space = spaces.Box(
  526. low=0,
  527. high=255,
  528. shape=(self.agent_view_size, self.agent_view_size, 3),
  529. dtype='uint8'
  530. )
  531. self.observation_space = spaces.Dict({
  532. 'image': self.observation_space
  533. })
  534. # Range of possible rewards
  535. self.reward_range = (0, 1)
  536. # Window to use for human rendering mode
  537. self.window = None
  538. # Environment configuration
  539. self.width = width
  540. self.height = height
  541. self.max_steps = max_steps
  542. self.see_through_walls = see_through_walls
  543. # Current position and direction of the agent
  544. self.agent_pos = None
  545. self.agent_dir = None
  546. # Initialize the RNG
  547. self.seed(seed=seed)
  548. # Initialize the state
  549. self.reset()
  550. def reset(self):
  551. # Current position and direction of the agent
  552. self.agent_pos = None
  553. self.agent_dir = None
  554. # Generate a new random grid at the start of each episode
  555. # To keep the same grid for each episode, call env.seed() with
  556. # the same seed before calling env.reset()
  557. self._gen_grid(self.width, self.height)
  558. # These fields should be defined by _gen_grid
  559. assert self.agent_pos is not None
  560. assert self.agent_dir is not None
  561. # Check that the agent doesn't overlap with an object
  562. start_cell = self.grid.get(*self.agent_pos)
  563. assert start_cell is None or start_cell.can_overlap()
  564. # Item picked up, being carried, initially nothing
  565. self.carrying = None
  566. # Step count since episode start
  567. self.step_count = 0
  568. # Return first observation
  569. obs = self.gen_obs()
  570. return obs
  571. def seed(self, seed=1337):
  572. # Seed the random number generator
  573. self.np_random, _ = seeding.np_random(seed)
  574. return [seed]
  575. @property
  576. def steps_remaining(self):
  577. return self.max_steps - self.step_count
  578. def __str__(self):
  579. """
  580. Produce a pretty string of the environment's grid along with the agent.
  581. A grid cell is represented by 2-character string, the first one for
  582. the object and the second one for the color.
  583. """
  584. # Map of object types to short string
  585. OBJECT_TO_STR = {
  586. 'wall' : 'W',
  587. 'floor' : 'F',
  588. 'door' : 'D',
  589. 'key' : 'K',
  590. 'ball' : 'A',
  591. 'box' : 'B',
  592. 'goal' : 'G',
  593. 'lava' : 'V',
  594. }
  595. # Short string for opened door
  596. OPENDED_DOOR_IDS = '_'
  597. # Map agent's direction to short string
  598. AGENT_DIR_TO_STR = {
  599. 0: '>',
  600. 1: 'V',
  601. 2: '<',
  602. 3: '^'
  603. }
  604. str = ''
  605. for j in range(self.grid.height):
  606. for i in range(self.grid.width):
  607. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  608. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  609. continue
  610. c = self.grid.get(i, j)
  611. if c == None:
  612. str += ' '
  613. continue
  614. if c.type == 'door':
  615. if c.is_open:
  616. str += '__'
  617. elif c.is_locked:
  618. str += 'L' + c.color[0].upper()
  619. else:
  620. str += 'D' + c.color[0].upper()
  621. continue
  622. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  623. if j < self.grid.height - 1:
  624. str += '\n'
  625. return str
  626. def _gen_grid(self, width, height):
  627. assert False, "_gen_grid needs to be implemented by each environment"
  628. def _reward(self):
  629. """
  630. Compute the reward to be given upon success
  631. """
  632. return 1 - 0.9 * (self.step_count / self.max_steps)
  633. def _rand_int(self, low, high):
  634. """
  635. Generate random integer in [low,high[
  636. """
  637. return self.np_random.randint(low, high)
  638. def _rand_float(self, low, high):
  639. """
  640. Generate random float in [low,high[
  641. """
  642. return self.np_random.uniform(low, high)
  643. def _rand_bool(self):
  644. """
  645. Generate random boolean value
  646. """
  647. return (self.np_random.randint(0, 2) == 0)
  648. def _rand_elem(self, iterable):
  649. """
  650. Pick a random element in a list
  651. """
  652. lst = list(iterable)
  653. idx = self._rand_int(0, len(lst))
  654. return lst[idx]
  655. def _rand_subset(self, iterable, num_elems):
  656. """
  657. Sample a random subset of distinct elements of a list
  658. """
  659. lst = list(iterable)
  660. assert num_elems <= len(lst)
  661. out = []
  662. while len(out) < num_elems:
  663. elem = self._rand_elem(lst)
  664. lst.remove(elem)
  665. out.append(elem)
  666. return out
  667. def _rand_color(self):
  668. """
  669. Generate a random color name (string)
  670. """
  671. return self._rand_elem(COLOR_NAMES)
  672. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  673. """
  674. Generate a random (x,y) position tuple
  675. """
  676. return (
  677. self.np_random.randint(xLow, xHigh),
  678. self.np_random.randint(yLow, yHigh)
  679. )
  680. def place_obj(self,
  681. obj,
  682. top=None,
  683. size=None,
  684. reject_fn=None,
  685. max_tries=math.inf
  686. ):
  687. """
  688. Place an object at an empty position in the grid
  689. :param top: top-left position of the rectangle where to place
  690. :param size: size of the rectangle where to place
  691. :param reject_fn: function to filter out potential positions
  692. """
  693. if top is None:
  694. top = (0, 0)
  695. else:
  696. top = (max(top[0], 0), max(top[1], 0))
  697. if size is None:
  698. size = (self.grid.width, self.grid.height)
  699. num_tries = 0
  700. while True:
  701. # This is to handle with rare cases where rejection sampling
  702. # gets stuck in an infinite loop
  703. if num_tries > max_tries:
  704. raise RecursionError('rejection sampling failed in place_obj')
  705. num_tries += 1
  706. pos = np.array((
  707. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  708. self._rand_int(top[1], min(top[1] + size[1], self.grid.height))
  709. ))
  710. # Don't place the object on top of another object
  711. if self.grid.get(*pos) != None:
  712. continue
  713. # Don't place the object where the agent is
  714. if np.array_equal(pos, self.agent_pos):
  715. continue
  716. # Check if there is a filtering criterion
  717. if reject_fn and reject_fn(self, pos):
  718. continue
  719. break
  720. self.grid.set(*pos, obj)
  721. if obj is not None:
  722. obj.init_pos = pos
  723. obj.cur_pos = pos
  724. return pos
  725. def put_obj(self, obj, i, j):
  726. """
  727. Put an object at a specific position in the grid
  728. """
  729. self.grid.set(i, j, obj)
  730. obj.init_pos = (i, j)
  731. obj.cur_pos = (i, j)
  732. def place_agent(
  733. self,
  734. top=None,
  735. size=None,
  736. rand_dir=True,
  737. max_tries=math.inf
  738. ):
  739. """
  740. Set the agent's starting point at an empty position in the grid
  741. """
  742. self.agent_pos = None
  743. pos = self.place_obj(None, top, size, max_tries=max_tries)
  744. self.agent_pos = pos
  745. if rand_dir:
  746. self.agent_dir = self._rand_int(0, 4)
  747. return pos
  748. @property
  749. def dir_vec(self):
  750. """
  751. Get the direction vector for the agent, pointing in the direction
  752. of forward movement.
  753. """
  754. assert self.agent_dir >= 0 and self.agent_dir < 4
  755. return DIR_TO_VEC[self.agent_dir]
  756. @property
  757. def right_vec(self):
  758. """
  759. Get the vector pointing to the right of the agent.
  760. """
  761. dx, dy = self.dir_vec
  762. return np.array((-dy, dx))
  763. @property
  764. def front_pos(self):
  765. """
  766. Get the position of the cell that is right in front of the agent
  767. """
  768. return self.agent_pos + self.dir_vec
  769. def get_view_coords(self, i, j):
  770. """
  771. Translate and rotate absolute grid coordinates (i, j) into the
  772. agent's partially observable view (sub-grid). Note that the resulting
  773. coordinates may be negative or outside of the agent's view size.
  774. """
  775. ax, ay = self.agent_pos
  776. dx, dy = self.dir_vec
  777. rx, ry = self.right_vec
  778. # Compute the absolute coordinates of the top-left view corner
  779. sz = self.agent_view_size
  780. hs = self.agent_view_size // 2
  781. tx = ax + (dx * (sz-1)) - (rx * hs)
  782. ty = ay + (dy * (sz-1)) - (ry * hs)
  783. lx = i - tx
  784. ly = j - ty
  785. # Project the coordinates of the object relative to the top-left
  786. # corner onto the agent's own coordinate system
  787. vx = (rx*lx + ry*ly)
  788. vy = -(dx*lx + dy*ly)
  789. return vx, vy
  790. def get_view_exts(self):
  791. """
  792. Get the extents of the square set of tiles visible to the agent
  793. Note: the bottom extent indices are not included in the set
  794. """
  795. # Facing right
  796. if self.agent_dir == 0:
  797. topX = self.agent_pos[0]
  798. topY = self.agent_pos[1] - self.agent_view_size // 2
  799. # Facing down
  800. elif self.agent_dir == 1:
  801. topX = self.agent_pos[0] - self.agent_view_size // 2
  802. topY = self.agent_pos[1]
  803. # Facing left
  804. elif self.agent_dir == 2:
  805. topX = self.agent_pos[0] - self.agent_view_size + 1
  806. topY = self.agent_pos[1] - self.agent_view_size // 2
  807. # Facing up
  808. elif self.agent_dir == 3:
  809. topX = self.agent_pos[0] - self.agent_view_size // 2
  810. topY = self.agent_pos[1] - self.agent_view_size + 1
  811. else:
  812. assert False, "invalid agent direction"
  813. botX = topX + self.agent_view_size
  814. botY = topY + self.agent_view_size
  815. return (topX, topY, botX, botY)
  816. def relative_coords(self, x, y):
  817. """
  818. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  819. """
  820. vx, vy = self.get_view_coords(x, y)
  821. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  822. return None
  823. return vx, vy
  824. def in_view(self, x, y):
  825. """
  826. check if a grid position is visible to the agent
  827. """
  828. return self.relative_coords(x, y) is not None
  829. def agent_sees(self, x, y):
  830. """
  831. Check if a non-empty grid position is visible to the agent
  832. """
  833. coordinates = self.relative_coords(x, y)
  834. if coordinates is None:
  835. return False
  836. vx, vy = coordinates
  837. obs = self.gen_obs()
  838. obs_grid, _ = Grid.decode(obs['image'])
  839. obs_cell = obs_grid.get(vx, vy)
  840. world_cell = self.grid.get(x, y)
  841. return obs_cell is not None and obs_cell.type == world_cell.type
  842. def step(self, action):
  843. self.step_count += 1
  844. reward = 0
  845. done = False
  846. # Get the position in front of the agent
  847. fwd_pos = self.front_pos
  848. # Get the contents of the cell in front of the agent
  849. fwd_cell = self.grid.get(*fwd_pos)
  850. # Rotate left
  851. if action == self.actions.left:
  852. self.agent_dir -= 1
  853. if self.agent_dir < 0:
  854. self.agent_dir += 4
  855. # Rotate right
  856. elif action == self.actions.right:
  857. self.agent_dir = (self.agent_dir + 1) % 4
  858. # Move forward
  859. elif action == self.actions.forward:
  860. if fwd_cell == None or fwd_cell.can_overlap():
  861. self.agent_pos = fwd_pos
  862. if fwd_cell != None and fwd_cell.type == 'goal':
  863. done = True
  864. reward = self._reward()
  865. if fwd_cell != None and fwd_cell.type == 'lava':
  866. done = True
  867. # Pick up an object
  868. elif action == self.actions.pickup:
  869. if fwd_cell and fwd_cell.can_pickup():
  870. if self.carrying is None:
  871. self.carrying = fwd_cell
  872. self.carrying.cur_pos = np.array([-1, -1])
  873. self.grid.set(*fwd_pos, None)
  874. # Drop an object
  875. elif action == self.actions.drop:
  876. if not fwd_cell and self.carrying:
  877. self.grid.set(*fwd_pos, self.carrying)
  878. self.carrying.cur_pos = fwd_pos
  879. self.carrying = None
  880. # Toggle/activate an object
  881. elif action == self.actions.toggle:
  882. if fwd_cell:
  883. fwd_cell.toggle(self, fwd_pos)
  884. # Done action (not used by default)
  885. elif action == self.actions.done:
  886. pass
  887. else:
  888. assert False, "unknown action"
  889. if self.step_count >= self.max_steps:
  890. done = True
  891. obs = self.gen_obs()
  892. return obs, reward, done, {}
  893. def gen_obs_grid(self):
  894. """
  895. Generate the sub-grid observed by the agent.
  896. This method also outputs a visibility mask telling us which grid
  897. cells the agent can actually see.
  898. """
  899. topX, topY, botX, botY = self.get_view_exts()
  900. grid = self.grid.slice(topX, topY, self.agent_view_size, self.agent_view_size)
  901. for i in range(self.agent_dir + 1):
  902. grid = grid.rotate_left()
  903. # Process occluders and visibility
  904. # Note that this incurs some performance cost
  905. if not self.see_through_walls:
  906. vis_mask = grid.process_vis(agent_pos=(self.agent_view_size // 2 , self.agent_view_size - 1))
  907. else:
  908. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)
  909. # Make it so the agent sees what it's carrying
  910. # We do this by placing the carried object at the agent's position
  911. # in the agent's partially observable view
  912. agent_pos = grid.width // 2, grid.height - 1
  913. if self.carrying:
  914. grid.set(*agent_pos, self.carrying)
  915. else:
  916. grid.set(*agent_pos, None)
  917. return grid, vis_mask
  918. def gen_obs(self):
  919. """
  920. Generate the agent's view (partially observable, low-resolution encoding)
  921. """
  922. grid, vis_mask = self.gen_obs_grid()
  923. # Encode the partially observable view into a numpy array
  924. image = grid.encode(vis_mask)
  925. assert hasattr(self, 'mission'), "environments must define a textual mission string"
  926. # Observations are dictionaries containing:
  927. # - an image (partially observable view of the environment)
  928. # - the agent's direction/orientation (acting as a compass)
  929. # - a textual mission string (instructions for the agent)
  930. obs = {
  931. 'image': image,
  932. 'direction': self.agent_dir,
  933. 'mission': self.mission
  934. }
  935. return obs
  936. def get_obs_render(self, obs, tile_size=TILE_PIXELS//2):
  937. """
  938. Render an agent observation for visualization
  939. """
  940. grid, vis_mask = Grid.decode(obs)
  941. # Render the whole grid
  942. img = grid.render(
  943. tile_size,
  944. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  945. agent_dir=3,
  946. highlight_mask=vis_mask
  947. )
  948. return img
  949. def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
  950. """
  951. Render the whole-grid human view
  952. """
  953. if close:
  954. if self.window:
  955. self.window.close()
  956. return
  957. if mode == 'human' and not self.window:
  958. import gym_minigrid.window
  959. self.window = gym_minigrid.window.Window('gym_minigrid')
  960. self.window.show(block=False)
  961. # Compute which cells are visible to the agent
  962. _, vis_mask = self.gen_obs_grid()
  963. # Compute the world coordinates of the bottom-left corner
  964. # of the agent's view area
  965. f_vec = self.dir_vec
  966. r_vec = self.right_vec
  967. top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)
  968. # Mask of which cells to highlight
  969. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool)
  970. # For each cell in the visibility mask
  971. for vis_j in range(0, self.agent_view_size):
  972. for vis_i in range(0, self.agent_view_size):
  973. # If this cell is not visible, don't highlight it
  974. if not vis_mask[vis_i, vis_j]:
  975. continue
  976. # Compute the world coordinates of this cell
  977. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  978. if abs_i < 0 or abs_i >= self.width:
  979. continue
  980. if abs_j < 0 or abs_j >= self.height:
  981. continue
  982. # Mark this cell to be highlighted
  983. highlight_mask[abs_i, abs_j] = True
  984. # Render the whole grid
  985. img = self.grid.render(
  986. tile_size,
  987. self.agent_pos,
  988. self.agent_dir,
  989. highlight_mask=highlight_mask if highlight else None
  990. )
  991. if mode == 'human':
  992. self.window.show_img(img)
  993. self.window.set_caption(self.mission)
  994. return img