loader.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import os
  2. from abc import abstractmethod
  3. from timeit import default_timer as timer
  4. import cv2
  5. import lmdb
  6. import numpy as np
  7. import tensorflow as tf
  8. from PIL import Image
  9. from turbojpeg import TurboJPEG
  10. os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
  11. class ImageLoader:
  12. extensions: tuple = (".png", ".jpg", ".jpeg", ".tiff", ".bmp", ".gif", ".tfrecords")
  13. def __init__(self, path: str, mode: str = "BGR"):
  14. self.path = path
  15. self.mode = mode
  16. self.dataset = self.parse_input(self.path)
  17. self.sample_idx = 0
  18. def parse_input(self, path):
  19. # single image or tfrecords file
  20. if os.path.isfile(path):
  21. assert path.lower().endswith(
  22. self.extensions,
  23. ), f"Unsupportable extension, please, use one of {self.extensions}"
  24. return [path]
  25. if os.path.isdir(path):
  26. # lmdb environment
  27. if any([file.endswith(".mdb") for file in os.listdir(path)]):
  28. return path
  29. else:
  30. # folder with images
  31. paths = [os.path.join(path, image) for image in os.listdir(path)]
  32. return paths
  33. def __iter__(self):
  34. self.sample_idx = 0
  35. return self
  36. def __len__(self):
  37. return len(self.dataset)
  38. @abstractmethod
  39. def __next__(self):
  40. pass
  41. class CV2Loader(ImageLoader):
  42. def __next__(self):
  43. start = timer()
  44. path = self.dataset[self.sample_idx] # get image path by index from the dataset
  45. image = cv2.imread(path) # read the image
  46. full_time = timer() - start
  47. if self.mode == "RGB":
  48. start = timer()
  49. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # change color mode
  50. full_time += timer() - start
  51. self.sample_idx += 1
  52. return image, full_time
  53. class PILLoader(ImageLoader):
  54. def __next__(self):
  55. start = timer()
  56. path = self.dataset[self.sample_idx] # get image path by index from the dataset
  57. image = np.asarray(Image.open(path)) # read the image as numpy array
  58. full_time = timer() - start
  59. if self.mode == "BGR":
  60. start = timer()
  61. image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # change color mode
  62. full_time += timer() - start
  63. self.sample_idx += 1
  64. return image, full_time
  65. class TurboJpegLoader(ImageLoader):
  66. def __init__(self, path, **kwargs):
  67. super(TurboJpegLoader, self).__init__(path, **kwargs)
  68. self.jpeg_reader = TurboJPEG() # create TurboJPEG object for image reading
  69. def __next__(self):
  70. start = timer()
  71. file = open(self.dataset[self.sample_idx], "rb") # open the input file as bytes
  72. full_time = timer() - start
  73. if self.mode == "RGB":
  74. mode = 0
  75. elif self.mode == "BGR":
  76. mode = 1
  77. start = timer()
  78. image = self.jpeg_reader.decode(file.read(), mode) # decode raw image
  79. full_time += timer() - start
  80. self.sample_idx += 1
  81. return image, full_time
  82. class LmdbLoader(ImageLoader):
  83. def __init__(self, path, **kwargs):
  84. super(LmdbLoader, self).__init__(path, **kwargs)
  85. self.path = path
  86. self._dataset_size = 0
  87. self.dataset = self.open_database()
  88. # we need to open the database to read images from it
  89. def open_database(self):
  90. lmdb_env = lmdb.open(self.path) # open the environment by path
  91. lmdb_txn = lmdb_env.begin() # start reading
  92. lmdb_cursor = lmdb_txn.cursor() # create cursor to iterate through the database
  93. self._dataset_size = lmdb_env.stat()[
  94. "entries"
  95. ] # get number of items in full dataset
  96. return lmdb_cursor
  97. def __iter__(self):
  98. self.dataset.first() # return the cursor to the first database element
  99. return self
  100. def __next__(self):
  101. start = timer()
  102. raw_image = self.dataset.value() # get raw image
  103. image = np.frombuffer(raw_image, dtype=np.uint8) # convert it to numpy
  104. image = cv2.imdecode(image, cv2.IMREAD_COLOR) # decode image
  105. full_time = timer() - start
  106. if self.mode == "RGB":
  107. start = timer()
  108. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  109. full_time += timer() - start
  110. start = timer()
  111. self.dataset.next() # step to the next element in database
  112. full_time += timer() - start
  113. return image, full_time
  114. def __len__(self):
  115. return self._dataset_size # get dataset length
  116. class TFRecordsLoader(ImageLoader):
  117. def __init__(self, path, **kwargs):
  118. super(TFRecordsLoader, self).__init__(path, **kwargs)
  119. self._dataset = self.open_database()
  120. def open_database(self):
  121. def _parse_image_function(example_proto):
  122. return tf.io.parse_single_example(example_proto, image_feature_description)
  123. # dataset structure description
  124. image_feature_description = {
  125. "label": tf.io.FixedLenFeature([], tf.int64),
  126. "image_raw": tf.io.FixedLenFeature([], tf.string),
  127. }
  128. raw_image_dataset = tf.data.TFRecordDataset(self.path) # open dataset by path
  129. parsed_image_dataset = raw_image_dataset.map(
  130. _parse_image_function,
  131. ) # parse dataset using structure description
  132. return parsed_image_dataset
  133. def __iter__(self):
  134. self.dataset = self._dataset.as_numpy_iterator()
  135. return self
  136. def __next__(self):
  137. start = timer()
  138. value = next(self.dataset)[
  139. "image_raw"
  140. ] # step to the next element in database and get new image
  141. image = tf.image.decode_jpeg(value).numpy() # decode raw image
  142. full_time = timer() - start
  143. if self.mode == "BGR":
  144. start = timer()
  145. image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  146. full_time += timer() - start
  147. return image, full_time
  148. def __len__(self):
  149. return self._dataset.reduce(
  150. np.int64(0), lambda x, _: x + 1,
  151. ).numpy() # get dataset length
  152. methods = {
  153. "cv2": CV2Loader,
  154. "pil": PILLoader,
  155. "turbojpeg": TurboJpegLoader,
  156. "lmdb": LmdbLoader,
  157. "tfrecords": TFRecordsLoader,
  158. }