app.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. import os
  2. import gc
  3. import io
  4. import cv2
  5. import base64
  6. import pathlib
  7. import numpy as np
  8. from PIL import Image
  9. import streamlit as st
  10. from streamlit_drawable_canvas import st_canvas
  11. import torch
  12. import torchvision.transforms as torchvision_T
  13. from torchvision.models.segmentation import deeplabv3_resnet50, deeplabv3_mobilenet_v3_large
  14. @st.cache(allow_output_mutation=True)
  15. def load_model(num_classes=2, model_name="mbv3", device=torch.device("cpu")):
  16. if model_name == "mbv3":
  17. model = deeplabv3_mobilenet_v3_large(num_classes=num_classes, aux_loss=True)
  18. checkpoint_path = os.path.join(os.getcwd(), "model_mbv3_iou_mix_2C049.pth")
  19. else:
  20. model = deeplabv3_resnet50(num_classes=num_classes, aux_loss=True)
  21. checkpoint_path = os.path.join(os.getcwd(), "model_r50_iou_mix_2C020.pth")
  22. model.to(device)
  23. checkpoints = torch.load(checkpoint_path, map_location=device)
  24. model.load_state_dict(checkpoints, strict=False)
  25. model.eval()
  26. _ = model(torch.randn((1, 3, 384, 384)))
  27. return model
  28. def image_preprocess_transforms(mean=(0.4611, 0.4359, 0.3905), std=(0.2193, 0.2150, 0.2109)):
  29. common_transforms = torchvision_T.Compose(
  30. [
  31. torchvision_T.ToTensor(),
  32. torchvision_T.Normalize(mean, std),
  33. ]
  34. )
  35. return common_transforms
  36. def order_points(pts):
  37. """Rearrange coordinates to order:
  38. top-left, top-right, bottom-right, bottom-left"""
  39. rect = np.zeros((4, 2), dtype="float32")
  40. pts = np.array(pts)
  41. s = pts.sum(axis=1)
  42. # Top-left point will have the smallest sum.
  43. rect[0] = pts[np.argmin(s)]
  44. # Bottom-right point will have the largest sum.
  45. rect[2] = pts[np.argmax(s)]
  46. diff = np.diff(pts, axis=1)
  47. # Top-right point will have the smallest difference.
  48. rect[1] = pts[np.argmin(diff)]
  49. # Bottom-left will have the largest difference.
  50. rect[3] = pts[np.argmax(diff)]
  51. # return the ordered coordinates
  52. return rect.astype("int").tolist()
  53. def find_dest(pts):
  54. (tl, tr, br, bl) = pts
  55. # Finding the maximum width.
  56. widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
  57. widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
  58. maxWidth = max(int(widthA), int(widthB))
  59. # Finding the maximum height.
  60. heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
  61. heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
  62. maxHeight = max(int(heightA), int(heightB))
  63. # Final destination co-ordinates.
  64. destination_corners = [[0, 0], [maxWidth, 0], [maxWidth, maxHeight], [0, maxHeight]]
  65. return order_points(destination_corners)
  66. def scan(image_true=None, trained_model=None, image_size=384, BUFFER=10):
  67. global preprocess_transforms
  68. IMAGE_SIZE = image_size
  69. half = IMAGE_SIZE // 2
  70. imH, imW, C = image_true.shape
  71. image_model = cv2.resize(image_true, (IMAGE_SIZE, IMAGE_SIZE), interpolation=cv2.INTER_NEAREST)
  72. scale_x = imW / IMAGE_SIZE
  73. scale_y = imH / IMAGE_SIZE
  74. image_model = preprocess_transforms(image_model)
  75. image_model = torch.unsqueeze(image_model, dim=0)
  76. with torch.no_grad():
  77. out = trained_model(image_model)["out"].cpu()
  78. del image_model
  79. gc.collect()
  80. out = torch.argmax(out, dim=1, keepdims=True).permute(0, 2, 3, 1)[0].numpy().squeeze().astype(np.int32)
  81. r_H, r_W = out.shape
  82. _out_extended = np.zeros((IMAGE_SIZE + r_H, IMAGE_SIZE + r_W), dtype=out.dtype)
  83. _out_extended[half : half + IMAGE_SIZE, half : half + IMAGE_SIZE] = out * 255
  84. out = _out_extended.copy()
  85. del _out_extended
  86. gc.collect()
  87. # Edge Detection.
  88. canny = cv2.Canny(out.astype(np.uint8), 225, 255)
  89. canny = cv2.dilate(canny, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)))
  90. contours, _ = cv2.findContours(canny, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
  91. page = sorted(contours, key=cv2.contourArea, reverse=True)[0]
  92. # ==========================================
  93. epsilon = 0.02 * cv2.arcLength(page, True)
  94. corners = cv2.approxPolyDP(page, epsilon, True)
  95. corners = np.concatenate(corners).astype(np.float32)
  96. corners[:, 0] -= half
  97. corners[:, 1] -= half
  98. corners[:, 0] *= scale_x
  99. corners[:, 1] *= scale_y
  100. # check if corners are inside.
  101. # if not find smallest enclosing box, expand_image then extract document
  102. # else extract document
  103. if not (np.all(corners.min(axis=0) >= (0, 0)) and np.all(corners.max(axis=0) <= (imW, imH))):
  104. left_pad, top_pad, right_pad, bottom_pad = 0, 0, 0, 0
  105. rect = cv2.minAreaRect(corners.reshape((-1, 1, 2)))
  106. box = cv2.boxPoints(rect)
  107. box_corners = np.int32(box)
  108. # box_corners = minimum_bounding_rectangle(corners)
  109. box_x_min = np.min(box_corners[:, 0])
  110. box_x_max = np.max(box_corners[:, 0])
  111. box_y_min = np.min(box_corners[:, 1])
  112. box_y_max = np.max(box_corners[:, 1])
  113. # Find corner point which doesn't satify the image constraint
  114. # and record the amount of shift required to make the box
  115. # corner satisfy the constraint
  116. if box_x_min <= 0:
  117. left_pad = abs(box_x_min) + BUFFER
  118. if box_x_max >= imW:
  119. right_pad = (box_x_max - imW) + BUFFER
  120. if box_y_min <= 0:
  121. top_pad = abs(box_y_min) + BUFFER
  122. if box_y_max >= imH:
  123. bottom_pad = (box_y_max - imH) + BUFFER
  124. # new image with additional zeros pixels
  125. image_extended = np.zeros((top_pad + bottom_pad + imH, left_pad + right_pad + imW, C), dtype=image_true.dtype)
  126. # adjust original image within the new 'image_extended'
  127. image_extended[top_pad : top_pad + imH, left_pad : left_pad + imW, :] = image_true
  128. image_extended = image_extended.astype(np.float32)
  129. # shifting 'box_corners' the required amount
  130. box_corners[:, 0] += left_pad
  131. box_corners[:, 1] += top_pad
  132. corners = box_corners
  133. image_true = image_extended
  134. corners = sorted(corners.tolist())
  135. corners = order_points(corners)
  136. destination_corners = find_dest(corners)
  137. M = cv2.getPerspectiveTransform(np.float32(corners), np.float32(destination_corners))
  138. final = cv2.warpPerspective(image_true, M, (destination_corners[2][0], destination_corners[2][1]), flags=cv2.INTER_LANCZOS4)
  139. final = np.clip(final, a_min=0, a_max=255)
  140. final = final.astype(np.uint8)
  141. return final
  142. # Generating a link to download a particular image file.
  143. def get_image_download_link(img, filename, text):
  144. buffered = io.BytesIO()
  145. img.save(buffered, format="JPEG")
  146. img_str = base64.b64encode(buffered.getvalue()).decode()
  147. href = f'<a href="data:file/txt;base64,{img_str}" download="{filename}">{text}</a>'
  148. return href
  149. # We create a downloads directory within the streamlit static asset directory
  150. # and we write output files to it
  151. STREAMLIT_STATIC_PATH = pathlib.Path(st.__path__[0]) / "static"
  152. DOWNLOADS_PATH = STREAMLIT_STATIC_PATH / "downloads"
  153. if not DOWNLOADS_PATH.is_dir():
  154. DOWNLOADS_PATH.mkdir()
  155. IMAGE_SIZE = 384
  156. preprocess_transforms = image_preprocess_transforms()
  157. image = None
  158. final = None
  159. result = None
  160. st.set_page_config(initial_sidebar_state="collapsed")
  161. st.title("Document Scanner: Semantic Segmentation using DeepLabV3-PyTorch")
  162. uploaded_file = st.file_uploader("Upload Document Image :", type=["png", "jpg", "jpeg"])
  163. method = st.radio("Select Document Segmentation Model:", ("MobilenetV3-Large", "Resnet-50"), horizontal=True)
  164. col1, col2 = st.columns((6, 5))
  165. if uploaded_file is not None:
  166. # Convert the file to an opencv image.
  167. file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
  168. image = cv2.imdecode(file_bytes, 1)
  169. h, w = image.shape[:2]
  170. if method == "MobilenetV3-Large":
  171. model = load_model(model_name="mbv3")
  172. else:
  173. model = load_model(model_name="r50")
  174. with col1:
  175. st.title("Input")
  176. st.image(image, channels="BGR", use_column_width=True)
  177. with col2:
  178. st.title("Scanned")
  179. final = scan(image_true=image, trained_model=model, image_size=IMAGE_SIZE)
  180. st.image(final, channels="BGR", use_column_width=True)
  181. if final is not None:
  182. # Display link.
  183. result = Image.fromarray(final[:, :, ::-1])
  184. st.markdown(get_image_download_link(result, "output.png", "Download " + "Output"), unsafe_allow_html=True)