split_dataset.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import os
  2. import cv2
  3. import shutil
  4. import random
  5. import numpy as np
  6. np.random.seed(1)
  7. og_img_dir = r"final_set\\images"
  8. og_msk_dir = r"final_set\\masks"
  9. # Saving resized images to reduce file size
  10. MAX_DIM_SIZE = 480
  11. img_per_doc = 6
  12. train_img_dir = r"document_dataset_resized\\train\\images"
  13. train_msk_dir = r"document_dataset_resized\\train\\masks"
  14. valid_img_dir = r"document_dataset_resized\\valid\\images"
  15. valid_msk_dir = r"document_dataset_resized\\valid\\masks"
  16. os.makedirs(train_img_dir, exist_ok=True)
  17. os.makedirs(train_msk_dir, exist_ok=True)
  18. os.makedirs(valid_img_dir, exist_ok=True)
  19. os.makedirs(valid_msk_dir, exist_ok=True)
  20. all_img_paths = np.asarray(sorted([os.path.join(og_img_dir, i) for i in os.listdir(og_img_dir)]))
  21. all_msk_paths = np.asarray(sorted([os.path.join(og_msk_dir, i) for i in os.listdir(og_msk_dir)]))
  22. total_number_of_documents = len(all_img_paths) // img_per_doc
  23. all_img_paths = np.split(all_img_paths, total_number_of_documents)
  24. all_msk_paths = np.split(all_msk_paths, total_number_of_documents)
  25. print(len(all_img_paths))
  26. img_per_doc = list(range(img_per_doc))
  27. train_img_paths = []
  28. train_msk_paths = []
  29. valid_img_paths = []
  30. valid_msk_paths = []
  31. for doc_id_img_paths, doc_id_msk_paths in zip(all_img_paths, all_msk_paths):
  32. number = random.choice(img_per_doc)
  33. for i in img_per_doc:
  34. if i == number:
  35. valid_img_paths.append(doc_id_img_paths[i])
  36. valid_msk_paths.append(doc_id_msk_paths[i])
  37. else:
  38. train_img_paths.append(doc_id_img_paths[i])
  39. train_msk_paths.append(doc_id_msk_paths[i])
  40. print(len(train_img_paths), len(valid_img_paths))
  41. def ResizeWithAspectRatio(curr_dim, resize_to: int = 320):
  42. """returns new h and new w which maintains the aspect ratio"""
  43. h, w = curr_dim
  44. if h > w:
  45. r = resize_to / float(h)
  46. size = (int(w * r), resize_to)
  47. else:
  48. r = resize_to / float(w)
  49. size = (resize_to, int(h * r))
  50. return size[::-1]
  51. def copy(img_paths, msk_paths, dst_img_dir, dst_msk_dir):
  52. for idx, (image_path, mask_path) in enumerate(zip(img_paths, msk_paths)):
  53. img_name = os.path.split(image_path)[-1]
  54. msk_name = os.path.split(mask_path)[-1]
  55. dst_image_path = os.path.join(dst_img_dir, img_name)
  56. dst_mask_path = os.path.join(dst_msk_dir, msk_name)
  57. # shutil.copyfile(image_path, dst_image_path)
  58. # shutil.copyfile(mask_path, dst_mask_path)
  59. # Saving resized images to reduce file size
  60. image = cv2.imread(image_path)
  61. h, w = image.shape[:2]
  62. asp_h, asp_w = ResizeWithAspectRatio(curr_dim=(h, w), resize_to=MAX_DIM_SIZE)
  63. image = cv2.resize(image, (asp_w, asp_h), interpolation=cv2.INTER_NEAREST)
  64. mask = cv2.imread(mask_path)
  65. mask = cv2.resize(mask, (asp_w, asp_h), interpolation=cv2.INTER_NEAREST)
  66. cv2.imwrite(dst_image_path, image)
  67. cv2.imwrite(dst_mask_path, mask)
  68. return
  69. # Training set
  70. copy(train_img_paths, train_msk_paths, train_img_dir, train_msk_dir)
  71. # Validation set
  72. copy(valid_img_paths, valid_msk_paths, valid_img_dir, valid_msk_dir)