123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133 |
- #!/usr/bin/env python3
- from __future__ import annotations
- import time
- import gymnasium as gym
- from minigrid.manual_control import ManualControl
- from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
- def benchmark(env_id, num_resets, num_frames):
- env = gym.make(env_id, render_mode="rgb_array")
- # Benchmark env.reset
- t0 = time.time()
- for i in range(num_resets):
- env.reset()
- t1 = time.time()
- dt = t1 - t0
- reset_time = (1000 * dt) / num_resets
- # Benchmark rendering
- t0 = time.time()
- for i in range(num_frames):
- env.render()
- t1 = time.time()
- dt = t1 - t0
- frames_per_sec = num_frames / dt
- # Create an environment with an RGB agent observation
- env = gym.make(env_id, render_mode="rgb_array")
- env = RGBImgPartialObsWrapper(env)
- env = ImgObsWrapper(env)
- env.reset()
- # Benchmark rendering in agent view
- t0 = time.time()
- for i in range(num_frames):
- obs, reward, terminated, truncated, info = env.step(0)
- t1 = time.time()
- dt = t1 - t0
- agent_view_fps = num_frames / dt
- print(f"Env reset time: {reset_time:.1f} ms")
- print(f"Rendering FPS : {frames_per_sec:.0f}")
- print(f"Agent view FPS: {agent_view_fps:.0f}")
- env.close()
- def benchmark_manual_control(env_id, num_resets, num_frames, tile_size):
- env = gym.make(env_id, tile_size=tile_size)
- env = ManualControl(env, seed=args.seed)
- # Benchmark env.reset
- t0 = time.time()
- for i in range(num_resets):
- env.reset()
- t1 = time.time()
- dt = t1 - t0
- reset_time = (1000 * dt) / num_resets
- # Benchmark rendering
- t0 = time.time()
- for i in range(num_frames):
- env.redraw()
- t1 = time.time()
- dt = t1 - t0
- frames_per_sec = num_frames / dt
- # Create an environment with an RGB agent observation
- env = gym.make(env_id, tile_size=tile_size)
- env = RGBImgPartialObsWrapper(env, env.tile_size)
- env = ImgObsWrapper(env)
- env = ManualControl(env, seed=args.seed)
- env.reset()
- # Benchmark rendering in agent view
- t0 = time.time()
- for i in range(num_frames):
- env.step(0)
- t1 = time.time()
- dt = t1 - t0
- agent_view_fps = num_frames / dt
- print(f"Env reset time: {reset_time:.1f} ms")
- print(f"Rendering FPS : {frames_per_sec:.0f}")
- print(f"Agent view FPS: {agent_view_fps:.0f}")
- env.close()
- if __name__ == "__main__":
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--env-id",
- dest="env_id",
- help="gym environment to load",
- default="MiniGrid-LavaGapS7-v0",
- )
- parser.add_argument(
- "--seed",
- type=int,
- help="random seed to generate the environment with",
- default=None,
- )
- parser.add_argument(
- "--num-resets",
- type=int,
- help="number of times to reset the environment for benchmarking",
- default=200,
- )
- parser.add_argument(
- "--num-frames",
- type=int,
- help="number of frames to test rendering for",
- default=5000,
- )
- parser.add_argument(
- "--tile-size", type=int, help="size at which to render tiles", default=32
- )
- args = parser.parse_args()
- benchmark(args.env_id, args.num_resets, args.num_frames)
- benchmark_manual_control(
- args.env_id, args.num_resets, args.num_frames, args.tile_size
- )
|