benchmark.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. #!/usr/bin/env python3
  2. import time
  3. import gym
  4. from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
  5. def benchmark(env_id, num_resets, num_frames):
  6. env = gym.make(env_id, render_mode="rgb_array")
  7. # Benchmark env.reset
  8. t0 = time.time()
  9. for i in range(num_resets):
  10. env.reset()
  11. t1 = time.time()
  12. dt = t1 - t0
  13. reset_time = (1000 * dt) / num_resets
  14. # Benchmark rendering
  15. t0 = time.time()
  16. for i in range(num_frames):
  17. env.render()
  18. t1 = time.time()
  19. dt = t1 - t0
  20. frames_per_sec = num_frames / dt
  21. # Create an environment with an RGB agent observation
  22. env = gym.make(env_id, render_mode="rgb_array")
  23. env = RGBImgPartialObsWrapper(env)
  24. env = ImgObsWrapper(env)
  25. env.reset()
  26. # Benchmark rendering
  27. t0 = time.time()
  28. for i in range(num_frames):
  29. obs, reward, terminated, truncated, info = env.step(0)
  30. t1 = time.time()
  31. dt = t1 - t0
  32. agent_view_fps = num_frames / dt
  33. print(f"Env reset time: {reset_time:.1f} ms")
  34. print(f"Rendering FPS : {frames_per_sec:.0f}")
  35. print(f"Agent view FPS: {agent_view_fps:.0f}")
  36. env.close()
  37. if __name__ == "__main__":
  38. import argparse
  39. parser = argparse.ArgumentParser()
  40. parser.add_argument(
  41. "--env-id",
  42. dest="env_id",
  43. help="gym environment to load",
  44. default="MiniGrid-LavaGapS7-v0",
  45. )
  46. parser.add_argument("--num_resets", default=200)
  47. parser.add_argument("--num_frames", default=5000)
  48. args = parser.parse_args()
  49. benchmark(args.env_id, args.num_resets, args.num_frames)