123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196 |
- from typing import Any, Callable, Optional, Union
- from gymnasium import spaces
- from gymnasium.utils import seeding
- def check_if_no_duplicate(duplicate_list: list) -> bool:
- """Check if given list contains any duplicates"""
- return len(set(duplicate_list)) == len(duplicate_list)
- class MissionSpace(spaces.Space[str]):
- r"""A space representing a mission for the Gym-Minigrid environments.
- The space allows generating random mission strings constructed with an input placeholder list.
- Example Usage::
- >>> observation_space = MissionSpace(mission_func=lambda color: f"Get the {color} ball.",
- ordered_placeholders=[["green", "blue"]])
- >>> observation_space.sample()
- "Get the green ball."
- >>> observation_space = MissionSpace(mission_func=lambda : "Get the ball.".,
- ordered_placeholders=None)
- >>> observation_space.sample()
- "Get the ball."
- """
- def __init__(
- self,
- mission_func: Callable[..., str],
- ordered_placeholders: Optional["list[list[str]]"] = None,
- seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
- ):
- r"""Constructor of :class:`MissionSpace` space.
- Args:
- mission_func (lambda _placeholders(str): _mission(str)): Function that generates a mission string from random placeholders.
- ordered_placeholders (Optional["list[list[str]]"]): List of lists of placeholders ordered in placing order in the mission function mission_func.
- seed: seed: The seed for sampling from the space.
- """
- # Check that the ordered placeholders and mission function are well defined.
- if ordered_placeholders is not None:
- assert (
- len(ordered_placeholders) == mission_func.__code__.co_argcount
- ), f"The number of placeholders {len(ordered_placeholders)} is different from the number of parameters in the mission function {mission_func.__code__.co_argcount}."
- for placeholder_list in ordered_placeholders:
- assert check_if_no_duplicate(
- placeholder_list
- ), "Make sure that the placeholders don't have any duplicate values."
- else:
- assert (
- mission_func.__code__.co_argcount == 0
- ), f"If the ordered placeholders are {ordered_placeholders}, the mission function shouldn't have any parameters."
- self.ordered_placeholders = ordered_placeholders
- self.mission_func = mission_func
- super().__init__(dtype=str, seed=seed)
- # Check that mission_func returns a string
- sampled_mission = self.sample()
- assert isinstance(
- sampled_mission, str
- ), f"mission_func must return type str not {type(sampled_mission)}"
- def sample(self) -> str:
- """Sample a random mission string."""
- if self.ordered_placeholders is not None:
- placeholders = []
- for rand_var_list in self.ordered_placeholders:
- idx = self.np_random.integers(0, len(rand_var_list))
- placeholders.append(rand_var_list[idx])
- return self.mission_func(*placeholders)
- else:
- return self.mission_func()
- def contains(self, x: Any) -> bool:
- """Return boolean specifying if x is a valid member of this space."""
- # Store a list of all the placeholders from self.ordered_placeholders that appear in x
- if self.ordered_placeholders is not None:
- check_placeholder_list = []
- for placeholder_list in self.ordered_placeholders:
- for placeholder in placeholder_list:
- if placeholder in x:
- check_placeholder_list.append(placeholder)
- # Remove duplicates from the list
- check_placeholder_list = list(set(check_placeholder_list))
- start_id_placeholder = []
- end_id_placeholder = []
- # Get the starting and ending id of the identified placeholders with possible duplicates
- new_check_placeholder_list = []
- for placeholder in check_placeholder_list:
- new_start_id_placeholder = [
- i for i in range(len(x)) if x.startswith(placeholder, i)
- ]
- new_check_placeholder_list += [placeholder] * len(
- new_start_id_placeholder
- )
- end_id_placeholder += [
- start_id + len(placeholder) - 1
- for start_id in new_start_id_placeholder
- ]
- start_id_placeholder += new_start_id_placeholder
- # Order by starting id the placeholders
- ordered_placeholder_list = sorted(
- zip(
- start_id_placeholder, end_id_placeholder, new_check_placeholder_list
- )
- )
- # Check for repeated placeholders contained in each other
- remove_placeholder_id = []
- for i, placeholder_1 in enumerate(ordered_placeholder_list):
- starting_id = i + 1
- for j, placeholder_2 in enumerate(
- ordered_placeholder_list[starting_id:]
- ):
- # Check if place holder ids overlap and keep the longest
- if max(placeholder_1[0], placeholder_2[0]) < min(
- placeholder_1[1], placeholder_2[1]
- ):
- remove_placeholder = min(
- placeholder_1[2], placeholder_2[2], key=len
- )
- if remove_placeholder == placeholder_1[2]:
- remove_placeholder_id.append(i)
- else:
- remove_placeholder_id.append(i + j + 1)
- for id in remove_placeholder_id:
- del ordered_placeholder_list[id]
- final_placeholders = [
- placeholder[2] for placeholder in ordered_placeholder_list
- ]
- # Check that the identified final placeholders are in the same order as the original placeholders.
- for orered_placeholder, final_placeholder in zip(
- self.ordered_placeholders, final_placeholders
- ):
- if final_placeholder in orered_placeholder:
- continue
- else:
- return False
- try:
- mission_string_with_placeholders = self.mission_func(
- *final_placeholders
- )
- except Exception as e:
- print(
- f"{x} is not contained in MissionSpace due to the following exception: {e}"
- )
- return False
- return bool(mission_string_with_placeholders == x)
- else:
- return bool(self.mission_func() == x)
- def __repr__(self) -> str:
- """Gives a string representation of this space."""
- return f"MissionSpace({self.mission_func}, {self.ordered_placeholders})"
- def __eq__(self, other) -> bool:
- """Check whether ``other`` is equivalent to this instance."""
- if isinstance(other, MissionSpace):
- # Check that place holder lists are the same
- if self.ordered_placeholders is not None:
- # Check length
- if (len(self.order_placeholder) == len(other.order_placeholder)) and (
- all(
- set(i) == set(j)
- for i, j in zip(self.order_placeholder, other.order_placeholder)
- )
- ):
- # Check mission string is the same with dummy space placeholders
- test_placeholders = [""] * len(self.order_placeholder)
- mission = self.mission_func(*test_placeholders)
- other_mission = other.mission_func(*test_placeholders)
- return mission == other_mission
- else:
- # Check that other is also None
- if other.ordered_placeholders is None:
- # Check mission string is the same
- mission = self.mission_func()
- other_mission = other.mission_func()
- return mission == other_mission
- # If none of the statements above return then False
- return False
|