verifier.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569
  1. """
  2. Copied and adapted from https://github.com/mila-iqia/babyai
  3. """
  4. from __future__ import annotations
  5. import os
  6. from abc import ABC, abstractmethod
  7. import numpy as np
  8. from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC
  9. from minigrid.minigrid_env import MiniGridEnv
  10. # Object types we are allowed to describe in language
  11. OBJ_TYPES = ["box", "ball", "key", "door"]
  12. # Object types we are allowed to describe in language
  13. OBJ_TYPES_NOT_DOOR = list(filter(lambda t: t != "door", OBJ_TYPES))
  14. # Locations are all relative to the agent's starting position
  15. LOC_NAMES = ["left", "right", "front", "behind"]
  16. # Environment flag to indicate that done actions should be
  17. # used by the verifier
  18. use_done_actions = os.environ.get("BABYAI_DONE_ACTIONS", False)
  19. def dot_product(v1, v2):
  20. """
  21. Compute the dot product of the vectors v1 and v2.
  22. """
  23. return sum(i * j for i, j in zip(v1, v2))
  24. def pos_next_to(pos_a, pos_b):
  25. """
  26. Test if two positions are next to each other.
  27. The positions have to line up either horizontally or vertically,
  28. but positions that are diagonally adjacent are not counted.
  29. """
  30. xa, ya = pos_a
  31. xb, yb = pos_b
  32. d = abs(xa - xb) + abs(ya - yb)
  33. return d == 1
  34. class ObjDesc:
  35. """
  36. Description of a set of objects in an environment
  37. """
  38. def __init__(self, type, color=None, loc=None):
  39. assert type in [None, *OBJ_TYPES], type
  40. assert color in [None, *COLOR_NAMES], color
  41. assert loc in [None, *LOC_NAMES], loc
  42. self.color = color
  43. self.type = type
  44. self.loc = loc
  45. # Set of objects possibly matching the description
  46. self.obj_set = []
  47. # Set of initial object positions
  48. self.obj_poss = []
  49. def __repr__(self):
  50. return f"{self.color} {self.type} {self.loc}"
  51. def surface(self, env):
  52. """
  53. Generate a natural language representation of the object description
  54. """
  55. self.find_matching_objs(env)
  56. assert len(self.obj_set) > 0, "no object matching description"
  57. if self.type:
  58. s = str(self.type)
  59. else:
  60. s = "object"
  61. if self.color:
  62. s = self.color + " " + s
  63. if self.loc:
  64. if self.loc == "front":
  65. s = s + " in front of you"
  66. elif self.loc == "behind":
  67. s = s + " behind you"
  68. else:
  69. s = s + " on your " + self.loc
  70. # Singular vs plural
  71. if len(self.obj_set) > 1:
  72. s = "a " + s
  73. else:
  74. s = "the " + s
  75. return s
  76. def find_matching_objs(self, env, use_location=True):
  77. """
  78. Find the set of objects matching the description and their positions.
  79. When use_location is False, we only update the positions of already tracked objects, without taking into account
  80. the location of the object. e.g. A ball that was on "your right" initially will still be tracked as being "on
  81. your right" when you move.
  82. """
  83. env = env.unwrapped
  84. if use_location:
  85. self.obj_set = []
  86. # otherwise we keep the same obj_set
  87. self.obj_poss = []
  88. agent_room = env.room_from_pos(*env.agent_pos)
  89. for i in range(env.grid.width):
  90. for j in range(env.grid.height):
  91. cell = env.grid.get(i, j)
  92. if cell is None:
  93. continue
  94. if not use_location:
  95. # we should keep tracking the same objects initially tracked only
  96. already_tracked = any([cell is obj for obj in self.obj_set])
  97. if not already_tracked:
  98. continue
  99. # Check if object's type matches description
  100. if self.type is not None and cell.type != self.type:
  101. continue
  102. # Check if object's color matches description
  103. if self.color is not None and cell.color != self.color:
  104. continue
  105. # Check if object's position matches description
  106. if use_location and self.loc in ["left", "right", "front", "behind"]:
  107. # Locations apply only to objects in the same room
  108. # the agent starts in
  109. if not agent_room.pos_inside(i, j):
  110. continue
  111. # Direction from the agent to the object
  112. v = (i - env.agent_pos[0], j - env.agent_pos[1])
  113. # (d1, d2) is an oriented orthonormal basis
  114. d1 = DIR_TO_VEC[env.agent_dir]
  115. d2 = (-d1[1], d1[0])
  116. # Check if object's position matches with location
  117. pos_matches = {
  118. "left": dot_product(v, d2) < 0,
  119. "right": dot_product(v, d2) > 0,
  120. "front": dot_product(v, d1) > 0,
  121. "behind": dot_product(v, d1) < 0,
  122. }
  123. if not (pos_matches[self.loc]):
  124. continue
  125. if use_location:
  126. self.obj_set.append(cell)
  127. self.obj_poss.append((i, j))
  128. return self.obj_set, self.obj_poss
  129. class Instr(ABC):
  130. """
  131. Base class for all instructions in the baby language
  132. """
  133. def __init__(self):
  134. self.env: MiniGridEnv
  135. @abstractmethod
  136. def surface(self, env):
  137. """
  138. Produce a natural language representation of the instruction
  139. """
  140. raise NotImplementedError
  141. def reset_verifier(self, env):
  142. """
  143. Must be called at the beginning of the episode
  144. """
  145. self.env = env
  146. @abstractmethod
  147. def verify(self, action):
  148. """
  149. Verify if the task described by the instruction is incomplete,
  150. complete with success or failed. The return value is a string,
  151. one of: 'success', 'failure' or 'continue'.
  152. """
  153. raise NotImplementedError
  154. def update_objs_poss(self):
  155. """
  156. Update the position of objects present in the instruction if needed
  157. """
  158. potential_objects = ("desc", "desc_move", "desc_fixed")
  159. for attr in potential_objects:
  160. if hasattr(self, attr):
  161. getattr(self, attr).find_matching_objs(self.env, use_location=False)
  162. class ActionInstr(Instr, ABC):
  163. """
  164. Base class for all action instructions (clauses)
  165. """
  166. def __init__(self):
  167. super().__init__()
  168. # Indicates that the action was completed on the last step
  169. self.lastStepMatch = False
  170. def verify(self, action):
  171. """
  172. Verifies actions, with and without the done action.
  173. """
  174. if not use_done_actions:
  175. return self.verify_action(action)
  176. if action == self.env.actions.done:
  177. if self.lastStepMatch:
  178. return "success"
  179. return "failure"
  180. res = self.verify_action(action)
  181. self.lastStepMatch = res == "success"
  182. @abstractmethod
  183. def verify_action(self):
  184. """
  185. Each action instruction class should implement this method
  186. to verify the action.
  187. """
  188. raise NotImplementedError
  189. class OpenInstr(ActionInstr):
  190. def __init__(self, obj_desc, strict=False):
  191. super().__init__()
  192. assert obj_desc.type == "door"
  193. self.desc = obj_desc
  194. self.strict = strict
  195. def surface(self, env):
  196. return "open " + self.desc.surface(env)
  197. def reset_verifier(self, env):
  198. super().reset_verifier(env)
  199. # Identify set of possible matching objects in the environment
  200. self.desc.find_matching_objs(env)
  201. def verify_action(self, action):
  202. # Only verify when the toggle action is performed
  203. if action != self.env.actions.toggle:
  204. return "continue"
  205. # Get the contents of the cell in front of the agent
  206. front_cell = self.env.grid.get(*self.env.front_pos)
  207. for door in self.desc.obj_set:
  208. if front_cell and front_cell is door and door.is_open:
  209. return "success"
  210. # If in strict mode and the wrong door is opened, failure
  211. if self.strict:
  212. if front_cell and front_cell.type == "door":
  213. return "failure"
  214. return "continue"
  215. class GoToInstr(ActionInstr):
  216. """
  217. Go next to (and look towards) an object matching a given description
  218. eg: go to the door
  219. """
  220. def __init__(self, obj_desc):
  221. super().__init__()
  222. self.desc = obj_desc
  223. def surface(self, env):
  224. return "go to " + self.desc.surface(env)
  225. def reset_verifier(self, env):
  226. super().reset_verifier(env)
  227. # Identify set of possible matching objects in the environment
  228. self.desc.find_matching_objs(env)
  229. def verify_action(self, action):
  230. # For each object position
  231. for pos in self.desc.obj_poss:
  232. # If the agent is next to (and facing) the object
  233. if np.array_equal(pos, self.env.front_pos):
  234. return "success"
  235. return "continue"
  236. class PickupInstr(ActionInstr):
  237. """
  238. Pick up an object matching a given description
  239. eg: pick up the grey ball
  240. """
  241. def __init__(self, obj_desc, strict=False):
  242. super().__init__()
  243. assert obj_desc.type != "door"
  244. self.desc = obj_desc
  245. self.strict = strict
  246. def surface(self, env):
  247. return "pick up " + self.desc.surface(env)
  248. def reset_verifier(self, env):
  249. super().reset_verifier(env)
  250. # Object previously being carried
  251. self.preCarrying = None
  252. # Identify set of possible matching objects in the environment
  253. self.desc.find_matching_objs(env)
  254. def verify_action(self, action):
  255. # To keep track of what was carried at the last time step
  256. preCarrying = self.preCarrying
  257. self.preCarrying = self.env.carrying
  258. # Only verify when the pickup action is performed
  259. if action != self.env.actions.pickup:
  260. return "continue"
  261. for obj in self.desc.obj_set:
  262. if preCarrying is None and self.env.carrying is obj:
  263. return "success"
  264. # If in strict mode and the wrong door object is picked up, failure
  265. if self.strict:
  266. if self.env.carrying:
  267. return "failure"
  268. self.preCarrying = self.env.carrying
  269. return "continue"
  270. class PutNextInstr(ActionInstr):
  271. """
  272. Put an object next to another object
  273. eg: put the red ball next to the blue key
  274. """
  275. def __init__(self, obj_move, obj_fixed, strict=False):
  276. super().__init__()
  277. assert obj_move.type != "door"
  278. self.desc_move = obj_move
  279. self.desc_fixed = obj_fixed
  280. self.strict = strict
  281. def surface(self, env):
  282. return (
  283. "put "
  284. + self.desc_move.surface(env)
  285. + " next to "
  286. + self.desc_fixed.surface(env)
  287. )
  288. def reset_verifier(self, env):
  289. super().reset_verifier(env)
  290. # Object previously being carried
  291. self.preCarrying = None
  292. # Identify set of possible matching objects in the environment
  293. self.desc_move.find_matching_objs(env)
  294. self.desc_fixed.find_matching_objs(env)
  295. def objs_next(self):
  296. """
  297. Check if the objects are next to each other
  298. This is used for rejection sampling
  299. """
  300. for obj_a in self.desc_move.obj_set:
  301. pos_a = obj_a.cur_pos
  302. for pos_b in self.desc_fixed.obj_poss:
  303. if pos_next_to(pos_a, pos_b):
  304. return True
  305. return False
  306. def verify_action(self, action):
  307. # To keep track of what was carried at the last time step
  308. preCarrying = self.preCarrying
  309. self.preCarrying = self.env.carrying
  310. # In strict mode, picking up the wrong object fails
  311. if self.strict:
  312. if action == self.env.actions.pickup and self.env.carrying:
  313. return "failure"
  314. # Only verify when the drop action is performed
  315. if action != self.env.actions.drop:
  316. return "continue"
  317. for obj_a in self.desc_move.obj_set:
  318. if preCarrying is not obj_a:
  319. continue
  320. pos_a = obj_a.cur_pos
  321. for pos_b in self.desc_fixed.obj_poss:
  322. if pos_next_to(pos_a, pos_b):
  323. return "success"
  324. return "continue"
  325. class SeqInstr(Instr, ABC):
  326. """
  327. Base class for sequencing instructions (before, after, and)
  328. """
  329. def __init__(self, instr_a, instr_b, strict=False):
  330. assert isinstance(instr_a, ActionInstr) or isinstance(instr_a, AndInstr)
  331. assert isinstance(instr_b, ActionInstr) or isinstance(instr_b, AndInstr)
  332. self.instr_a = instr_a
  333. self.instr_b = instr_b
  334. self.strict = strict
  335. class BeforeInstr(SeqInstr):
  336. """
  337. Sequence two instructions in order:
  338. eg: go to the red door then pick up the blue ball
  339. """
  340. def surface(self, env):
  341. return self.instr_a.surface(env) + ", then " + self.instr_b.surface(env)
  342. def reset_verifier(self, env):
  343. super().reset_verifier(env)
  344. self.instr_a.reset_verifier(env)
  345. self.instr_b.reset_verifier(env)
  346. self.a_done = False
  347. self.b_done = False
  348. def verify(self, action):
  349. if self.a_done == "success":
  350. self.b_done = self.instr_b.verify(action)
  351. if self.b_done == "failure":
  352. return "failure"
  353. if self.b_done == "success":
  354. return "success"
  355. else:
  356. self.a_done = self.instr_a.verify(action)
  357. if self.a_done == "failure":
  358. return "failure"
  359. if self.a_done == "success":
  360. return self.verify(action)
  361. # In strict mode, completing b first means failure
  362. if self.strict:
  363. if self.instr_b.verify(action) == "success":
  364. return "failure"
  365. return "continue"
  366. class AfterInstr(SeqInstr):
  367. """
  368. Sequence two instructions in reverse order:
  369. eg: go to the red door after you pick up the blue ball
  370. """
  371. def surface(self, env):
  372. return self.instr_a.surface(env) + " after you " + self.instr_b.surface(env)
  373. def reset_verifier(self, env):
  374. super().reset_verifier(env)
  375. self.instr_a.reset_verifier(env)
  376. self.instr_b.reset_verifier(env)
  377. self.a_done = False
  378. self.b_done = False
  379. def verify(self, action):
  380. if self.b_done == "success":
  381. self.a_done = self.instr_a.verify(action)
  382. if self.a_done == "success":
  383. return "success"
  384. if self.a_done == "failure":
  385. return "failure"
  386. else:
  387. self.b_done = self.instr_b.verify(action)
  388. if self.b_done == "failure":
  389. return "failure"
  390. if self.b_done == "success":
  391. return self.verify(action)
  392. # In strict mode, completing a first means failure
  393. if self.strict:
  394. if self.instr_a.verify(action) == "success":
  395. return "failure"
  396. return "continue"
  397. class AndInstr(SeqInstr):
  398. """
  399. Conjunction of two actions, both can be completed in any other
  400. eg: go to the red door and pick up the blue ball
  401. """
  402. def __init__(self, instr_a, instr_b, strict=False):
  403. assert isinstance(instr_a, ActionInstr)
  404. assert isinstance(instr_b, ActionInstr)
  405. super().__init__(instr_a, instr_b, strict)
  406. def surface(self, env):
  407. return self.instr_a.surface(env) + " and " + self.instr_b.surface(env)
  408. def reset_verifier(self, env):
  409. super().reset_verifier(env)
  410. self.instr_a.reset_verifier(env)
  411. self.instr_b.reset_verifier(env)
  412. self.a_done = False
  413. self.b_done = False
  414. def verify(self, action):
  415. if self.a_done != "success":
  416. self.a_done = self.instr_a.verify(action)
  417. if self.b_done != "success":
  418. self.b_done = self.instr_b.verify(action)
  419. if use_done_actions and action is self.env.actions.done:
  420. if self.a_done == "failure" and self.b_done == "failure":
  421. return "failure"
  422. if self.a_done == "success" and self.b_done == "success":
  423. return "success"
  424. return "continue"