mission.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. from __future__ import annotations
  2. from typing import Any, Callable
  3. from gymnasium import spaces
  4. from gymnasium.utils import seeding
  5. def check_if_no_duplicate(duplicate_list: list) -> bool:
  6. """Check if given list contains any duplicates"""
  7. return len(set(duplicate_list)) == len(duplicate_list)
  8. class MissionSpace(spaces.Space[str]):
  9. r"""A space representing a mission for the Gym-Minigrid environments.
  10. The space allows generating random mission strings constructed with an input placeholder list.
  11. Example Usage::
  12. >>> observation_space = MissionSpace(mission_func=lambda color: f"Get the {color} ball.",
  13. ordered_placeholders=[["green", "blue"]])
  14. >>> observation_space.sample()
  15. "Get the green ball."
  16. >>> observation_space = MissionSpace(mission_func=lambda : "Get the ball.".,
  17. ordered_placeholders=None)
  18. >>> observation_space.sample()
  19. "Get the ball."
  20. """
  21. def __init__(
  22. self,
  23. mission_func: Callable[..., str],
  24. ordered_placeholders: list[list[str]] | None = None,
  25. seed: int | seeding.RandomNumberGenerator | None = None,
  26. ):
  27. r"""Constructor of :class:`MissionSpace` space.
  28. Args:
  29. mission_func (lambda _placeholders(str): _mission(str)): Function that generates a mission string from random placeholders.
  30. ordered_placeholders (Optional["list[list[str]]"]): List of lists of placeholders ordered in placing order in the mission function mission_func.
  31. seed: seed: The seed for sampling from the space.
  32. """
  33. # Check that the ordered placeholders and mission function are well defined.
  34. if ordered_placeholders is not None:
  35. assert (
  36. len(ordered_placeholders) == mission_func.__code__.co_argcount
  37. ), f"The number of placeholders {len(ordered_placeholders)} is different from the number of parameters in the mission function {mission_func.__code__.co_argcount}."
  38. for placeholder_list in ordered_placeholders:
  39. assert check_if_no_duplicate(
  40. placeholder_list
  41. ), "Make sure that the placeholders don't have any duplicate values."
  42. else:
  43. assert (
  44. mission_func.__code__.co_argcount == 0
  45. ), f"If the ordered placeholders are {ordered_placeholders}, the mission function shouldn't have any parameters."
  46. self.ordered_placeholders = ordered_placeholders
  47. self.mission_func = mission_func
  48. super().__init__(dtype=str, seed=seed)
  49. # Check that mission_func returns a string
  50. sampled_mission = self.sample()
  51. assert isinstance(
  52. sampled_mission, str
  53. ), f"mission_func must return type str not {type(sampled_mission)}"
  54. def sample(self) -> str:
  55. """Sample a random mission string."""
  56. if self.ordered_placeholders is not None:
  57. placeholders = []
  58. for rand_var_list in self.ordered_placeholders:
  59. idx = self.np_random.integers(0, len(rand_var_list))
  60. placeholders.append(rand_var_list[idx])
  61. return self.mission_func(*placeholders)
  62. else:
  63. return self.mission_func()
  64. def contains(self, x: Any) -> bool:
  65. """Return boolean specifying if x is a valid member of this space."""
  66. # Store a list of all the placeholders from self.ordered_placeholders that appear in x
  67. if self.ordered_placeholders is not None:
  68. check_placeholder_list = []
  69. for placeholder_list in self.ordered_placeholders:
  70. for placeholder in placeholder_list:
  71. if placeholder in x:
  72. check_placeholder_list.append(placeholder)
  73. # Remove duplicates from the list
  74. check_placeholder_list = list(set(check_placeholder_list))
  75. start_id_placeholder = []
  76. end_id_placeholder = []
  77. # Get the starting and ending id of the identified placeholders with possible duplicates
  78. new_check_placeholder_list = []
  79. for placeholder in check_placeholder_list:
  80. new_start_id_placeholder = [
  81. i for i in range(len(x)) if x.startswith(placeholder, i)
  82. ]
  83. new_check_placeholder_list += [placeholder] * len(
  84. new_start_id_placeholder
  85. )
  86. end_id_placeholder += [
  87. start_id + len(placeholder) - 1
  88. for start_id in new_start_id_placeholder
  89. ]
  90. start_id_placeholder += new_start_id_placeholder
  91. # Order by starting id the placeholders
  92. ordered_placeholder_list = sorted(
  93. zip(
  94. start_id_placeholder, end_id_placeholder, new_check_placeholder_list
  95. )
  96. )
  97. # Check for repeated placeholders contained in each other
  98. remove_placeholder_id = []
  99. for i, placeholder_1 in enumerate(ordered_placeholder_list):
  100. starting_id = i + 1
  101. for j, placeholder_2 in enumerate(
  102. ordered_placeholder_list[starting_id:]
  103. ):
  104. # Check if place holder ids overlap and keep the longest
  105. if max(placeholder_1[0], placeholder_2[0]) < min(
  106. placeholder_1[1], placeholder_2[1]
  107. ):
  108. remove_placeholder = min(
  109. placeholder_1[2], placeholder_2[2], key=len
  110. )
  111. if remove_placeholder == placeholder_1[2]:
  112. remove_placeholder_id.append(i)
  113. else:
  114. remove_placeholder_id.append(i + j + 1)
  115. for id in remove_placeholder_id:
  116. del ordered_placeholder_list[id]
  117. final_placeholders = [
  118. placeholder[2] for placeholder in ordered_placeholder_list
  119. ]
  120. # Check that the identified final placeholders are in the same order as the original placeholders.
  121. for orered_placeholder, final_placeholder in zip(
  122. self.ordered_placeholders, final_placeholders
  123. ):
  124. if final_placeholder in orered_placeholder:
  125. continue
  126. else:
  127. return False
  128. try:
  129. mission_string_with_placeholders = self.mission_func(
  130. *final_placeholders
  131. )
  132. except Exception as e:
  133. print(
  134. f"{x} is not contained in MissionSpace due to the following exception: {e}"
  135. )
  136. return False
  137. return bool(mission_string_with_placeholders == x)
  138. else:
  139. return bool(self.mission_func() == x)
  140. def __repr__(self) -> str:
  141. """Gives a string representation of this space."""
  142. return f"MissionSpace({self.mission_func}, {self.ordered_placeholders})"
  143. def __eq__(self, other) -> bool:
  144. """Check whether ``other`` is equivalent to this instance."""
  145. if isinstance(other, MissionSpace):
  146. # Check that place holder lists are the same
  147. if self.ordered_placeholders is not None:
  148. # Check length
  149. if (len(self.order_placeholder) == len(other.order_placeholder)) and (
  150. all(
  151. set(i) == set(j)
  152. for i, j in zip(self.order_placeholder, other.order_placeholder)
  153. )
  154. ):
  155. # Check mission string is the same with dummy space placeholders
  156. test_placeholders = [""] * len(self.order_placeholder)
  157. mission = self.mission_func(*test_placeholders)
  158. other_mission = other.mission_func(*test_placeholders)
  159. return mission == other_mission
  160. else:
  161. # Check that other is also None
  162. if other.ordered_placeholders is None:
  163. # Check mission string is the same
  164. mission = self.mission_func()
  165. other_mission = other.mission_func()
  166. return mission == other_mission
  167. # If none of the statements above return then False
  168. return False