123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250 |
- import os
- import gc
- import io
- import cv2
- import base64
- import pathlib
- import numpy as np
- from PIL import Image
- import streamlit as st
- from streamlit_drawable_canvas import st_canvas
- import torch
- import torchvision.transforms as torchvision_T
- from torchvision.models.segmentation import deeplabv3_resnet50, deeplabv3_mobilenet_v3_large
- @st.cache(allow_output_mutation=True)
- def load_model(num_classes=2, model_name="mbv3", device=torch.device("cpu")):
- if model_name == "mbv3":
- model = deeplabv3_mobilenet_v3_large(num_classes=num_classes, aux_loss=True)
- checkpoint_path = os.path.join(os.getcwd(), "model_mbv3_iou_mix_2C049.pth")
- else:
- model = deeplabv3_resnet50(num_classes=num_classes, aux_loss=True)
- checkpoint_path = os.path.join(os.getcwd(), "model_r50_iou_mix_2C020.pth")
- model.to(device)
- checkpoints = torch.load(checkpoint_path, map_location=device)
- model.load_state_dict(checkpoints, strict=False)
- model.eval()
- _ = model(torch.randn((1, 3, 384, 384)))
- return model
- def image_preprocess_transforms(mean=(0.4611, 0.4359, 0.3905), std=(0.2193, 0.2150, 0.2109)):
- common_transforms = torchvision_T.Compose(
- [
- torchvision_T.ToTensor(),
- torchvision_T.Normalize(mean, std),
- ]
- )
- return common_transforms
- def order_points(pts):
- """Rearrange coordinates to order:
- top-left, top-right, bottom-right, bottom-left"""
- rect = np.zeros((4, 2), dtype="float32")
- pts = np.array(pts)
- s = pts.sum(axis=1)
- # Top-left point will have the smallest sum.
- rect[0] = pts[np.argmin(s)]
- # Bottom-right point will have the largest sum.
- rect[2] = pts[np.argmax(s)]
- diff = np.diff(pts, axis=1)
- # Top-right point will have the smallest difference.
- rect[1] = pts[np.argmin(diff)]
- # Bottom-left will have the largest difference.
- rect[3] = pts[np.argmax(diff)]
- # return the ordered coordinates
- return rect.astype("int").tolist()
- def find_dest(pts):
- (tl, tr, br, bl) = pts
- # Finding the maximum width.
- widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
- widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
- maxWidth = max(int(widthA), int(widthB))
- # Finding the maximum height.
- heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
- heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
- maxHeight = max(int(heightA), int(heightB))
- # Final destination co-ordinates.
- destination_corners = [[0, 0], [maxWidth, 0], [maxWidth, maxHeight], [0, maxHeight]]
- return order_points(destination_corners)
- def scan(image_true=None, trained_model=None, image_size=384, BUFFER=10):
- global preprocess_transforms
- IMAGE_SIZE = image_size
- half = IMAGE_SIZE // 2
- imH, imW, C = image_true.shape
- image_model = cv2.resize(image_true, (IMAGE_SIZE, IMAGE_SIZE), interpolation=cv2.INTER_NEAREST)
- scale_x = imW / IMAGE_SIZE
- scale_y = imH / IMAGE_SIZE
- image_model = preprocess_transforms(image_model)
- image_model = torch.unsqueeze(image_model, dim=0)
- with torch.no_grad():
- out = trained_model(image_model)["out"].cpu()
- del image_model
- gc.collect()
- out = torch.argmax(out, dim=1, keepdims=True).permute(0, 2, 3, 1)[0].numpy().squeeze().astype(np.int32)
- r_H, r_W = out.shape
- _out_extended = np.zeros((IMAGE_SIZE + r_H, IMAGE_SIZE + r_W), dtype=out.dtype)
- _out_extended[half : half + IMAGE_SIZE, half : half + IMAGE_SIZE] = out * 255
- out = _out_extended.copy()
- del _out_extended
- gc.collect()
- # Edge Detection.
- canny = cv2.Canny(out.astype(np.uint8), 225, 255)
- canny = cv2.dilate(canny, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)))
- contours, _ = cv2.findContours(canny, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
- page = sorted(contours, key=cv2.contourArea, reverse=True)[0]
- # ==========================================
- epsilon = 0.02 * cv2.arcLength(page, True)
- corners = cv2.approxPolyDP(page, epsilon, True)
- corners = np.concatenate(corners).astype(np.float32)
- corners[:, 0] -= half
- corners[:, 1] -= half
- corners[:, 0] *= scale_x
- corners[:, 1] *= scale_y
- # check if corners are inside.
- # if not find smallest enclosing box, expand_image then extract document
- # else extract document
- if not (np.all(corners.min(axis=0) >= (0, 0)) and np.all(corners.max(axis=0) <= (imW, imH))):
- left_pad, top_pad, right_pad, bottom_pad = 0, 0, 0, 0
- rect = cv2.minAreaRect(corners.reshape((-1, 1, 2)))
- box = cv2.boxPoints(rect)
- box_corners = np.int32(box)
- # box_corners = minimum_bounding_rectangle(corners)
- box_x_min = np.min(box_corners[:, 0])
- box_x_max = np.max(box_corners[:, 0])
- box_y_min = np.min(box_corners[:, 1])
- box_y_max = np.max(box_corners[:, 1])
- # Find corner point which doesn't satify the image constraint
- # and record the amount of shift required to make the box
- # corner satisfy the constraint
- if box_x_min <= 0:
- left_pad = abs(box_x_min) + BUFFER
- if box_x_max >= imW:
- right_pad = (box_x_max - imW) + BUFFER
- if box_y_min <= 0:
- top_pad = abs(box_y_min) + BUFFER
- if box_y_max >= imH:
- bottom_pad = (box_y_max - imH) + BUFFER
- # new image with additional zeros pixels
- image_extended = np.zeros((top_pad + bottom_pad + imH, left_pad + right_pad + imW, C), dtype=image_true.dtype)
- # adjust original image within the new 'image_extended'
- image_extended[top_pad : top_pad + imH, left_pad : left_pad + imW, :] = image_true
- image_extended = image_extended.astype(np.float32)
- # shifting 'box_corners' the required amount
- box_corners[:, 0] += left_pad
- box_corners[:, 1] += top_pad
- corners = box_corners
- image_true = image_extended
- corners = sorted(corners.tolist())
- corners = order_points(corners)
- destination_corners = find_dest(corners)
- M = cv2.getPerspectiveTransform(np.float32(corners), np.float32(destination_corners))
- final = cv2.warpPerspective(image_true, M, (destination_corners[2][0], destination_corners[2][1]), flags=cv2.INTER_LANCZOS4)
- final = np.clip(final, a_min=0, a_max=255)
- final = final.astype(np.uint8)
- return final
- # Generating a link to download a particular image file.
- def get_image_download_link(img, filename, text):
- buffered = io.BytesIO()
- img.save(buffered, format="JPEG")
- img_str = base64.b64encode(buffered.getvalue()).decode()
- href = f'<a href="data:file/txt;base64,{img_str}" download="{filename}">{text}</a>'
- return href
- # We create a downloads directory within the streamlit static asset directory
- # and we write output files to it
- STREAMLIT_STATIC_PATH = pathlib.Path(st.__path__[0]) / "static"
- DOWNLOADS_PATH = STREAMLIT_STATIC_PATH / "downloads"
- if not DOWNLOADS_PATH.is_dir():
- DOWNLOADS_PATH.mkdir()
- IMAGE_SIZE = 384
- preprocess_transforms = image_preprocess_transforms()
- image = None
- final = None
- result = None
- st.set_page_config(initial_sidebar_state="collapsed")
- st.title("Document Scanner: Semantic Segmentation using DeepLabV3-PyTorch")
- uploaded_file = st.file_uploader("Upload Document Image :", type=["png", "jpg", "jpeg"])
- method = st.radio("Select Document Segmentation Model:", ("MobilenetV3-Large", "Resnet-50"), horizontal=True)
- col1, col2 = st.columns((6, 5))
- if uploaded_file is not None:
- # Convert the file to an opencv image.
- file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
- image = cv2.imdecode(file_bytes, 1)
- h, w = image.shape[:2]
- if method == "MobilenetV3-Large":
- model = load_model(model_name="mbv3")
- else:
- model = load_model(model_name="r50")
- with col1:
- st.title("Input")
- st.image(image, channels="BGR", use_column_width=True)
- with col2:
- st.title("Scanned")
- final = scan(image_true=image, trained_model=model, image_size=IMAGE_SIZE)
- st.image(final, channels="BGR", use_column_width=True)
- if final is not None:
- # Display link.
- result = Image.fromarray(final[:, :, ::-1])
- st.markdown(get_image_download_link(result, "output.png", "Download " + "Output"), unsafe_allow_html=True)
|