benchmark.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. #!/usr/bin/env python3
  2. import argparse
  3. import time
  4. import gym
  5. from gym_minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
  6. parser = argparse.ArgumentParser()
  7. parser.add_argument(
  8. "--env-name",
  9. dest="env_name",
  10. help="gym environment to load",
  11. default="MiniGrid-LavaGapS7-v0",
  12. )
  13. parser.add_argument("--num_resets", default=200)
  14. parser.add_argument("--num_frames", default=5000)
  15. args = parser.parse_args()
  16. env = gym.make(args.env_name)
  17. # Benchmark env.reset
  18. t0 = time.time()
  19. for i in range(args.num_resets):
  20. env.reset()
  21. t1 = time.time()
  22. dt = t1 - t0
  23. reset_time = (1000 * dt) / args.num_resets
  24. # Benchmark rendering
  25. t0 = time.time()
  26. for i in range(args.num_frames):
  27. env.render("rgb_array")
  28. t1 = time.time()
  29. dt = t1 - t0
  30. frames_per_sec = args.num_frames / dt
  31. # Create an environment with an RGB agent observation
  32. env = gym.make(args.env_name)
  33. env = RGBImgPartialObsWrapper(env)
  34. env = ImgObsWrapper(env)
  35. env.reset()
  36. # Benchmark rendering
  37. t0 = time.time()
  38. for i in range(args.num_frames):
  39. obs, reward, done, info = env.step(0)
  40. t1 = time.time()
  41. dt = t1 - t0
  42. agent_view_fps = args.num_frames / dt
  43. print(f"Env reset time: {reset_time:.1f} ms")
  44. print(f"Rendering FPS : {frames_per_sec:.0f}")
  45. print(f"Agent view FPS: {agent_view_fps:.0f}")