benchmark.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3. import time
  4. import gymnasium as gym
  5. from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
  6. def benchmark(env_id, num_resets, num_frames):
  7. env = gym.make(env_id, render_mode="rgb_array")
  8. # Benchmark env.reset
  9. t0 = time.time()
  10. for i in range(num_resets):
  11. env.reset()
  12. t1 = time.time()
  13. dt = t1 - t0
  14. reset_time = (1000 * dt) / num_resets
  15. # Benchmark rendering
  16. t0 = time.time()
  17. for i in range(num_frames):
  18. env.render()
  19. t1 = time.time()
  20. dt = t1 - t0
  21. frames_per_sec = num_frames / dt
  22. # Create an environment with an RGB agent observation
  23. env = gym.make(env_id, render_mode="rgb_array")
  24. env = RGBImgPartialObsWrapper(env)
  25. env = ImgObsWrapper(env)
  26. env.reset()
  27. # Benchmark rendering
  28. t0 = time.time()
  29. for i in range(num_frames):
  30. obs, reward, terminated, truncated, info = env.step(0)
  31. t1 = time.time()
  32. dt = t1 - t0
  33. agent_view_fps = num_frames / dt
  34. print(f"Env reset time: {reset_time:.1f} ms")
  35. print(f"Rendering FPS : {frames_per_sec:.0f}")
  36. print(f"Agent view FPS: {agent_view_fps:.0f}")
  37. env.close()
  38. if __name__ == "__main__":
  39. import argparse
  40. parser = argparse.ArgumentParser()
  41. parser.add_argument(
  42. "--env-id",
  43. dest="env_id",
  44. help="gym environment to load",
  45. default="MiniGrid-LavaGapS7-v0",
  46. )
  47. parser.add_argument("--num_resets", default=200)
  48. parser.add_argument("--num_frames", default=5000)
  49. args = parser.parse_args()
  50. benchmark(args.env_id, args.num_resets, args.num_frames)