123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- # Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
- """Script to generate val dataset for SSD/DSSD tutorial."""
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import argparse
- import os
- def parse_args(args=None):
- """parse the arguments."""
- parser = argparse.ArgumentParser(description='Generate val dataset for SSD/DSSD tutorial')
- parser.add_argument(
- "--input_image_dir",
- type=str,
- required=True,
- help="Input directory to KITTI training dataset images."
- )
- parser.add_argument(
- "--input_label_dir",
- type=str,
- required=True,
- help="Input directory to KITTI training dataset labels."
- )
- parser.add_argument(
- "--output_dir",
- type=str,
- required=True,
- help="Ouput directory to TLT val dataset."
- )
- parser.add_argument(
- "--val_split",
- type=int,
- required=False,
- default=10,
- help="Percentage of training dataset for generating val dataset"
- )
- return parser.parse_args(args)
- def main(args=None):
- """Main function for data preparation."""
- args = parse_args(args)
- img_files = []
- for file_name in os.listdir(args.input_image_dir):
- if file_name.split(".")[-1] == "png":
- img_files.append(file_name)
- total_cnt = len(img_files)
- val_ratio = float(args.val_split) / 100.0
- val_cnt = int(total_cnt * val_ratio)
- train_cnt = total_cnt - val_cnt
- val_img_list = img_files[0:val_cnt]
- target_img_path = os.path.join(args.output_dir, "image")
- target_label_path = os.path.join(args.output_dir, "label")
- if not os.path.exists(target_img_path):
- os.makedirs(target_img_path)
- else:
- print("This script will not run as output image path already exists.")
- return
- if not os.path.exists(target_label_path):
- os.makedirs(target_label_path)
- else:
- print("This script will not run as output label path already exists.")
- return
- print("Total {} samples in KITTI training dataset".format(total_cnt))
- print("{} for train and {} for val".format(train_cnt, val_cnt))
- for img_name in val_img_list:
- label_name = img_name.split(".")[0] + ".txt"
- os.rename(os.path.join(args.input_image_dir, img_name),
- os.path.join(target_img_path, img_name))
- os.rename(os.path.join(args.input_label_dir, label_name),
- os.path.join(target_label_path, label_name))
- if __name__ == "__main__":
- main()
|