gen_mds.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. __author__ = "Feng Gu"
  2. __email__ = "contact@fenggu.me"
  3. """
  4. isort:skip_file
  5. """
  6. import os
  7. import re
  8. from gymnasium.envs.registration import registry
  9. from tqdm import tqdm
  10. from utils import trim
  11. from itertools import chain
  12. readme_path = os.path.join(
  13. os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
  14. "README.md",
  15. )
  16. LAYOUT = "env"
  17. pattern = re.compile(r"(?<!^)(?=[A-Z])")
  18. all_envs = list(registry.values())
  19. filtered_envs_by_type = {}
  20. env_names = []
  21. babyai_envs = {}
  22. # Obtain filtered list
  23. for env_spec in tqdm(all_envs):
  24. # minigrid.envs:Env
  25. split = env_spec.entry_point.split(".")
  26. # ignore gymnasium.envs.env_type:Env
  27. env_module = split[0]
  28. if len(split) > 2 and "babyai" in split[2]:
  29. curr_babyai_env = split[2]
  30. babyai_env_name = curr_babyai_env.split(":")[1]
  31. babyai_envs[babyai_env_name] = env_spec
  32. elif env_module == "minigrid":
  33. env_name = split[1]
  34. filtered_envs_by_type[env_name] = env_spec
  35. # if env_module != "minigrid":
  36. else:
  37. continue
  38. filtered_envs = {
  39. env[0]: env[1]
  40. for env in sorted(
  41. filtered_envs_by_type.items(),
  42. key=lambda item: item[1].entry_point.split(".")[1],
  43. )
  44. }
  45. filtered_babyai_envs = {
  46. env[0]: env[1]
  47. for env in sorted(
  48. babyai_envs.items(),
  49. key=lambda item: item[1].entry_point.split(".")[1],
  50. )
  51. }
  52. for env_name, env_spec in chain(filtered_envs.items(), filtered_babyai_envs.items()):
  53. made = env_spec.make()
  54. docstring = trim(made.unwrapped.__doc__)
  55. pascal_env_name = env_spec.id.split("-")[1]
  56. # remove suffix
  57. p = re.compile(r"([A-Z][a-z]+)*")
  58. name = p.search(pascal_env_name).group()
  59. snake_env_name = pattern.sub("_", name).lower()
  60. env_names.append(snake_env_name)
  61. title_env_name = snake_env_name.replace("_", " ").title()
  62. v_path = os.path.join(
  63. os.path.dirname(os.path.dirname(__file__)),
  64. "environments",
  65. snake_env_name + ".md",
  66. )
  67. front_matter = f"""---
  68. autogenerated:
  69. title: {title_env_name}
  70. ---
  71. """
  72. title = f"# {title_env_name}"
  73. if docstring is None:
  74. docstring = "No information provided"
  75. all_text = f"""{front_matter}
  76. {title}
  77. {docstring}
  78. """
  79. file = open(v_path, "w+", encoding="utf-8")
  80. file.write(all_text)
  81. file.close()