create_tfrecords.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import os
  2. from argparse import ArgumentParser
  3. import tensorflow as tf
  4. from tools import get_images_paths
  5. def _byte_feature(value):
  6. """Convert string / byte into bytes_list."""
  7. if isinstance(value, type(tf.constant(0))):
  8. value = value.numpy() # BytesList can't unpack string from EagerTensor.
  9. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
  10. def _int64_feature(value):
  11. """Convert bool / enum / int / uint into int64_list."""
  12. return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
  13. def image_example(image_string, label):
  14. feature = {
  15. "label": _int64_feature(label),
  16. "image_raw": _byte_feature(image_string),
  17. }
  18. return tf.train.Example(features=tf.train.Features(feature=feature))
  19. def store_many_tfrecords(images_list, save_file):
  20. assert save_file.endswith(
  21. ".tfrecords",
  22. ), 'File path is wrong, it should contain "*myname*.tfrecords"'
  23. directory = os.path.dirname(save_file)
  24. if not os.path.exists(directory):
  25. os.makedirs(directory)
  26. with tf.io.TFRecordWriter(save_file) as writer: # start writer
  27. for label, filename in enumerate(images_list): # cycle by each image path
  28. image_string = open(filename, "rb").read() # read the image as bytes string
  29. tf_example = image_example(
  30. image_string, label,
  31. ) # save the data as tf.Example object
  32. writer.write(tf_example.SerializeToString()) # and write it into database
  33. if __name__ == "__main__":
  34. parser = ArgumentParser()
  35. parser.add_argument(
  36. "--path",
  37. "-p",
  38. type=str,
  39. required=True,
  40. help="path to the images folder to collect",
  41. )
  42. parser.add_argument(
  43. "--output",
  44. "-o",
  45. type=str,
  46. required=True,
  47. help='path to the output tfrecords file i.e. "path/to/folder/myname.tfrecords"',
  48. )
  49. args = parser.parse_args()
  50. image_paths = get_images_paths(args.path)
  51. store_many_tfrecords(image_paths, args.output)