autogenerate_notebooks_table.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. from __future__ import annotations
  2. import argparse
  3. from dataclasses import dataclass
  4. from typing import Optional, List
  5. from enum import Enum
  6. NOTEBOOKS_ROOT_PATH = "https://github.com/roboflow-ai/notebooks/blob/main/notebooks"
  7. NOTEBOOKS_COLAB_ROOT_PATH = "github/roboflow-ai/notebooks/blob/main/notebooks"
  8. WARNING_HEADER = [
  9. "<!---",
  10. " WARNING: DO NOT EDIT THIS TABLE MANUALLY. IT IS AUTOMATICALLY GENERATED.",
  11. " HEAD OVER TO CONTRIBUTING.MD FOR MORE DETAILS ON HOW TO MAKE CHANGES PROPERLY.",
  12. "-->"
  13. ]
  14. TABLE_HEADER = [
  15. "| **notebook** | **open in colab / kaggle / sagemaker studio lab** | **complementary materials** | **repository / paper** |",
  16. "|:------------:|:-------------------------------------------------:|:---------------------------:|:----------------------:|"
  17. ]
  18. MODELS_SECTION_HEADER = "## 🚀 model tutorials ({} notebooks)"
  19. SKILLS_SECTION_HEADER = "## 📸 computer vision skills ({} notebooks)"
  20. NOTEBOOK_LINK_PATTERN = "[{}]({}/{})"
  21. OPEN_IN_COLAB_BADGE_PATTERN = "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/{}/{})"
  22. OPEN_IN_KAGGLE_BADGE_PATTERN = "[![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src={}/{})"
  23. OPEN_IN_SAGEMAKER_LAB_PATTERN = "[![SageMaker](https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/sage-maker.svg)](https://studiolab.sagemaker.aws/import/github/roboflow-ai/notebooks/blob/main/notebooks/{})"
  24. ROBOFLOW_BADGE_PATTERN = "[![Roboflow](https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg)]({})"
  25. YOUTUBE_BADGE_PATTERN = "[![YouTube](https://badges.aleen42.com/src/youtube.svg)]({})"
  26. GITHUB_BADGE_PATTERN = "[![GitHub](https://badges.aleen42.com/src/github.svg)]({})"
  27. ARXIV_BADGE_PATTERN = "[![arXiv](https://img.shields.io/badge/arXiv-{}-b31b1b.svg)](https://arxiv.org/abs/{})"
  28. AUTOGENERATED_NOTEBOOKS_TABLE_TOKEN = "<!--- AUTOGENERATED-NOTEBOOKS-TABLE -->"
  29. class READMESection(Enum):
  30. MODELS = "models"
  31. SKILLS = "skills"
  32. @classmethod
  33. def from_value(cls, value: str) -> READMESection:
  34. try:
  35. return cls(value=value.lower())
  36. except (AttributeError, ValueError):
  37. raise Exception(f"{cls.__name__} must be one of {READMESection.list()}, {value} given.")
  38. @staticmethod
  39. def list():
  40. return list(map(lambda entry: entry.value, READMESection))
  41. @dataclass(frozen=True)
  42. class TableEntry:
  43. display_name: str
  44. notebook_name: str
  45. roboflow_blogpost_path: Optional[str]
  46. youtube_video_path: Optional[str]
  47. github_repository_path: Optional[str]
  48. arxiv_index: Optional[str]
  49. should_open_in_sagemaker_labs: bool
  50. readme_section: READMESection
  51. @classmethod
  52. def from_csv_line(cls, csv_line: str) -> TableEntry:
  53. csv_fields = [
  54. field.strip()
  55. for field
  56. in csv_line.split(",")
  57. ]
  58. if len(csv_fields) != 8:
  59. raise Exception("Every csv line must contain 8 fields")
  60. return TableEntry(
  61. display_name=csv_fields[0],
  62. notebook_name=csv_fields[1],
  63. roboflow_blogpost_path=csv_fields[2],
  64. youtube_video_path=csv_fields[3],
  65. github_repository_path=csv_fields[4],
  66. arxiv_index=csv_fields[5],
  67. should_open_in_sagemaker_labs=csv_fields[6] == "True",
  68. readme_section=READMESection(csv_fields[7])
  69. )
  70. def format(self) -> str:
  71. notebook_link = NOTEBOOK_LINK_PATTERN.format(self.display_name, NOTEBOOKS_ROOT_PATH, self.notebook_name)
  72. open_in_colab_badge = OPEN_IN_COLAB_BADGE_PATTERN.format(NOTEBOOKS_COLAB_ROOT_PATH, self.notebook_name)
  73. open_in_kaggle_badge = OPEN_IN_KAGGLE_BADGE_PATTERN.format(NOTEBOOKS_ROOT_PATH, self.notebook_name)
  74. open_in_sagemaker_lab_badge = OPEN_IN_SAGEMAKER_LAB_PATTERN.format(self.notebook_name) if self.should_open_in_sagemaker_labs else ""
  75. roboflow_badge = ROBOFLOW_BADGE_PATTERN.format(self.roboflow_blogpost_path) if self.roboflow_blogpost_path else ""
  76. youtube_badge = YOUTUBE_BADGE_PATTERN.format(self.youtube_video_path) if self.youtube_video_path else ""
  77. github_badge = GITHUB_BADGE_PATTERN.format(self.github_repository_path) if self.github_repository_path else ""
  78. arxiv_badge = ARXIV_BADGE_PATTERN.format(self.arxiv_index, self.arxiv_index) if self.arxiv_index else ""
  79. return f"| {notebook_link} | {open_in_colab_badge} {open_in_kaggle_badge} {open_in_sagemaker_lab_badge} | {roboflow_badge} {youtube_badge} | {github_badge} {arxiv_badge}|"
  80. def read_lines_from_file(path: str) -> List[str]:
  81. with open(path) as file:
  82. return [line.rstrip() for line in file]
  83. def save_lines_to_file(path: str, lines: List[str]) -> None:
  84. with open(path, "w") as f:
  85. for line in lines:
  86. f.write("%s\n" % line)
  87. def parse_csv_lines(csv_lines: List[str]) -> List[TableEntry]:
  88. return [
  89. TableEntry.from_csv_line(csv_line=csv_line)
  90. for csv_line
  91. in csv_lines
  92. ]
  93. def search_lines_with_token(lines: List[str], token: str) -> List[int]:
  94. result = []
  95. for line_index, line in enumerate(lines):
  96. if token in line:
  97. result.append(line_index)
  98. return result
  99. def inject_markdown_table_into_readme(readme_lines: List[str], table_lines: List[str]) -> List[str]:
  100. lines_with_token_indexes = search_lines_with_token(lines=readme_lines, token=AUTOGENERATED_NOTEBOOKS_TABLE_TOKEN)
  101. if len(lines_with_token_indexes) != 2:
  102. raise Exception(f"Please inject two {AUTOGENERATED_NOTEBOOKS_TABLE_TOKEN} "
  103. f"tokens to signal start and end of autogenerated table.")
  104. [table_start_line_index, table_end_line_index] = lines_with_token_indexes
  105. return readme_lines[:table_start_line_index + 1] + table_lines + readme_lines[table_end_line_index:]
  106. if __name__ == "__main__":
  107. parser = argparse.ArgumentParser()
  108. parser.add_argument('-d', '--data_path', default='automation/notebooks-table-data.csv')
  109. parser.add_argument('-r', '--readme_path', default='README.md')
  110. args = parser.parse_args()
  111. csv_lines = read_lines_from_file(path=args.data_path)[1:]
  112. readme_lines = read_lines_from_file(path=args.readme_path)
  113. table_entries = parse_csv_lines(csv_lines=csv_lines)
  114. models_lines = [
  115. entry.format()
  116. for entry
  117. in table_entries
  118. if entry.readme_section == READMESection.MODELS
  119. ]
  120. skills_lines = [
  121. entry.format()
  122. for entry
  123. in table_entries
  124. if entry.readme_section == READMESection.SKILLS
  125. ]
  126. table_lines = WARNING_HEADER + \
  127. [MODELS_SECTION_HEADER.format(len(models_lines))] + TABLE_HEADER + models_lines + \
  128. [SKILLS_SECTION_HEADER.format(len(skills_lines))] + TABLE_HEADER + skills_lines
  129. readme_lines = inject_markdown_table_into_readme(readme_lines=readme_lines, table_lines=table_lines)
  130. save_lines_to_file(path=args.readme_path, lines=readme_lines)