gen_env_docs.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. from __future__ import annotations
  2. import os
  3. import re
  4. from itertools import chain
  5. import gymnasium
  6. from tqdm import tqdm
  7. import minigrid
  8. from utils import env_name_format, trim
  9. gymnasium.register_envs(minigrid)
  10. readme_path = os.path.join(
  11. os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
  12. "README.md",
  13. )
  14. all_envs = list(gymnasium.registry.values())
  15. filtered_envs_by_type = {}
  16. env_names = []
  17. babyai_envs = {}
  18. wfc_envs = {}
  19. # Obtain filtered list
  20. for env_spec in tqdm(all_envs):
  21. if isinstance(env_spec.entry_point, str):
  22. # minigrid.envs:Env
  23. split = env_spec.entry_point.split(".")
  24. # ignore gymnasium.envs.env_type:Env
  25. env_module = split[0]
  26. if len(split) > 2 and "babyai" in split[2]:
  27. curr_babyai_env = split[2]
  28. babyai_env_name = curr_babyai_env.split(":")[1]
  29. babyai_envs[babyai_env_name] = env_spec
  30. elif len(split) > 2 and "wfc" in split[2]:
  31. curr_wfc_env = env_spec.kwargs["wfc_config"]
  32. wfc_envs[curr_wfc_env] = env_spec
  33. elif env_module == "minigrid":
  34. env_name = split[1]
  35. filtered_envs_by_type[env_name] = env_spec
  36. # if env_module != "minigrid":
  37. else:
  38. continue
  39. filtered_envs = {
  40. env[0]: env[1]
  41. for env in sorted(
  42. filtered_envs_by_type.items(),
  43. key=lambda item: item[1].entry_point.split(".")[1],
  44. )
  45. }
  46. filtered_babyai_envs = {
  47. env[0]: env[1]
  48. for env in sorted(
  49. babyai_envs.items(),
  50. key=lambda item: item[1].entry_point.split(".")[1],
  51. )
  52. }
  53. # Because they share a class, only the default (MazeSimple) environment should be kept
  54. canonical_wfc_env_name = "MazeSimple"
  55. filtered_wfc_envs = {canonical_wfc_env_name: wfc_envs[canonical_wfc_env_name]}
  56. for env_name, env_spec in chain(
  57. filtered_envs.items(), filtered_babyai_envs.items(), filtered_wfc_envs.items()
  58. ):
  59. env = env_spec.make()
  60. docstring = trim(env.unwrapped.__doc__)
  61. # minigrid.envs:Env or minigrid.envs.babyai:Env
  62. split = env_spec.entry_point.split(".")
  63. # ignore minigrid.envs.env_type:Env
  64. env_module = split[0]
  65. env_name = split[-1].split(":")[-1]
  66. env_type = env_module if len(split) == 2 else split[-1].split(":")[0]
  67. path_name = ""
  68. os.makedirs(
  69. os.path.join(
  70. os.path.dirname(os.path.dirname(__file__)), "environments", env_type
  71. ),
  72. exist_ok=True,
  73. )
  74. v_path = os.path.join(
  75. os.path.dirname(os.path.dirname(__file__)),
  76. "environments",
  77. env_type,
  78. f"{env_name}.md",
  79. )
  80. formatted_env_name = env_name_format(env_name)
  81. # Front matter
  82. front_matter = f"""---
  83. autogenerated:
  84. title: {formatted_env_name}
  85. ---
  86. """
  87. # Title and gif
  88. title = f"# {formatted_env_name}"
  89. gif = (
  90. "```{figure} "
  91. + f"""/_static/videos/{env_type}/{env_name}.gif
  92. :alt: {formatted_env_name}
  93. :width: 200px
  94. ```
  95. """
  96. )
  97. # Environment attributes
  98. action_space_table = env.action_space.__repr__().replace("\n", "")
  99. observation_space_table = env.observation_space.__repr__().replace("\n", "")
  100. env_attributes = f"""
  101. | | |
  102. |---|---|
  103. | Action Space | `{re.sub(' +', ' ', action_space_table)}` |
  104. | Observation Space | `{re.sub(' +', ' ', observation_space_table)}` |
  105. | Creation | `gymnasium.make("{env_spec.id}")` |
  106. """
  107. # Create Markdown file content
  108. if docstring is None:
  109. docstring = "No information provided"
  110. all_text = f"""{front_matter}
  111. {title}
  112. {gif}
  113. {env_attributes}
  114. {docstring}
  115. """
  116. file = open(v_path, "w+", encoding="utf-8")
  117. file.write(all_text)
  118. file.close()