generate_val_dataset.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
  2. """Script to generate val dataset for SSD/DSSD tutorial."""
  3. from __future__ import absolute_import
  4. from __future__ import division
  5. from __future__ import print_function
  6. import argparse
  7. import os
  8. def parse_args(args=None):
  9. """parse the arguments."""
  10. parser = argparse.ArgumentParser(description='Generate val dataset for SSD/DSSD tutorial')
  11. parser.add_argument(
  12. "--input_image_dir",
  13. type=str,
  14. required=True,
  15. help="Input directory to KITTI training dataset images."
  16. )
  17. parser.add_argument(
  18. "--input_label_dir",
  19. type=str,
  20. required=True,
  21. help="Input directory to KITTI training dataset labels."
  22. )
  23. parser.add_argument(
  24. "--output_dir",
  25. type=str,
  26. required=True,
  27. help="Ouput directory to TLT val dataset."
  28. )
  29. parser.add_argument(
  30. "--val_split",
  31. type=int,
  32. required=False,
  33. default=10,
  34. help="Percentage of training dataset for generating val dataset"
  35. )
  36. return parser.parse_args(args)
  37. def main(args=None):
  38. """Main function for data preparation."""
  39. args = parse_args(args)
  40. img_files = []
  41. for file_name in os.listdir(args.input_image_dir):
  42. if file_name.split(".")[-1] == "png":
  43. img_files.append(file_name)
  44. total_cnt = len(img_files)
  45. val_ratio = float(args.val_split) / 100.0
  46. val_cnt = int(total_cnt * val_ratio)
  47. train_cnt = total_cnt - val_cnt
  48. val_img_list = img_files[0:val_cnt]
  49. target_img_path = os.path.join(args.output_dir, "image")
  50. target_label_path = os.path.join(args.output_dir, "label")
  51. if not os.path.exists(target_img_path):
  52. os.makedirs(target_img_path)
  53. else:
  54. print("This script will not run as output image path already exists.")
  55. return
  56. if not os.path.exists(target_label_path):
  57. os.makedirs(target_label_path)
  58. else:
  59. print("This script will not run as output label path already exists.")
  60. return
  61. print("Total {} samples in KITTI training dataset".format(total_cnt))
  62. print("{} for train and {} for val".format(train_cnt, val_cnt))
  63. for img_name in val_img_list:
  64. label_name = img_name.split(".")[0] + ".txt"
  65. os.rename(os.path.join(args.input_image_dir, img_name),
  66. os.path.join(target_img_path, img_name))
  67. os.rename(os.path.join(args.input_label_dir, label_name),
  68. os.path.join(target_label_path, label_name))
  69. if __name__ == "__main__":
  70. main()