mission.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. from typing import Any, Callable, Optional, Union
  2. from gymnasium import spaces
  3. from gymnasium.utils import seeding
  4. def check_if_no_duplicate(duplicate_list: list) -> bool:
  5. """Check if given list contains any duplicates"""
  6. return len(set(duplicate_list)) == len(duplicate_list)
  7. class MissionSpace(spaces.Space[str]):
  8. r"""A space representing a mission for the Gym-Minigrid environments.
  9. The space allows generating random mission strings constructed with an input placeholder list.
  10. Example Usage::
  11. >>> observation_space = MissionSpace(mission_func=lambda color: f"Get the {color} ball.",
  12. ordered_placeholders=[["green", "blue"]])
  13. >>> observation_space.sample()
  14. "Get the green ball."
  15. >>> observation_space = MissionSpace(mission_func=lambda : "Get the ball.".,
  16. ordered_placeholders=None)
  17. >>> observation_space.sample()
  18. "Get the ball."
  19. """
  20. def __init__(
  21. self,
  22. mission_func: Callable[..., str],
  23. ordered_placeholders: Optional["list[list[str]]"] = None,
  24. seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
  25. ):
  26. r"""Constructor of :class:`MissionSpace` space.
  27. Args:
  28. mission_func (lambda _placeholders(str): _mission(str)): Function that generates a mission string from random placeholders.
  29. ordered_placeholders (Optional["list[list[str]]"]): List of lists of placeholders ordered in placing order in the mission function mission_func.
  30. seed: seed: The seed for sampling from the space.
  31. """
  32. # Check that the ordered placeholders and mission function are well defined.
  33. if ordered_placeholders is not None:
  34. assert (
  35. len(ordered_placeholders) == mission_func.__code__.co_argcount
  36. ), f"The number of placeholders {len(ordered_placeholders)} is different from the number of parameters in the mission function {mission_func.__code__.co_argcount}."
  37. for placeholder_list in ordered_placeholders:
  38. assert check_if_no_duplicate(
  39. placeholder_list
  40. ), "Make sure that the placeholders don't have any duplicate values."
  41. else:
  42. assert (
  43. mission_func.__code__.co_argcount == 0
  44. ), f"If the ordered placeholders are {ordered_placeholders}, the mission function shouldn't have any parameters."
  45. self.ordered_placeholders = ordered_placeholders
  46. self.mission_func = mission_func
  47. super().__init__(dtype=str, seed=seed)
  48. # Check that mission_func returns a string
  49. sampled_mission = self.sample()
  50. assert isinstance(
  51. sampled_mission, str
  52. ), f"mission_func must return type str not {type(sampled_mission)}"
  53. def sample(self) -> str:
  54. """Sample a random mission string."""
  55. if self.ordered_placeholders is not None:
  56. placeholders = []
  57. for rand_var_list in self.ordered_placeholders:
  58. idx = self.np_random.integers(0, len(rand_var_list))
  59. placeholders.append(rand_var_list[idx])
  60. return self.mission_func(*placeholders)
  61. else:
  62. return self.mission_func()
  63. def contains(self, x: Any) -> bool:
  64. """Return boolean specifying if x is a valid member of this space."""
  65. # Store a list of all the placeholders from self.ordered_placeholders that appear in x
  66. if self.ordered_placeholders is not None:
  67. check_placeholder_list = []
  68. for placeholder_list in self.ordered_placeholders:
  69. for placeholder in placeholder_list:
  70. if placeholder in x:
  71. check_placeholder_list.append(placeholder)
  72. # Remove duplicates from the list
  73. check_placeholder_list = list(set(check_placeholder_list))
  74. start_id_placeholder = []
  75. end_id_placeholder = []
  76. # Get the starting and ending id of the identified placeholders with possible duplicates
  77. new_check_placeholder_list = []
  78. for placeholder in check_placeholder_list:
  79. new_start_id_placeholder = [
  80. i for i in range(len(x)) if x.startswith(placeholder, i)
  81. ]
  82. new_check_placeholder_list += [placeholder] * len(
  83. new_start_id_placeholder
  84. )
  85. end_id_placeholder += [
  86. start_id + len(placeholder) - 1
  87. for start_id in new_start_id_placeholder
  88. ]
  89. start_id_placeholder += new_start_id_placeholder
  90. # Order by starting id the placeholders
  91. ordered_placeholder_list = sorted(
  92. zip(
  93. start_id_placeholder, end_id_placeholder, new_check_placeholder_list
  94. )
  95. )
  96. # Check for repeated placeholders contained in each other
  97. remove_placeholder_id = []
  98. for i, placeholder_1 in enumerate(ordered_placeholder_list):
  99. starting_id = i + 1
  100. for j, placeholder_2 in enumerate(
  101. ordered_placeholder_list[starting_id:]
  102. ):
  103. # Check if place holder ids overlap and keep the longest
  104. if max(placeholder_1[0], placeholder_2[0]) < min(
  105. placeholder_1[1], placeholder_2[1]
  106. ):
  107. remove_placeholder = min(
  108. placeholder_1[2], placeholder_2[2], key=len
  109. )
  110. if remove_placeholder == placeholder_1[2]:
  111. remove_placeholder_id.append(i)
  112. else:
  113. remove_placeholder_id.append(i + j + 1)
  114. for id in remove_placeholder_id:
  115. del ordered_placeholder_list[id]
  116. final_placeholders = [
  117. placeholder[2] for placeholder in ordered_placeholder_list
  118. ]
  119. # Check that the identified final placeholders are in the same order as the original placeholders.
  120. for orered_placeholder, final_placeholder in zip(
  121. self.ordered_placeholders, final_placeholders
  122. ):
  123. if final_placeholder in orered_placeholder:
  124. continue
  125. else:
  126. return False
  127. try:
  128. mission_string_with_placeholders = self.mission_func(
  129. *final_placeholders
  130. )
  131. except Exception as e:
  132. print(
  133. f"{x} is not contained in MissionSpace due to the following exception: {e}"
  134. )
  135. return False
  136. return bool(mission_string_with_placeholders == x)
  137. else:
  138. return bool(self.mission_func() == x)
  139. def __repr__(self) -> str:
  140. """Gives a string representation of this space."""
  141. return f"MissionSpace({self.mission_func}, {self.ordered_placeholders})"
  142. def __eq__(self, other) -> bool:
  143. """Check whether ``other`` is equivalent to this instance."""
  144. if isinstance(other, MissionSpace):
  145. # Check that place holder lists are the same
  146. if self.ordered_placeholders is not None:
  147. # Check length
  148. if (len(self.order_placeholder) == len(other.order_placeholder)) and (
  149. all(
  150. set(i) == set(j)
  151. for i, j in zip(self.order_placeholder, other.order_placeholder)
  152. )
  153. ):
  154. # Check mission string is the same with dummy space placeholders
  155. test_placeholders = [""] * len(self.order_placeholder)
  156. mission = self.mission_func(*test_placeholders)
  157. other_mission = other.mission_func(*test_placeholders)
  158. return mission == other_mission
  159. else:
  160. # Check that other is also None
  161. if other.ordered_placeholders is None:
  162. # Check mission string is the same
  163. mission = self.mission_func()
  164. other_mission = other.mission_func()
  165. return mission == other_mission
  166. # If none of the statements above return then False
  167. return False