verifier.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. """
  2. Copied and adapted from https://github.com/mila-iqia/babyai
  3. """
  4. from __future__ import annotations
  5. import os
  6. from abc import ABC, abstractmethod
  7. import numpy as np
  8. from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC
  9. from minigrid.minigrid_env import MiniGridEnv
  10. # Object types we are allowed to describe in language
  11. OBJ_TYPES = ["box", "ball", "key", "door"]
  12. # Object types we are allowed to describe in language
  13. OBJ_TYPES_NOT_DOOR = list(filter(lambda t: t != "door", OBJ_TYPES))
  14. # Locations are all relative to the agent's starting position
  15. LOC_NAMES = ["left", "right", "front", "behind"]
  16. # Environment flag to indicate that done actions should be
  17. # used by the verifier
  18. use_done_actions = os.environ.get("BABYAI_DONE_ACTIONS", False)
  19. def dot_product(v1, v2):
  20. """
  21. Compute the dot product of the vectors v1 and v2.
  22. """
  23. return sum(i * j for i, j in zip(v1, v2))
  24. def pos_next_to(pos_a, pos_b):
  25. """
  26. Test if two positions are next to each other.
  27. The positions have to line up either horizontally or vertically,
  28. but positions that are diagonally adjacent are not counted.
  29. """
  30. xa, ya = pos_a
  31. xb, yb = pos_b
  32. d = abs(xa - xb) + abs(ya - yb)
  33. return d == 1
  34. class ObjDesc:
  35. """
  36. Description of a set of objects in an environment
  37. """
  38. def __init__(self, type, color=None, loc=None):
  39. assert type in [None, *OBJ_TYPES], type
  40. assert color in [None, *COLOR_NAMES], color
  41. assert loc in [None, *LOC_NAMES], loc
  42. self.color = color
  43. self.type = type
  44. self.loc = loc
  45. # Set of objects possibly matching the description
  46. self.obj_set = []
  47. # Set of initial object positions
  48. self.obj_poss = []
  49. def __repr__(self):
  50. return f"{self.color} {self.type} {self.loc}"
  51. def surface(self, env):
  52. """
  53. Generate a natural language representation of the object description
  54. """
  55. self.find_matching_objs(env)
  56. assert len(self.obj_set) > 0, "no object matching description"
  57. if self.type:
  58. s = str(self.type)
  59. else:
  60. s = "object"
  61. if self.color:
  62. s = self.color + " " + s
  63. if self.loc:
  64. if self.loc == "front":
  65. s = s + " in front of you"
  66. elif self.loc == "behind":
  67. s = s + " behind you"
  68. else:
  69. s = s + " on your " + self.loc
  70. # Singular vs plural
  71. if len(self.obj_set) > 1:
  72. s = "a " + s
  73. else:
  74. s = "the " + s
  75. return s
  76. def find_matching_objs(self, env, use_location=True):
  77. """
  78. Find the set of objects matching the description and their positions.
  79. When use_location is False, we only update the positions of already tracked objects, without taking into account
  80. the location of the object. e.g. A ball that was on "your right" initially will still be tracked as being "on
  81. your right" when you move.
  82. """
  83. if use_location:
  84. self.obj_set = []
  85. # otherwise we keep the same obj_set
  86. self.obj_poss = []
  87. agent_room = env.room_from_pos(*env.agent_pos)
  88. for i in range(env.grid.width):
  89. for j in range(env.grid.height):
  90. cell = env.grid.get(i, j)
  91. if cell is None:
  92. continue
  93. if not use_location:
  94. # we should keep tracking the same objects initially tracked only
  95. already_tracked = any([cell is obj for obj in self.obj_set])
  96. if not already_tracked:
  97. continue
  98. # Check if object's type matches description
  99. if self.type is not None and cell.type != self.type:
  100. continue
  101. # Check if object's color matches description
  102. if self.color is not None and cell.color != self.color:
  103. continue
  104. # Check if object's position matches description
  105. if use_location and self.loc in ["left", "right", "front", "behind"]:
  106. # Locations apply only to objects in the same room
  107. # the agent starts in
  108. if not agent_room.pos_inside(i, j):
  109. continue
  110. # Direction from the agent to the object
  111. v = (i - env.agent_pos[0], j - env.agent_pos[1])
  112. # (d1, d2) is an oriented orthonormal basis
  113. d1 = DIR_TO_VEC[env.agent_dir]
  114. d2 = (-d1[1], d1[0])
  115. # Check if object's position matches with location
  116. pos_matches = {
  117. "left": dot_product(v, d2) < 0,
  118. "right": dot_product(v, d2) > 0,
  119. "front": dot_product(v, d1) > 0,
  120. "behind": dot_product(v, d1) < 0,
  121. }
  122. if not (pos_matches[self.loc]):
  123. continue
  124. if use_location:
  125. self.obj_set.append(cell)
  126. self.obj_poss.append((i, j))
  127. return self.obj_set, self.obj_poss
  128. class Instr(ABC):
  129. """
  130. Base class for all instructions in the baby language
  131. """
  132. def __init__(self):
  133. self.env: MiniGridEnv
  134. @abstractmethod
  135. def surface(self, env):
  136. """
  137. Produce a natural language representation of the instruction
  138. """
  139. raise NotImplementedError
  140. def reset_verifier(self, env):
  141. """
  142. Must be called at the beginning of the episode
  143. """
  144. self.env = env
  145. @abstractmethod
  146. def verify(self, action):
  147. """
  148. Verify if the task described by the instruction is incomplete,
  149. complete with success or failed. The return value is a string,
  150. one of: 'success', 'failure' or 'continue'.
  151. """
  152. raise NotImplementedError
  153. def update_objs_poss(self):
  154. """
  155. Update the position of objects present in the instruction if needed
  156. """
  157. potential_objects = ("desc", "desc_move", "desc_fixed")
  158. for attr in potential_objects:
  159. if hasattr(self, attr):
  160. getattr(self, attr).find_matching_objs(self.env, use_location=False)
  161. class ActionInstr(Instr, ABC):
  162. """
  163. Base class for all action instructions (clauses)
  164. """
  165. def __init__(self):
  166. super().__init__()
  167. # Indicates that the action was completed on the last step
  168. self.lastStepMatch = False
  169. def verify(self, action):
  170. """
  171. Verifies actions, with and without the done action.
  172. """
  173. if not use_done_actions:
  174. return self.verify_action(action)
  175. if action == self.env.actions.done:
  176. if self.lastStepMatch:
  177. return "success"
  178. return "failure"
  179. res = self.verify_action(action)
  180. self.lastStepMatch = res == "success"
  181. @abstractmethod
  182. def verify_action(self):
  183. """
  184. Each action instruction class should implement this method
  185. to verify the action.
  186. """
  187. raise NotImplementedError
  188. class OpenInstr(ActionInstr):
  189. def __init__(self, obj_desc, strict=False):
  190. super().__init__()
  191. assert obj_desc.type == "door"
  192. self.desc = obj_desc
  193. self.strict = strict
  194. def surface(self, env):
  195. return "open " + self.desc.surface(env)
  196. def reset_verifier(self, env):
  197. super().reset_verifier(env)
  198. # Identify set of possible matching objects in the environment
  199. self.desc.find_matching_objs(env)
  200. def verify_action(self, action):
  201. # Only verify when the toggle action is performed
  202. if action != self.env.actions.toggle:
  203. return "continue"
  204. # Get the contents of the cell in front of the agent
  205. front_cell = self.env.grid.get(*self.env.front_pos)
  206. for door in self.desc.obj_set:
  207. if front_cell and front_cell is door and door.is_open:
  208. return "success"
  209. # If in strict mode and the wrong door is opened, failure
  210. if self.strict:
  211. if front_cell and front_cell.type == "door":
  212. return "failure"
  213. return "continue"
  214. class GoToInstr(ActionInstr):
  215. """
  216. Go next to (and look towards) an object matching a given description
  217. eg: go to the door
  218. """
  219. def __init__(self, obj_desc):
  220. super().__init__()
  221. self.desc = obj_desc
  222. def surface(self, env):
  223. return "go to " + self.desc.surface(env)
  224. def reset_verifier(self, env):
  225. super().reset_verifier(env)
  226. # Identify set of possible matching objects in the environment
  227. self.desc.find_matching_objs(env)
  228. def verify_action(self, action):
  229. # For each object position
  230. for pos in self.desc.obj_poss:
  231. # If the agent is next to (and facing) the object
  232. if np.array_equal(pos, self.env.front_pos):
  233. return "success"
  234. return "continue"
  235. class PickupInstr(ActionInstr):
  236. """
  237. Pick up an object matching a given description
  238. eg: pick up the grey ball
  239. """
  240. def __init__(self, obj_desc, strict=False):
  241. super().__init__()
  242. assert obj_desc.type != "door"
  243. self.desc = obj_desc
  244. self.strict = strict
  245. def surface(self, env):
  246. return "pick up " + self.desc.surface(env)
  247. def reset_verifier(self, env):
  248. super().reset_verifier(env)
  249. # Object previously being carried
  250. self.preCarrying = None
  251. # Identify set of possible matching objects in the environment
  252. self.desc.find_matching_objs(env)
  253. def verify_action(self, action):
  254. # To keep track of what was carried at the last time step
  255. preCarrying = self.preCarrying
  256. self.preCarrying = self.env.carrying
  257. # Only verify when the pickup action is performed
  258. if action != self.env.actions.pickup:
  259. return "continue"
  260. for obj in self.desc.obj_set:
  261. if preCarrying is None and self.env.carrying is obj:
  262. return "success"
  263. # If in strict mode and the wrong door object is picked up, failure
  264. if self.strict:
  265. if self.env.carrying:
  266. return "failure"
  267. self.preCarrying = self.env.carrying
  268. return "continue"
  269. class PutNextInstr(ActionInstr):
  270. """
  271. Put an object next to another object
  272. eg: put the red ball next to the blue key
  273. """
  274. def __init__(self, obj_move, obj_fixed, strict=False):
  275. super().__init__()
  276. assert obj_move.type != "door"
  277. self.desc_move = obj_move
  278. self.desc_fixed = obj_fixed
  279. self.strict = strict
  280. def surface(self, env):
  281. return (
  282. "put "
  283. + self.desc_move.surface(env)
  284. + " next to "
  285. + self.desc_fixed.surface(env)
  286. )
  287. def reset_verifier(self, env):
  288. super().reset_verifier(env)
  289. # Object previously being carried
  290. self.preCarrying = None
  291. # Identify set of possible matching objects in the environment
  292. self.desc_move.find_matching_objs(env)
  293. self.desc_fixed.find_matching_objs(env)
  294. def objs_next(self):
  295. """
  296. Check if the objects are next to each other
  297. This is used for rejection sampling
  298. """
  299. for obj_a in self.desc_move.obj_set:
  300. pos_a = obj_a.cur_pos
  301. for pos_b in self.desc_fixed.obj_poss:
  302. if pos_next_to(pos_a, pos_b):
  303. return True
  304. return False
  305. def verify_action(self, action):
  306. # To keep track of what was carried at the last time step
  307. preCarrying = self.preCarrying
  308. self.preCarrying = self.env.carrying
  309. # In strict mode, picking up the wrong object fails
  310. if self.strict:
  311. if action == self.env.actions.pickup and self.env.carrying:
  312. return "failure"
  313. # Only verify when the drop action is performed
  314. if action != self.env.actions.drop:
  315. return "continue"
  316. for obj_a in self.desc_move.obj_set:
  317. if preCarrying is not obj_a:
  318. continue
  319. pos_a = obj_a.cur_pos
  320. for pos_b in self.desc_fixed.obj_poss:
  321. if pos_next_to(pos_a, pos_b):
  322. return "success"
  323. return "continue"
  324. class SeqInstr(Instr, ABC):
  325. """
  326. Base class for sequencing instructions (before, after, and)
  327. """
  328. def __init__(self, instr_a, instr_b, strict=False):
  329. assert isinstance(instr_a, ActionInstr) or isinstance(instr_a, AndInstr)
  330. assert isinstance(instr_b, ActionInstr) or isinstance(instr_b, AndInstr)
  331. self.instr_a = instr_a
  332. self.instr_b = instr_b
  333. self.strict = strict
  334. class BeforeInstr(SeqInstr):
  335. """
  336. Sequence two instructions in order:
  337. eg: go to the red door then pick up the blue ball
  338. """
  339. def surface(self, env):
  340. return self.instr_a.surface(env) + ", then " + self.instr_b.surface(env)
  341. def reset_verifier(self, env):
  342. super().reset_verifier(env)
  343. self.instr_a.reset_verifier(env)
  344. self.instr_b.reset_verifier(env)
  345. self.a_done = False
  346. self.b_done = False
  347. def verify(self, action):
  348. if self.a_done == "success":
  349. self.b_done = self.instr_b.verify(action)
  350. if self.b_done == "failure":
  351. return "failure"
  352. if self.b_done == "success":
  353. return "success"
  354. else:
  355. self.a_done = self.instr_a.verify(action)
  356. if self.a_done == "failure":
  357. return "failure"
  358. if self.a_done == "success":
  359. return self.verify(action)
  360. # In strict mode, completing b first means failure
  361. if self.strict:
  362. if self.instr_b.verify(action) == "success":
  363. return "failure"
  364. return "continue"
  365. class AfterInstr(SeqInstr):
  366. """
  367. Sequence two instructions in reverse order:
  368. eg: go to the red door after you pick up the blue ball
  369. """
  370. def surface(self, env):
  371. return self.instr_a.surface(env) + " after you " + self.instr_b.surface(env)
  372. def reset_verifier(self, env):
  373. super().reset_verifier(env)
  374. self.instr_a.reset_verifier(env)
  375. self.instr_b.reset_verifier(env)
  376. self.a_done = False
  377. self.b_done = False
  378. def verify(self, action):
  379. if self.b_done == "success":
  380. self.a_done = self.instr_a.verify(action)
  381. if self.a_done == "success":
  382. return "success"
  383. if self.a_done == "failure":
  384. return "failure"
  385. else:
  386. self.b_done = self.instr_b.verify(action)
  387. if self.b_done == "failure":
  388. return "failure"
  389. if self.b_done == "success":
  390. return self.verify(action)
  391. # In strict mode, completing a first means failure
  392. if self.strict:
  393. if self.instr_a.verify(action) == "success":
  394. return "failure"
  395. return "continue"
  396. class AndInstr(SeqInstr):
  397. """
  398. Conjunction of two actions, both can be completed in any other
  399. eg: go to the red door and pick up the blue ball
  400. """
  401. def __init__(self, instr_a, instr_b, strict=False):
  402. assert isinstance(instr_a, ActionInstr)
  403. assert isinstance(instr_b, ActionInstr)
  404. super().__init__(instr_a, instr_b, strict)
  405. def surface(self, env):
  406. return self.instr_a.surface(env) + " and " + self.instr_b.surface(env)
  407. def reset_verifier(self, env):
  408. super().reset_verifier(env)
  409. self.instr_a.reset_verifier(env)
  410. self.instr_b.reset_verifier(env)
  411. self.a_done = False
  412. self.b_done = False
  413. def verify(self, action):
  414. if self.a_done != "success":
  415. self.a_done = self.instr_a.verify(action)
  416. if self.b_done != "success":
  417. self.b_done = self.instr_b.verify(action)
  418. if use_done_actions and action is self.env.actions.done:
  419. if self.a_done == "failure" and self.b_done == "failure":
  420. return "failure"
  421. if self.a_done == "success" and self.b_done == "success":
  422. return "success"
  423. return "continue"