浏览代码

Continuation of: Add babyai bot and test (#381)

Co-authored-by: GilgameshD <wenhaod@andrew.cmu.edu>
Giulio Starace 1 年之前
父节点
当前提交
9dbdf61235
共有 4 个文件被更改,包括 1090 次插入0 次删除
  1. 3 0
      minigrid/envs/babyai/putnext.py
  2. 3 0
      minigrid/envs/babyai/unlock.py
  3. 1026 0
      minigrid/utils/baby_ai_bot.py
  4. 58 0
      tests/test_baby_ai_bot.py

+ 3 - 0
minigrid/envs/babyai/putnext.py

@@ -139,6 +139,9 @@ class PutNext(RoomGridLevel):
     - `BabyAI-PutNextS6N3Carrying-v0`
     - `BabyAI-PutNextS7N4Carrying-v0`
 
+    ## Additional Notes
+
+    The BabyAI bot is unable to solve the bonus PutNextCarrying configurations.
     """
 
     def __init__(

+ 3 - 0
minigrid/envs/babyai/unlock.py

@@ -220,6 +220,9 @@ class KeyInBox(RoomGridLevel):
 
     - `BabyAI-KeyInBox-v0`
 
+    ## Additional Notes
+
+    The BabyAI bot is unable to solve this level.
     """
 
     def __init__(self, **kwargs):

文件差异内容过多而无法显示
+ 1026 - 0
minigrid/utils/baby_ai_bot.py


+ 58 - 0
tests/test_baby_ai_bot.py

@@ -0,0 +1,58 @@
+from __future__ import annotations
+
+import gymnasium as gym
+import pytest
+
+from minigrid.utils.baby_ai_bot import BabyAIBot
+
+# see discussion starting here: https://github.com/Farama-Foundation/Minigrid/pull/381#issuecomment-1646800992
+broken_bonus_envs = {
+    "BabyAI-PutNextS5N2Carrying-v0",
+    "BabyAI-PutNextS6N3Carrying-v0",
+    "BabyAI-PutNextS7N4Carrying-v0",
+    "BabyAI-KeyInBox-v0",
+}
+
+# get all babyai envs (except the broken ones)
+babyai_envs = []
+for k_i in gym.envs.registry.keys():
+    if k_i.split("-")[0] == "BabyAI":
+        if k_i not in broken_bonus_envs:
+            babyai_envs.append(k_i)
+
+
+@pytest.mark.parametrize("env_id", babyai_envs)
+def test_bot(env_id):
+    """
+    The BabyAI Bot should be able to solve all BabyAI environments,
+    allowing us therefore to generate demonstrations.
+    """
+    # Use the parameter env_id to make the environment
+    env = gym.make(env_id)
+    # env = gym.make(env_id, render_mode="human") # for visual debugging
+
+    # reset env
+    curr_seed = 0
+
+    num_steps = 240
+    terminated = False
+    while not terminated:
+        env.reset(seed=curr_seed)
+
+        # create expert bot
+        expert = BabyAIBot(env)
+
+        last_action = None
+        for _step in range(num_steps):
+            action = expert.replan(last_action)
+            obs, reward, terminated, truncated, info = env.step(action)
+            last_action = action
+            env.render()
+
+            if terminated:
+                break
+
+        # try again with a different seed
+        curr_seed += 1
+
+    env.close()