verifier.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565
  1. """
  2. Copied and adapted from https://github.com/mila-iqia/babyai
  3. """
  4. import os
  5. from abc import ABC, abstractmethod
  6. import numpy as np
  7. from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC
  8. from minigrid.minigrid_env import MiniGridEnv
  9. # Object types we are allowed to describe in language
  10. OBJ_TYPES = ["box", "ball", "key", "door"]
  11. # Object types we are allowed to describe in language
  12. OBJ_TYPES_NOT_DOOR = list(filter(lambda t: t != "door", OBJ_TYPES))
  13. # Locations are all relative to the agent's starting position
  14. LOC_NAMES = ["left", "right", "front", "behind"]
  15. # Environment flag to indicate that done actions should be
  16. # used by the verifier
  17. use_done_actions = os.environ.get("BABYAI_DONE_ACTIONS", False)
  18. def dot_product(v1, v2):
  19. """
  20. Compute the dot product of the vectors v1 and v2.
  21. """
  22. return sum(i * j for i, j in zip(v1, v2))
  23. def pos_next_to(pos_a, pos_b):
  24. """
  25. Test if two positions are next to each other.
  26. The positions have to line up either horizontally or vertically,
  27. but positions that are diagonally adjacent are not counted.
  28. """
  29. xa, ya = pos_a
  30. xb, yb = pos_b
  31. d = abs(xa - xb) + abs(ya - yb)
  32. return d == 1
  33. class ObjDesc:
  34. """
  35. Description of a set of objects in an environment
  36. """
  37. def __init__(self, type, color=None, loc=None):
  38. assert type in [None, *OBJ_TYPES], type
  39. assert color in [None, *COLOR_NAMES], color
  40. assert loc in [None, *LOC_NAMES], loc
  41. self.color = color
  42. self.type = type
  43. self.loc = loc
  44. # Set of objects possibly matching the description
  45. self.obj_set = []
  46. # Set of initial object positions
  47. self.obj_poss = []
  48. def __repr__(self):
  49. return f"{self.color} {self.type} {self.loc}"
  50. def surface(self, env):
  51. """
  52. Generate a natural language representation of the object description
  53. """
  54. self.find_matching_objs(env)
  55. assert len(self.obj_set) > 0, "no object matching description"
  56. if self.type:
  57. s = str(self.type)
  58. else:
  59. s = "object"
  60. if self.color:
  61. s = self.color + " " + s
  62. if self.loc:
  63. if self.loc == "front":
  64. s = s + " in front of you"
  65. elif self.loc == "behind":
  66. s = s + " behind you"
  67. else:
  68. s = s + " on your " + self.loc
  69. # Singular vs plural
  70. if len(self.obj_set) > 1:
  71. s = "a " + s
  72. else:
  73. s = "the " + s
  74. return s
  75. def find_matching_objs(self, env, use_location=True):
  76. """
  77. Find the set of objects matching the description and their positions.
  78. When use_location is False, we only update the positions of already tracked objects, without taking into account
  79. the location of the object. e.g. A ball that was on "your right" initially will still be tracked as being "on
  80. your right" when you move.
  81. """
  82. if use_location:
  83. self.obj_set = []
  84. # otherwise we keep the same obj_set
  85. self.obj_poss = []
  86. agent_room = env.room_from_pos(*env.agent_pos)
  87. for i in range(env.grid.width):
  88. for j in range(env.grid.height):
  89. cell = env.grid.get(i, j)
  90. if cell is None:
  91. continue
  92. if not use_location:
  93. # we should keep tracking the same objects initially tracked only
  94. already_tracked = any([cell is obj for obj in self.obj_set])
  95. if not already_tracked:
  96. continue
  97. # Check if object's type matches description
  98. if self.type is not None and cell.type != self.type:
  99. continue
  100. # Check if object's color matches description
  101. if self.color is not None and cell.color != self.color:
  102. continue
  103. # Check if object's position matches description
  104. if use_location and self.loc in ["left", "right", "front", "behind"]:
  105. # Locations apply only to objects in the same room
  106. # the agent starts in
  107. if not agent_room.pos_inside(i, j):
  108. continue
  109. # Direction from the agent to the object
  110. v = (i - env.agent_pos[0], j - env.agent_pos[1])
  111. # (d1, d2) is an oriented orthonormal basis
  112. d1 = DIR_TO_VEC[env.agent_dir]
  113. d2 = (-d1[1], d1[0])
  114. # Check if object's position matches with location
  115. pos_matches = {
  116. "left": dot_product(v, d2) < 0,
  117. "right": dot_product(v, d2) > 0,
  118. "front": dot_product(v, d1) > 0,
  119. "behind": dot_product(v, d1) < 0,
  120. }
  121. if not (pos_matches[self.loc]):
  122. continue
  123. if use_location:
  124. self.obj_set.append(cell)
  125. self.obj_poss.append((i, j))
  126. return self.obj_set, self.obj_poss
  127. class Instr(ABC):
  128. """
  129. Base class for all instructions in the baby language
  130. """
  131. def __init__(self):
  132. self.env: MiniGridEnv
  133. @abstractmethod
  134. def surface(self, env):
  135. """
  136. Produce a natural language representation of the instruction
  137. """
  138. raise NotImplementedError
  139. def reset_verifier(self, env):
  140. """
  141. Must be called at the beginning of the episode
  142. """
  143. self.env = env
  144. @abstractmethod
  145. def verify(self, action):
  146. """
  147. Verify if the task described by the instruction is incomplete,
  148. complete with success or failed. The return value is a string,
  149. one of: 'success', 'failure' or 'continue'.
  150. """
  151. raise NotImplementedError
  152. def update_objs_poss(self):
  153. """
  154. Update the position of objects present in the instruction if needed
  155. """
  156. potential_objects = ("desc", "desc_move", "desc_fixed")
  157. for attr in potential_objects:
  158. if hasattr(self, attr):
  159. getattr(self, attr).find_matching_objs(self.env, use_location=False)
  160. class ActionInstr(Instr, ABC):
  161. """
  162. Base class for all action instructions (clauses)
  163. """
  164. def __init__(self):
  165. super().__init__()
  166. # Indicates that the action was completed on the last step
  167. self.lastStepMatch = False
  168. def verify(self, action):
  169. """
  170. Verifies actions, with and without the done action.
  171. """
  172. if not use_done_actions:
  173. return self.verify_action(action)
  174. if action == self.env.actions.done:
  175. if self.lastStepMatch:
  176. return "success"
  177. return "failure"
  178. res = self.verify_action(action)
  179. self.lastStepMatch = res == "success"
  180. @abstractmethod
  181. def verify_action(self):
  182. """
  183. Each action instruction class should implement this method
  184. to verify the action.
  185. """
  186. raise NotImplementedError
  187. class OpenInstr(ActionInstr):
  188. def __init__(self, obj_desc, strict=False):
  189. super().__init__()
  190. assert obj_desc.type == "door"
  191. self.desc = obj_desc
  192. self.strict = strict
  193. def surface(self, env):
  194. return "open " + self.desc.surface(env)
  195. def reset_verifier(self, env):
  196. super().reset_verifier(env)
  197. # Identify set of possible matching objects in the environment
  198. self.desc.find_matching_objs(env)
  199. def verify_action(self, action):
  200. # Only verify when the toggle action is performed
  201. if action != self.env.actions.toggle:
  202. return "continue"
  203. # Get the contents of the cell in front of the agent
  204. front_cell = self.env.grid.get(*self.env.front_pos)
  205. for door in self.desc.obj_set:
  206. if front_cell and front_cell is door and door.is_open:
  207. return "success"
  208. # If in strict mode and the wrong door is opened, failure
  209. if self.strict:
  210. if front_cell and front_cell.type == "door":
  211. return "failure"
  212. return "continue"
  213. class GoToInstr(ActionInstr):
  214. """
  215. Go next to (and look towards) an object matching a given description
  216. eg: go to the door
  217. """
  218. def __init__(self, obj_desc):
  219. super().__init__()
  220. self.desc = obj_desc
  221. def surface(self, env):
  222. return "go to " + self.desc.surface(env)
  223. def reset_verifier(self, env):
  224. super().reset_verifier(env)
  225. # Identify set of possible matching objects in the environment
  226. self.desc.find_matching_objs(env)
  227. def verify_action(self, action):
  228. # For each object position
  229. for pos in self.desc.obj_poss:
  230. # If the agent is next to (and facing) the object
  231. if np.array_equal(pos, self.env.front_pos):
  232. return "success"
  233. return "continue"
  234. class PickupInstr(ActionInstr):
  235. """
  236. Pick up an object matching a given description
  237. eg: pick up the grey ball
  238. """
  239. def __init__(self, obj_desc, strict=False):
  240. super().__init__()
  241. assert obj_desc.type != "door"
  242. self.desc = obj_desc
  243. self.strict = strict
  244. def surface(self, env):
  245. return "pick up " + self.desc.surface(env)
  246. def reset_verifier(self, env):
  247. super().reset_verifier(env)
  248. # Object previously being carried
  249. self.preCarrying = None
  250. # Identify set of possible matching objects in the environment
  251. self.desc.find_matching_objs(env)
  252. def verify_action(self, action):
  253. # To keep track of what was carried at the last time step
  254. preCarrying = self.preCarrying
  255. self.preCarrying = self.env.carrying
  256. # Only verify when the pickup action is performed
  257. if action != self.env.actions.pickup:
  258. return "continue"
  259. for obj in self.desc.obj_set:
  260. if preCarrying is None and self.env.carrying is obj:
  261. return "success"
  262. # If in strict mode and the wrong door object is picked up, failure
  263. if self.strict:
  264. if self.env.carrying:
  265. return "failure"
  266. self.preCarrying = self.env.carrying
  267. return "continue"
  268. class PutNextInstr(ActionInstr):
  269. """
  270. Put an object next to another object
  271. eg: put the red ball next to the blue key
  272. """
  273. def __init__(self, obj_move, obj_fixed, strict=False):
  274. super().__init__()
  275. assert obj_move.type != "door"
  276. self.desc_move = obj_move
  277. self.desc_fixed = obj_fixed
  278. self.strict = strict
  279. def surface(self, env):
  280. return (
  281. "put "
  282. + self.desc_move.surface(env)
  283. + " next to "
  284. + self.desc_fixed.surface(env)
  285. )
  286. def reset_verifier(self, env):
  287. super().reset_verifier(env)
  288. # Object previously being carried
  289. self.preCarrying = None
  290. # Identify set of possible matching objects in the environment
  291. self.desc_move.find_matching_objs(env)
  292. self.desc_fixed.find_matching_objs(env)
  293. def objs_next(self):
  294. """
  295. Check if the objects are next to each other
  296. This is used for rejection sampling
  297. """
  298. for obj_a in self.desc_move.obj_set:
  299. pos_a = obj_a.cur_pos
  300. for pos_b in self.desc_fixed.obj_poss:
  301. if pos_next_to(pos_a, pos_b):
  302. return True
  303. return False
  304. def verify_action(self, action):
  305. # To keep track of what was carried at the last time step
  306. preCarrying = self.preCarrying
  307. self.preCarrying = self.env.carrying
  308. # In strict mode, picking up the wrong object fails
  309. if self.strict:
  310. if action == self.env.actions.pickup and self.env.carrying:
  311. return "failure"
  312. # Only verify when the drop action is performed
  313. if action != self.env.actions.drop:
  314. return "continue"
  315. for obj_a in self.desc_move.obj_set:
  316. if preCarrying is not obj_a:
  317. continue
  318. pos_a = obj_a.cur_pos
  319. for pos_b in self.desc_fixed.obj_poss:
  320. if pos_next_to(pos_a, pos_b):
  321. return "success"
  322. return "continue"
  323. class SeqInstr(Instr, ABC):
  324. """
  325. Base class for sequencing instructions (before, after, and)
  326. """
  327. def __init__(self, instr_a, instr_b, strict=False):
  328. assert isinstance(instr_a, ActionInstr) or isinstance(instr_a, AndInstr)
  329. assert isinstance(instr_b, ActionInstr) or isinstance(instr_b, AndInstr)
  330. self.instr_a = instr_a
  331. self.instr_b = instr_b
  332. self.strict = strict
  333. class BeforeInstr(SeqInstr):
  334. """
  335. Sequence two instructions in order:
  336. eg: go to the red door then pick up the blue ball
  337. """
  338. def surface(self, env):
  339. return self.instr_a.surface(env) + ", then " + self.instr_b.surface(env)
  340. def reset_verifier(self, env):
  341. super().reset_verifier(env)
  342. self.instr_a.reset_verifier(env)
  343. self.instr_b.reset_verifier(env)
  344. self.a_done = False
  345. self.b_done = False
  346. def verify(self, action):
  347. if self.a_done == "success":
  348. self.b_done = self.instr_b.verify(action)
  349. if self.b_done == "failure":
  350. return "failure"
  351. if self.b_done == "success":
  352. return "success"
  353. else:
  354. self.a_done = self.instr_a.verify(action)
  355. if self.a_done == "failure":
  356. return "failure"
  357. if self.a_done == "success":
  358. return self.verify(action)
  359. # In strict mode, completing b first means failure
  360. if self.strict:
  361. if self.instr_b.verify(action) == "success":
  362. return "failure"
  363. return "continue"
  364. class AfterInstr(SeqInstr):
  365. """
  366. Sequence two instructions in reverse order:
  367. eg: go to the red door after you pick up the blue ball
  368. """
  369. def surface(self, env):
  370. return self.instr_a.surface(env) + " after you " + self.instr_b.surface(env)
  371. def reset_verifier(self, env):
  372. super().reset_verifier(env)
  373. self.instr_a.reset_verifier(env)
  374. self.instr_b.reset_verifier(env)
  375. self.a_done = False
  376. self.b_done = False
  377. def verify(self, action):
  378. if self.b_done == "success":
  379. self.a_done = self.instr_a.verify(action)
  380. if self.a_done == "success":
  381. return "success"
  382. if self.a_done == "failure":
  383. return "failure"
  384. else:
  385. self.b_done = self.instr_b.verify(action)
  386. if self.b_done == "failure":
  387. return "failure"
  388. if self.b_done == "success":
  389. return self.verify(action)
  390. # In strict mode, completing a first means failure
  391. if self.strict:
  392. if self.instr_a.verify(action) == "success":
  393. return "failure"
  394. return "continue"
  395. class AndInstr(SeqInstr):
  396. """
  397. Conjunction of two actions, both can be completed in any other
  398. eg: go to the red door and pick up the blue ball
  399. """
  400. def __init__(self, instr_a, instr_b, strict=False):
  401. assert isinstance(instr_a, ActionInstr)
  402. assert isinstance(instr_b, ActionInstr)
  403. super().__init__(instr_a, instr_b, strict)
  404. def surface(self, env):
  405. return self.instr_a.surface(env) + " and " + self.instr_b.surface(env)
  406. def reset_verifier(self, env):
  407. super().reset_verifier(env)
  408. self.instr_a.reset_verifier(env)
  409. self.instr_b.reset_verifier(env)
  410. self.a_done = False
  411. self.b_done = False
  412. def verify(self, action):
  413. if self.a_done != "success":
  414. self.a_done = self.instr_a.verify(action)
  415. if self.b_done != "success":
  416. self.b_done = self.instr_b.verify(action)
  417. if use_done_actions and action is self.env.actions.done:
  418. if self.a_done == "failure" and self.b_done == "failure":
  419. return "failure"
  420. if self.a_done == "success" and self.b_done == "success":
  421. return "success"
  422. return "continue"