split_food-101.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import argparse
  2. import os
  3. import os.path as osp
  4. from shutil import copyfile
  5. from tqdm import tqdm
  6. def main():
  7. parser = argparse.ArgumentParser(
  8. description="Separate Food-101 into train/test folders",
  9. )
  10. parser.add_argument(
  11. "--data-root",
  12. default="./data",
  13. type=str,
  14. help="Path to root folder of the dataset",
  15. )
  16. args = parser.parse_args()
  17. classes = [
  18. "apple_pie",
  19. "bruschetta",
  20. "caesar_salad",
  21. "steak",
  22. "spring_rolls",
  23. "spaghetti_carbonara",
  24. "frozen_yogurt",
  25. "falafel",
  26. "mussels",
  27. "ramen",
  28. "onion_rings",
  29. "oysters",
  30. "risotto",
  31. "waffles",
  32. "cup_cakes",
  33. "grilled_cheese_sandwich",
  34. "fried_calamari",
  35. "huevos_rancheros",
  36. "croque_madame",
  37. "bread_pudding",
  38. "dumplings",
  39. ]
  40. assert osp.isdir(args.data_root)
  41. assert "images" in os.listdir(args.data_root)
  42. assert "meta" in os.listdir(args.data_root)
  43. os.makedirs(osp.join(args.data_root, "train"), exist_ok=True)
  44. os.makedirs(osp.join(args.data_root, "test"), exist_ok=True)
  45. for cls_name in classes:
  46. os.makedirs(osp.join(args.data_root, "train", cls_name), exist_ok=True)
  47. os.makedirs(osp.join(args.data_root, "test", cls_name), exist_ok=True)
  48. with open(osp.join(args.data_root, "meta", "train.txt"), "r") as file:
  49. for image in tqdm(file):
  50. image = image.rstrip()
  51. if image.split("/")[0] in classes:
  52. copyfile(
  53. osp.join(args.data_root, "images", image + ".jpg"),
  54. osp.join(args.data_root, "train", image + ".jpg"),
  55. )
  56. with open(osp.join(args.data_root, "meta", "test.txt"), "r") as file:
  57. for image in tqdm(file):
  58. image = image.rstrip()
  59. if image.split("/")[0] in classes:
  60. copyfile(
  61. osp.join(args.data_root, "images", image + ".jpg"),
  62. osp.join(args.data_root, "test", image + ".jpg"),
  63. )
  64. if __name__ == "__main__":
  65. main()