textDetection.cpp 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. #include <opencv2/imgproc.hpp>
  2. #include <opencv2/highgui.hpp>
  3. #include <opencv2/dnn.hpp>
  4. #include <iostream>
  5. using namespace cv;
  6. using namespace cv::dnn;
  7. const char* keys =
  8. "{ help h | | Print help message. }"
  9. "{ input i | | Path to input image or video file. Skip this argument to capture frames from a camera.}"
  10. "{ model m | frozen_east_text_detection.pb | Path to a binary .pb file contains trained network.}"
  11. "{ width | 320 | Preprocess input image by resizing to a specific width. It should be multiple by 32. }"
  12. "{ height | 320 | Preprocess input image by resizing to a specific height. It should be multiple by 32. }"
  13. "{ thr | 0.5 | Confidence threshold. }"
  14. "{ nms | 0.4 | Non-maximum suppression threshold. }"
  15. "{ device | cpu | Device to run Deep Learning inference. }";
  16. void decode(const Mat& scores, const Mat& geometry, float scoreThresh,
  17. std::vector<RotatedRect>& detections, std::vector<float>& confidences);
  18. int main(int argc, char** argv)
  19. {
  20. // Parse command line arguments.
  21. CommandLineParser parser(argc, argv, keys);
  22. parser.about("Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of "
  23. "EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)");
  24. if (argc == 1 || parser.has("help"))
  25. {
  26. parser.printMessage();
  27. return 0;
  28. }
  29. float confThreshold = parser.get<float>("thr");
  30. float nmsThreshold = parser.get<float>("nms");
  31. int inpWidth = parser.get<int>("width");
  32. int inpHeight = parser.get<int>("height");
  33. String model = parser.get<String>("model");
  34. if (!parser.check())
  35. {
  36. parser.printErrors();
  37. return 1;
  38. }
  39. CV_Assert(!model.empty());
  40. String device = parser.get<String>("device");
  41. // Load network.
  42. Net net = readNet(model);
  43. if (device == "cpu")
  44. {
  45. std::cout << "Using CPU device" << std::endl;
  46. net.setPreferableBackend(DNN_TARGET_CPU);
  47. }
  48. else if (device == "gpu")
  49. {
  50. std::cout << "Using GPU device" << std::endl;
  51. net.setPreferableBackend(DNN_BACKEND_CUDA);
  52. net.setPreferableTarget(DNN_TARGET_CUDA);
  53. }
  54. // Open a video file or an image file or a camera stream.
  55. VideoCapture cap;
  56. if (parser.has("input"))
  57. cap.open(parser.get<String>("input"));
  58. else
  59. cap.open(0);
  60. static const std::string kWinName = "EAST: An Efficient and Accurate Scene Text Detector";
  61. namedWindow(kWinName, WINDOW_NORMAL);
  62. std::vector<Mat> output;
  63. std::vector<String> outputLayers(2);
  64. outputLayers[0] = "feature_fusion/Conv_7/Sigmoid";
  65. outputLayers[1] = "feature_fusion/concat_3";
  66. Mat frame, blob;
  67. while (waitKey(1) < 0)
  68. {
  69. cap >> frame;
  70. if (frame.empty())
  71. {
  72. waitKey();
  73. break;
  74. }
  75. blobFromImage(frame, blob, 1.0, Size(inpWidth, inpHeight), Scalar(123.68, 116.78, 103.94), true, false);
  76. net.setInput(blob);
  77. net.forward(output, outputLayers);
  78. Mat scores = output[0];
  79. Mat geometry = output[1];
  80. // Decode predicted bounding boxes.
  81. std::vector<RotatedRect> boxes;
  82. std::vector<float> confidences;
  83. decode(scores, geometry, confThreshold, boxes, confidences);
  84. // Apply non-maximum suppression procedure.
  85. std::vector<int> indices;
  86. NMSBoxes(boxes, confidences, confThreshold, nmsThreshold, indices);
  87. // Render detections.
  88. Point2f ratio((float)frame.cols / inpWidth, (float)frame.rows / inpHeight);
  89. for (size_t i = 0; i < indices.size(); ++i)
  90. {
  91. RotatedRect& box = boxes[indices[i]];
  92. Point2f vertices[4];
  93. box.points(vertices);
  94. for (int j = 0; j < 4; ++j)
  95. {
  96. vertices[j].x *= ratio.x;
  97. vertices[j].y *= ratio.y;
  98. }
  99. for (int j = 0; j < 4; ++j)
  100. line(frame, vertices[j], vertices[(j + 1) % 4], Scalar(0, 255, 0), 2, LINE_AA);
  101. }
  102. // Put efficiency information.
  103. std::vector<double> layersTimes;
  104. double freq = getTickFrequency() / 1000;
  105. double t = net.getPerfProfile(layersTimes) / freq;
  106. std::string label = format("Inference time: %.2f ms", t);
  107. putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
  108. imshow(kWinName, frame);
  109. }
  110. return 0;
  111. }
  112. void decode(const Mat& scores, const Mat& geometry, float scoreThresh,
  113. std::vector<RotatedRect>& detections, std::vector<float>& confidences)
  114. {
  115. detections.clear();
  116. CV_Assert(scores.dims == 4); CV_Assert(geometry.dims == 4); CV_Assert(scores.size[0] == 1);
  117. CV_Assert(geometry.size[0] == 1); CV_Assert(scores.size[1] == 1); CV_Assert(geometry.size[1] == 5);
  118. CV_Assert(scores.size[2] == geometry.size[2]); CV_Assert(scores.size[3] == geometry.size[3]);
  119. const int height = scores.size[2];
  120. const int width = scores.size[3];
  121. for (int y = 0; y < height; ++y)
  122. {
  123. const float* scoresData = scores.ptr<float>(0, 0, y);
  124. const float* x0_data = geometry.ptr<float>(0, 0, y);
  125. const float* x1_data = geometry.ptr<float>(0, 1, y);
  126. const float* x2_data = geometry.ptr<float>(0, 2, y);
  127. const float* x3_data = geometry.ptr<float>(0, 3, y);
  128. const float* anglesData = geometry.ptr<float>(0, 4, y);
  129. for (int x = 0; x < width; ++x)
  130. {
  131. float score = scoresData[x];
  132. if (score < scoreThresh)
  133. continue;
  134. // Decode a prediction.
  135. // Multiple by 4 because feature maps are 4 time less than input image.
  136. float offsetX = x * 4.0f, offsetY = y * 4.0f;
  137. float angle = anglesData[x];
  138. float cosA = std::cos(angle);
  139. float sinA = std::sin(angle);
  140. float h = x0_data[x] + x2_data[x];
  141. float w = x1_data[x] + x3_data[x];
  142. Point2f offset(offsetX + cosA * x1_data[x] + sinA * x2_data[x],
  143. offsetY - sinA * x1_data[x] + cosA * x2_data[x]);
  144. Point2f p1 = Point2f(-sinA * h, -cosA * h) + offset;
  145. Point2f p3 = Point2f(-cosA * w, sinA * w) + offset;
  146. RotatedRect r(0.5f * (p1 + p3), Size2f(w, h), -angle * 180.0f / (float)CV_PI);
  147. detections.push_back(r);
  148. confidences.push_back(score);
  149. }
  150. }
  151. }