mission.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. from __future__ import annotations
  2. from typing import Any, Callable
  3. from gymnasium import spaces
  4. from gymnasium.utils import seeding
  5. def check_if_no_duplicate(duplicate_list: list) -> bool:
  6. """Check if given list contains any duplicates"""
  7. return len(set(duplicate_list)) == len(duplicate_list)
  8. class MissionSpace(spaces.Space[str]):
  9. r"""A space representing a mission for the Gym-Minigrid environments.
  10. The space allows generating random mission strings constructed with an input placeholder list.
  11. Example Usage::
  12. >>> observation_space = MissionSpace(mission_func=lambda color: f"Get the {color} ball.",
  13. ... ordered_placeholders=[["green", "blue"]])
  14. >>> _ = observation_space.seed(123)
  15. >>> observation_space.sample()
  16. 'Get the green ball.'
  17. >>> observation_space = MissionSpace(mission_func=lambda : "Get the ball.",
  18. ... ordered_placeholders=None)
  19. >>> observation_space.sample()
  20. 'Get the ball.'
  21. """
  22. def __init__(
  23. self,
  24. mission_func: Callable[..., str],
  25. ordered_placeholders: list[list[str]] | None = None,
  26. seed: int | seeding.RandomNumberGenerator | None = None,
  27. ):
  28. r"""Constructor of :class:`MissionSpace` space.
  29. Args:
  30. mission_func (lambda _placeholders(str): _mission(str)): Function that generates a mission string from random placeholders.
  31. ordered_placeholders (Optional["list[list[str]]"]): List of lists of placeholders ordered in placing order in the mission function mission_func.
  32. seed: seed: The seed for sampling from the space.
  33. """
  34. # Check that the ordered placeholders and mission function are well defined.
  35. if ordered_placeholders is not None:
  36. assert (
  37. len(ordered_placeholders) == mission_func.__code__.co_argcount
  38. ), f"The number of placeholders {len(ordered_placeholders)} is different from the number of parameters in the mission function {mission_func.__code__.co_argcount}."
  39. for placeholder_list in ordered_placeholders:
  40. assert check_if_no_duplicate(
  41. placeholder_list
  42. ), "Make sure that the placeholders don't have any duplicate values."
  43. else:
  44. assert (
  45. mission_func.__code__.co_argcount == 0
  46. ), f"If the ordered placeholders are {ordered_placeholders}, the mission function shouldn't have any parameters."
  47. self.ordered_placeholders = ordered_placeholders
  48. self.mission_func = mission_func
  49. super().__init__(dtype=str, seed=seed)
  50. # Check that mission_func returns a string
  51. sampled_mission = self.sample()
  52. assert isinstance(
  53. sampled_mission, str
  54. ), f"mission_func must return type str not {type(sampled_mission)}"
  55. def sample(self) -> str:
  56. """Sample a random mission string."""
  57. if self.ordered_placeholders is not None:
  58. placeholders = []
  59. for rand_var_list in self.ordered_placeholders:
  60. idx = self.np_random.integers(0, len(rand_var_list))
  61. placeholders.append(rand_var_list[idx])
  62. return self.mission_func(*placeholders)
  63. else:
  64. return self.mission_func()
  65. def contains(self, x: Any) -> bool:
  66. """Return boolean specifying if x is a valid member of this space."""
  67. # Store a list of all the placeholders from self.ordered_placeholders that appear in x
  68. if self.ordered_placeholders is not None:
  69. check_placeholder_list = []
  70. for placeholder_list in self.ordered_placeholders:
  71. for placeholder in placeholder_list:
  72. if placeholder in x:
  73. check_placeholder_list.append(placeholder)
  74. # Remove duplicates from the list
  75. check_placeholder_list = list(set(check_placeholder_list))
  76. start_id_placeholder = []
  77. end_id_placeholder = []
  78. # Get the starting and ending id of the identified placeholders with possible duplicates
  79. new_check_placeholder_list = []
  80. for placeholder in check_placeholder_list:
  81. new_start_id_placeholder = [
  82. i for i in range(len(x)) if x.startswith(placeholder, i)
  83. ]
  84. new_check_placeholder_list += [placeholder] * len(
  85. new_start_id_placeholder
  86. )
  87. end_id_placeholder += [
  88. start_id + len(placeholder) - 1
  89. for start_id in new_start_id_placeholder
  90. ]
  91. start_id_placeholder += new_start_id_placeholder
  92. # Order by starting id the placeholders
  93. ordered_placeholder_list = sorted(
  94. zip(
  95. start_id_placeholder, end_id_placeholder, new_check_placeholder_list
  96. )
  97. )
  98. # Check for repeated placeholders contained in each other
  99. remove_placeholder_id = []
  100. for i, placeholder_1 in enumerate(ordered_placeholder_list):
  101. starting_id = i + 1
  102. for j, placeholder_2 in enumerate(
  103. ordered_placeholder_list[starting_id:]
  104. ):
  105. # Check if place holder ids overlap and keep the longest
  106. if max(placeholder_1[0], placeholder_2[0]) < min(
  107. placeholder_1[1], placeholder_2[1]
  108. ):
  109. remove_placeholder = min(
  110. placeholder_1[2], placeholder_2[2], key=len
  111. )
  112. if remove_placeholder == placeholder_1[2]:
  113. remove_placeholder_id.append(i)
  114. else:
  115. remove_placeholder_id.append(i + j + 1)
  116. for id in remove_placeholder_id:
  117. del ordered_placeholder_list[id]
  118. final_placeholders = [
  119. placeholder[2] for placeholder in ordered_placeholder_list
  120. ]
  121. # Check that the identified final placeholders are in the same order as the original placeholders.
  122. for orered_placeholder, final_placeholder in zip(
  123. self.ordered_placeholders, final_placeholders
  124. ):
  125. if final_placeholder in orered_placeholder:
  126. continue
  127. else:
  128. return False
  129. try:
  130. mission_string_with_placeholders = self.mission_func(
  131. *final_placeholders
  132. )
  133. except Exception as e:
  134. print(
  135. f"{x} is not contained in MissionSpace due to the following exception: {e}"
  136. )
  137. return False
  138. return bool(mission_string_with_placeholders == x)
  139. else:
  140. return bool(self.mission_func() == x)
  141. def __repr__(self) -> str:
  142. """Gives a string representation of this space."""
  143. return f"MissionSpace({self.mission_func}, {self.ordered_placeholders})"
  144. def __eq__(self, other) -> bool:
  145. """Check whether ``other`` is equivalent to this instance."""
  146. if isinstance(other, MissionSpace):
  147. # Check that place holder lists are the same
  148. if self.ordered_placeholders is not None:
  149. # Check length
  150. if (
  151. len(self.ordered_placeholders) == len(other.ordered_placeholders)
  152. ) and (
  153. all(
  154. set(i) == set(j)
  155. for i, j in zip(
  156. self.ordered_placeholders, other.ordered_placeholders
  157. )
  158. )
  159. ):
  160. # Check mission string is the same with dummy space placeholders
  161. test_placeholders = [""] * len(self.ordered_placeholders)
  162. mission = self.mission_func(*test_placeholders)
  163. other_mission = other.mission_func(*test_placeholders)
  164. return mission == other_mission
  165. else:
  166. # Check that other is also None
  167. if other.ordered_placeholders is None:
  168. # Check mission string is the same
  169. mission = self.mission_func()
  170. other_mission = other.mission_func()
  171. return mission == other_mission
  172. # If none of the statements above return then False
  173. return False