gen_gifs.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import os
  2. import re
  3. import gymnasium
  4. from PIL import Image
  5. from tqdm import tqdm
  6. # snake to camel case: https://stackoverflow.com/questions/1175208/elegant-python-function-to-convert-camelcase-to-snake-case # noqa: E501
  7. pattern = re.compile(r"(?<!^)(?=[A-Z])")
  8. # how many steps to record an env for
  9. LENGTH = 300
  10. output_dir = os.path.join(os.path.dirname(__file__), "..", "_static", "videos")
  11. os.makedirs(output_dir, exist_ok=True)
  12. # Some environments have multiple versions
  13. # For example, KeyCorridorEnv -> KeyCorridorS3R1, KeyCorridorS3R2, KeyCorridorS3R3, etc
  14. # We only want one as an example
  15. envs_completed = []
  16. # iterate through all envspecs
  17. for env_spec in tqdm(gymnasium.envs.registry.values()):
  18. # minigrid.envs:Env or minigrid.envs.babyai:Env
  19. split = env_spec.entry_point.split(".")
  20. # ignore minigrid.envs.env_type:Env
  21. env_module = split[0]
  22. env_name = split[-1].split(":")[-1]
  23. env_type = env_module if len(split) == 2 else split[-1].split(":")[0]
  24. if env_module == "minigrid" and env_name not in envs_completed:
  25. os.makedirs(os.path.join(output_dir, env_type), exist_ok=True)
  26. path = os.path.join(output_dir, env_type, env_name + ".gif")
  27. envs_completed.append(env_name)
  28. # try catch in case missing some installs
  29. try:
  30. env = gymnasium.make(env_spec.id, render_mode="rgb_array")
  31. # the gymnasium needs to be rgb renderable
  32. if not ("rgb_array" in env.metadata["render_modes"]):
  33. continue
  34. # obtain and save LENGTH frames worth of steps
  35. frames = []
  36. t = 0
  37. while True:
  38. state, info = env.reset()
  39. terminated, truncated = False, False
  40. while not (terminated or truncated) and len(frames) <= LENGTH:
  41. frame = env.render()
  42. frames.append(Image.fromarray(frame))
  43. action = env.action_space.sample()
  44. # Avoid to much movement
  45. if t % 10 == 0:
  46. state_next, reward, terminated, truncated, info = env.step(
  47. action
  48. )
  49. t += 1
  50. if len(frames) > LENGTH:
  51. break
  52. env.close()
  53. frames[0].save(
  54. path,
  55. save_all=True,
  56. append_images=frames[1:],
  57. duration=50,
  58. loop=0,
  59. )
  60. print("Saved: " + env_name)
  61. except BaseException as e:
  62. print("ERROR", e)
  63. continue