test_baby_ai_bot.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. from __future__ import annotations
  2. import gymnasium as gym
  3. import pytest
  4. from minigrid.utils.baby_ai_bot import BabyAIBot
  5. # see discussion starting here: https://github.com/Farama-Foundation/Minigrid/pull/381#issuecomment-1646800992
  6. broken_bonus_envs = {
  7. "BabyAI-PutNextS5N2Carrying-v0",
  8. "BabyAI-PutNextS6N3Carrying-v0",
  9. "BabyAI-PutNextS7N4Carrying-v0",
  10. "BabyAI-KeyInBox-v0",
  11. }
  12. # get all babyai envs (except the broken ones)
  13. babyai_envs = []
  14. for k_i in gym.envs.registry.keys():
  15. if k_i.split("-")[0] == "BabyAI":
  16. if k_i not in broken_bonus_envs:
  17. babyai_envs.append(k_i)
  18. @pytest.mark.parametrize("env_id", babyai_envs)
  19. def test_bot(env_id):
  20. """
  21. The BabyAI Bot should be able to solve all BabyAI environments,
  22. allowing us therefore to generate demonstrations.
  23. """
  24. # Use the parameter env_id to make the environment
  25. env = gym.make(env_id)
  26. # env = gym.make(env_id, render_mode="human") # for visual debugging
  27. # reset env
  28. curr_seed = 0
  29. num_steps = 240
  30. terminated = False
  31. while not terminated:
  32. env.reset(seed=curr_seed)
  33. # create expert bot
  34. expert = BabyAIBot(env)
  35. last_action = None
  36. for _step in range(num_steps):
  37. action = expert.replan(last_action)
  38. obs, reward, terminated, truncated, info = env.step(action)
  39. last_action = action
  40. env.render()
  41. if terminated:
  42. break
  43. # try again with a different seed
  44. curr_seed += 1
  45. env.close()