benchmark.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from argparse import ArgumentParser
  2. import numpy as np
  3. from prettytable import PrettyTable
  4. from create_lmdb import store_many_lmdb
  5. from create_tfrecords import store_many_tfrecords
  6. from loader import (
  7. CV2Loader,
  8. LmdbLoader,
  9. PILLoader,
  10. TFRecordsLoader,
  11. TurboJpegLoader,
  12. methods,
  13. )
  14. from tools import get_images_paths
  15. def count_time(loader, iters):
  16. time_list = []
  17. num_images = len(loader)
  18. for i in range(iters):
  19. loader = iter(loader)
  20. for idx in range(num_images):
  21. image, time = next(loader)
  22. time_list.append(time)
  23. time_list = np.asarray(time_list)
  24. print_stats(time_list, type(loader).__name__)
  25. return np.asarray(time_list)
  26. def print_stats(time, name):
  27. print("Time measures for {}:".format(name))
  28. print("{} mean time - {:.8f} seconds".format(name, time.mean()))
  29. print("{} median time - {:.8f} seconds".format(name, np.median(time)))
  30. print("{} std time - {:.8f} seconds".format(name, time.std()))
  31. print("{} min time - {:.8f} seconds".format(name, time.min()))
  32. print("{} max time - {:.8f} seconds".format(name, time.max()))
  33. print("\n")
  34. def benchmark(method, path, iters=100, **kwargs):
  35. image_loader = methods[method](path, **kwargs) # get image loader
  36. time = count_time(image_loader, iters) # measure the time for loading
  37. return time
  38. if __name__ == "__main__":
  39. parser = ArgumentParser()
  40. parser.add_argument(
  41. "--path", "-p", type=str, help="path to image folder",
  42. )
  43. parser.add_argument(
  44. "--method",
  45. nargs="+",
  46. required=True,
  47. choices=["cv2", "pil", "turbojpeg", "lmdb", "tfrecords"],
  48. help="Image loading methods to use in benchmark",
  49. )
  50. parser.add_argument(
  51. "--mode",
  52. "-m",
  53. type=str,
  54. required=True,
  55. choices=["BGR", "RGB"],
  56. help="Image color mode",
  57. )
  58. parser.add_argument(
  59. "--iters", type=int, help="Number of iterations to average the results",
  60. )
  61. args = parser.parse_args()
  62. benchmark_methods = args.method
  63. image_paths = get_images_paths(args.path)
  64. results = {}
  65. for method in benchmark_methods:
  66. if method == "lmdb":
  67. path = "./lmdb/images"
  68. store_many_lmdb(image_paths, path)
  69. elif method == "tfrecords":
  70. path = "./tfrecords/images.tfrecords"
  71. store_many_tfrecords(image_paths, path)
  72. else:
  73. path = args.path
  74. time = benchmark(method, path, mode=args.mode, iters=args.iters)
  75. results.update({method: time})
  76. table = PrettyTable(["Loader", "Mean time", "Median time"])
  77. print(
  78. f"Benchmark on {len(image_paths)} {args.mode} images with {args.iters} averaging iteration results:\n",
  79. )
  80. for method, time in results.items():
  81. table.add_row([method, time.mean(), np.median(time)])
  82. print(table)