gen_gifs.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from __future__ import annotations
  2. import os
  3. import re
  4. import gymnasium
  5. from PIL import Image
  6. from tqdm import tqdm
  7. import minigrid
  8. gymnasium.register_envs(minigrid)
  9. # snake to camel case: https://stackoverflow.com/questions/1175208/elegant-python-function-to-convert-camelcase-to-snake-case # noqa: E501
  10. pattern = re.compile(r"(?<!^)(?=[A-Z])")
  11. # how many steps to record an env for
  12. LENGTH = 300
  13. output_dir = os.path.join(os.path.dirname(__file__), "..", "_static", "videos")
  14. os.makedirs(output_dir, exist_ok=True)
  15. # Some environments have multiple versions
  16. # For example, KeyCorridorEnv -> KeyCorridorS3R1, KeyCorridorS3R2, KeyCorridorS3R3, etc
  17. # We only want one as an example
  18. envs_completed = []
  19. # iterate through all envspecs
  20. for env_spec in tqdm(gymnasium.registry.values()):
  21. # minigrid.envs:Env or minigrid.envs.babyai:Env
  22. if not isinstance(env_spec.entry_point, str):
  23. continue
  24. split = env_spec.entry_point.split(".")
  25. # ignore minigrid.envs.env_type:Env
  26. env_module = split[0]
  27. env_name = split[-1].split(":")[-1]
  28. env_type = env_module if len(split) == 2 else split[-1].split(":")[0]
  29. # Override env_name for WFC to include the preset name
  30. if env_name == "WFCEnv":
  31. env_name = env_spec.kwargs["wfc_config"]
  32. if env_module == "minigrid" and env_name not in envs_completed:
  33. os.makedirs(os.path.join(output_dir, env_type), exist_ok=True)
  34. path = os.path.join(output_dir, env_type, env_name + ".gif")
  35. envs_completed.append(env_name)
  36. # try catch in case missing some installs
  37. try:
  38. env = gymnasium.make(env_spec.id, render_mode="rgb_array")
  39. env.reset(seed=123)
  40. env.action_space.seed(123)
  41. # the gymnasium needs to be rgb renderable
  42. if not ("rgb_array" in env.metadata["render_modes"]):
  43. continue
  44. # obtain and save LENGTH frames worth of steps
  45. frames = []
  46. t = 0
  47. while True:
  48. state, info = env.reset()
  49. terminated, truncated = False, False
  50. while not (terminated or truncated) and len(frames) <= LENGTH:
  51. frame = env.render()
  52. frames.append(Image.fromarray(frame))
  53. action = env.action_space.sample()
  54. # Avoid to much movement
  55. if t % 10 == 0:
  56. state_next, reward, terminated, truncated, info = env.step(
  57. action
  58. )
  59. t += 1
  60. if len(frames) > LENGTH:
  61. break
  62. env.close()
  63. frames[0].save(
  64. path,
  65. save_all=True,
  66. append_images=frames[1:],
  67. duration=50,
  68. loop=0,
  69. )
  70. print("Saved: " + env_name)
  71. except BaseException as e:
  72. print("ERROR", e)
  73. continue