textDetection.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # Import required modules
  2. import cv2 as cv
  3. import math
  4. import argparse
  5. parser = argparse.ArgumentParser(description='Use this script to run text detection deep learning networks using OpenCV.')
  6. # Input argument
  7. parser.add_argument('--input', help='Path to input image or video file. Skip this argument to capture frames from a camera.')
  8. # Model argument
  9. parser.add_argument('--model', default="frozen_east_text_detection.pb",
  10. help='Path to a binary .pb file of model contains trained weights.'
  11. )
  12. # Width argument
  13. parser.add_argument('--width', type=int, default=320,
  14. help='Preprocess input image by resizing to a specific width. It should be multiple by 32.'
  15. )
  16. # Height argument
  17. parser.add_argument('--height',type=int, default=320,
  18. help='Preprocess input image by resizing to a specific height. It should be multiple by 32.'
  19. )
  20. # Confidence threshold
  21. parser.add_argument('--thr',type=float, default=0.5,
  22. help='Confidence threshold.'
  23. )
  24. # Non-maximum suppression threshold
  25. parser.add_argument('--nms',type=float, default=0.4,
  26. help='Non-maximum suppression threshold.'
  27. )
  28. parser.add_argument('--device', default="cpu", help="Device to inference on")
  29. args = parser.parse_args()
  30. ############ Utility functions ############
  31. def decode(scores, geometry, scoreThresh):
  32. detections = []
  33. confidences = []
  34. ############ CHECK DIMENSIONS AND SHAPES OF geometry AND scores ############
  35. assert len(scores.shape) == 4, "Incorrect dimensions of scores"
  36. assert len(geometry.shape) == 4, "Incorrect dimensions of geometry"
  37. assert scores.shape[0] == 1, "Invalid dimensions of scores"
  38. assert geometry.shape[0] == 1, "Invalid dimensions of geometry"
  39. assert scores.shape[1] == 1, "Invalid dimensions of scores"
  40. assert geometry.shape[1] == 5, "Invalid dimensions of geometry"
  41. assert scores.shape[2] == geometry.shape[2], "Invalid dimensions of scores and geometry"
  42. assert scores.shape[3] == geometry.shape[3], "Invalid dimensions of scores and geometry"
  43. height = scores.shape[2]
  44. width = scores.shape[3]
  45. for y in range(0, height):
  46. # Extract data from scores
  47. scoresData = scores[0][0][y]
  48. x0_data = geometry[0][0][y]
  49. x1_data = geometry[0][1][y]
  50. x2_data = geometry[0][2][y]
  51. x3_data = geometry[0][3][y]
  52. anglesData = geometry[0][4][y]
  53. for x in range(0, width):
  54. score = scoresData[x]
  55. # If score is lower than threshold score, move to next x
  56. if(score < scoreThresh):
  57. continue
  58. # Calculate offset
  59. offsetX = x * 4.0
  60. offsetY = y * 4.0
  61. angle = anglesData[x]
  62. # Calculate cos and sin of angle
  63. cosA = math.cos(angle)
  64. sinA = math.sin(angle)
  65. h = x0_data[x] + x2_data[x]
  66. w = x1_data[x] + x3_data[x]
  67. # Calculate offset
  68. offset = ([offsetX + cosA * x1_data[x] + sinA * x2_data[x], offsetY - sinA * x1_data[x] + cosA * x2_data[x]])
  69. # Find points for rectangle
  70. p1 = (-sinA * h + offset[0], -cosA * h + offset[1])
  71. p3 = (-cosA * w + offset[0], sinA * w + offset[1])
  72. center = (0.5*(p1[0]+p3[0]), 0.5*(p1[1]+p3[1]))
  73. detections.append((center, (w,h), -1*angle * 180.0 / math.pi))
  74. confidences.append(float(score))
  75. # Return detections and confidences
  76. return [detections, confidences]
  77. if __name__ == "__main__":
  78. # Read and store arguments
  79. confThreshold = args.thr
  80. nmsThreshold = args.nms
  81. inpWidth = args.width
  82. inpHeight = args.height
  83. model = args.model
  84. # Load network
  85. net = cv.dnn.readNet(model)
  86. if args.device == "cpu":
  87. net.setPreferableBackend(cv.dnn.DNN_TARGET_CPU)
  88. print("Using CPU device")
  89. elif args.device == "gpu":
  90. net.setPreferableBackend(cv.dnn.DNN_BACKEND_CUDA)
  91. net.setPreferableTarget(cv.dnn.DNN_TARGET_CUDA)
  92. print("Using GPU device")
  93. # Create a new named window
  94. kWinName = "EAST: An Efficient and Accurate Scene Text Detector"
  95. cv.namedWindow(kWinName, cv.WINDOW_NORMAL)
  96. outputLayers = []
  97. outputLayers.append("feature_fusion/Conv_7/Sigmoid")
  98. outputLayers.append("feature_fusion/concat_3")
  99. # Open a video file or an image file or a camera stream
  100. cap = cv.VideoCapture(args.input if args.input else 0)
  101. while cv.waitKey(1) < 0:
  102. # Read frame
  103. hasFrame, frame = cap.read()
  104. if not hasFrame:
  105. cv.waitKey()
  106. break
  107. # Get frame height and width
  108. height_ = frame.shape[0]
  109. width_ = frame.shape[1]
  110. rW = width_ / float(inpWidth)
  111. rH = height_ / float(inpHeight)
  112. # Create a 4D blob from frame.
  113. blob = cv.dnn.blobFromImage(frame, 1.0, (inpWidth, inpHeight), (123.68, 116.78, 103.94), True, False)
  114. # Run the model
  115. net.setInput(blob)
  116. output = net.forward(outputLayers)
  117. t, _ = net.getPerfProfile()
  118. label = 'Inference time: %.2f ms' % (t * 1000.0 / cv.getTickFrequency())
  119. # Get scores and geometry
  120. scores = output[0]
  121. geometry = output[1]
  122. [boxes, confidences] = decode(scores, geometry, confThreshold)
  123. # Apply NMS
  124. indices = cv.dnn.NMSBoxesRotated(boxes, confidences, confThreshold,nmsThreshold)
  125. for i in indices:
  126. # get 4 corners of the rotated rect
  127. vertices = cv.boxPoints(boxes[i[0]])
  128. # scale the bounding box coordinates based on the respective ratios
  129. for j in range(4):
  130. vertices[j][0] *= rW
  131. vertices[j][1] *= rH
  132. for j in range(4):
  133. p1 = (vertices[j][0], vertices[j][1])
  134. p2 = (vertices[(j + 1) % 4][0], vertices[(j + 1) % 4][1])
  135. cv.line(frame, p1, p2, (0, 255, 0), 2, cv.LINE_AA)
  136. # cv.putText(frame, "{:.3f}".format(confidences[i[0]]), (vertices[0][0], vertices[0][1]), cv.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1, cv.LINE_AA)
  137. # Put efficiency information
  138. cv.putText(frame, label, (0, 15), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255))
  139. # Display the frame
  140. cv.imshow(kWinName,frame)
  141. cv.imwrite("output.png",frame)