瀏覽代碼

add mission space pytest

Rodrigo Perez-Vicente 2 年之前
父節點
當前提交
d74d116572
共有 2 個文件被更改,包括 57 次插入1 次删除
  1. 9 0
      gym_minigrid/minigrid.py
  2. 48 1
      tests/test_envs.py

+ 9 - 0
gym_minigrid/minigrid.py

@@ -161,6 +161,7 @@ class MissionSpace(spaces.Space[str]):
                 for placeholder in placeholder_list:
                     if placeholder in x:
                         check_placeholder_list.append(placeholder)
+
             # Remove duplicates from the list
             check_placeholder_list = list(set(check_placeholder_list))
 
@@ -213,6 +214,14 @@ class MissionSpace(spaces.Space[str]):
                 placeholder[2] for placeholder in ordered_placeholder_list
             ]
 
+            # Check that the identified final placeholders are in the same order as the original placeholders.
+            for orered_placeholder, final_placeholder in zip(
+                self.ordered_placeholders, final_placeholders
+            ):
+                if final_placeholder in orered_placeholder:
+                    continue
+                else:
+                    return False
             try:
                 mission_string_with_placeholders = self.mission_func(
                     *final_placeholders

+ 48 - 1
tests/test_envs.py

@@ -4,7 +4,7 @@ import pytest
 from gym.envs.registration import EnvSpec
 from gym.utils.env_checker import check_env
 
-from gym_minigrid.minigrid import Grid
+from gym_minigrid.minigrid import Grid, MissionSpace
 from tests.utils import all_testing_env_specs, assert_equals
 
 CHECK_ENV_IGNORE_WARNINGS = [
@@ -205,3 +205,50 @@ def test_interactive_mode(env_id):
 
     # Test the close method
     env.close()
+
+
+def test_mission_space():
+
+    # Test placeholders
+    mission_space = MissionSpace(
+        mission_func=lambda color, obj_type: f"Get the {color} {obj_type}.",
+        ordered_placeholders=[["green", "red"], ["ball", "key"]],
+    )
+
+    assert mission_space.contains("Get the green ball.")
+    assert mission_space.contains("Get the red key.")
+    assert not mission_space.contains("Get the purple box.")
+
+    # Test passing inverted placeholders
+    assert not mission_space.contains("Get the key red.")
+
+    # Test passing extra repeated placeholders
+    assert not mission_space.contains("Get the key red key.")
+
+    # Test contained placeholders like "get the" and "go get the". "get the" string is contained in both placeholders.
+    mission_space = MissionSpace(
+        mission_func=lambda get_syntax, obj_type: f"{get_syntax} {obj_type}.",
+        ordered_placeholders=[
+            ["go get the", "get the", "go fetch the", "fetch the"],
+            ["ball", "key"],
+        ],
+    )
+
+    assert mission_space.contains("get the ball.")
+    assert mission_space.contains("go get the key.")
+    assert mission_space.contains("go fetch the ball.")
+
+    # Test repeated placeholders
+    mission_space = MissionSpace(
+        mission_func=lambda get_syntax, color_1, obj_type_1, color_2, obj_type_2: f"{get_syntax} {color_1} {obj_type_1} and the {color_2} {obj_type_2}.",
+        ordered_placeholders=[
+            ["go get the", "get the", "go fetch the", "fetch the"],
+            ["green", "red"],
+            ["ball", "key"],
+            ["green", "red"],
+            ["ball", "key"],
+        ],
+    )
+
+    assert mission_space.contains("get the green key and the green key.")
+    assert mission_space.contains("go fetch the red ball and the green key.")