mask_rcnn.cpp 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. // Copyright (C) 2018-2019, BigVision LLC (LearnOpenCV.com), All Rights Reserved.
  2. // Author : Sunita Nayak
  3. // Article : https://www.learnopencv.com/deep-learning-based-object-detection-and-instance-segmentation-using-mask-r-cnn-in-opencv-python-c/
  4. // License: BSD-3-Clause-Attribution (Please read the license file.)
  5. // Usage example: ./mask_rcnn.out --video=run.mp4
  6. // ./mask_rcnn.out --image=bird.jpg
  7. #include <fstream>
  8. #include <sstream>
  9. #include <iostream>
  10. #include <string.h>
  11. #include <opencv2/dnn.hpp>
  12. #include <opencv2/imgproc.hpp>
  13. #include <opencv2/highgui.hpp>
  14. const char* keys =
  15. "{help h usage ? | | Usage examples: \n\t\t./mask-rcnn.out --image=traffic.jpg \n\t\t./mask-rcnn.out --video=sample.mp4}"
  16. "{image i |<none>| input image }"
  17. "{video v |<none>| input video }"
  18. "{device d |<none>| device }"
  19. ;
  20. using namespace cv;
  21. using namespace dnn;
  22. using namespace std;
  23. // Initialize the parameters
  24. float confThreshold = 0.5; // Confidence threshold
  25. float maskThreshold = 0.3; // Mask threshold
  26. vector<string> classes;
  27. vector<Scalar> colors;
  28. // Draw the predicted bounding box
  29. void drawBox(Mat& frame, int classId, float conf, Rect box, Mat& objectMask);
  30. // Postprocess the neural network's output for each frame
  31. void postprocess(Mat& frame, const vector<Mat>& outs);
  32. int main(int argc, char** argv)
  33. {
  34. CommandLineParser parser(argc, argv, keys);
  35. parser.about("Use this script to run object detection using YOLO3 in OpenCV.");
  36. if (parser.has("help"))
  37. {
  38. parser.printMessage();
  39. return 0;
  40. }
  41. // Load names of classes
  42. string classesFile = "mscoco_labels.names";
  43. ifstream ifs(classesFile.c_str());
  44. string line;
  45. while (getline(ifs, line)) classes.push_back(line);
  46. string device = parser.get<String>("device");
  47. // Load the colors
  48. string colorsFile = "colors.txt";
  49. ifstream colorFptr(colorsFile.c_str());
  50. while (getline(colorFptr, line)) {
  51. char* pEnd;
  52. double r, g, b;
  53. r = strtod (line.c_str(), &pEnd);
  54. g = strtod (pEnd, NULL);
  55. b = strtod (pEnd, NULL);
  56. Scalar color = Scalar(r, g, b, 255.0);
  57. colors.push_back(Scalar(r, g, b, 255.0));
  58. }
  59. // Give the configuration and weight files for the model
  60. String textGraph = "./mask_rcnn_inception_v2_coco_2018_01_28.pbtxt";
  61. String modelWeights = "./mask_rcnn_inception_v2_coco_2018_01_28/frozen_inference_graph.pb";
  62. // Load the network
  63. Net net = readNetFromTensorflow(modelWeights, textGraph);
  64. if (device == "cpu")
  65. {
  66. cout << "Using CPU device" << endl;
  67. net.setPreferableBackend(DNN_TARGET_CPU);
  68. }
  69. else if (device == "gpu")
  70. {
  71. cout << "Using GPU device" << endl;
  72. net.setPreferableBackend(DNN_BACKEND_CUDA);
  73. net.setPreferableTarget(DNN_TARGET_CUDA);
  74. }
  75. // Open a video file or an image file or a camera stream.
  76. string str, outputFile;
  77. VideoCapture cap;
  78. VideoWriter video;
  79. Mat frame, blob;
  80. try {
  81. outputFile = "mask_rcnn_out_cpp.avi";
  82. if (parser.has("image"))
  83. {
  84. // Open the image file
  85. str = parser.get<String>("image");
  86. //cout << "Image file input : " << str << endl;
  87. ifstream ifile(str);
  88. if (!ifile) throw("error");
  89. cap.open(str);
  90. str.replace(str.end()-4, str.end(), "_mask_rcnn_out.jpg");
  91. outputFile = str;
  92. }
  93. else if (parser.has("video"))
  94. {
  95. // Open the video file
  96. str = parser.get<String>("video");
  97. ifstream ifile(str);
  98. if (!ifile) throw("error");
  99. cap.open(str);
  100. str.replace(str.end()-4, str.end(), "_mask_rcnn_out.avi");
  101. outputFile = str;
  102. }
  103. // Open the webcam
  104. else cap.open(parser.get<int>("webcam"));
  105. }
  106. catch(...) {
  107. cout << "Could not open the input image/video stream" << endl;
  108. return 0;
  109. }
  110. // Get the video writer initialized to save the output video
  111. if (!parser.has("image")) {
  112. video.open(outputFile, VideoWriter::fourcc('M','J','P','G'), 28, Size(cap.get(CAP_PROP_FRAME_WIDTH), cap.get(CAP_PROP_FRAME_HEIGHT)));
  113. }
  114. // Create a window
  115. static const string kWinName = "Deep learning object detection in OpenCV";
  116. namedWindow(kWinName, WINDOW_NORMAL);
  117. // Process frames.
  118. while (waitKey(1) < 0)
  119. {
  120. // get frame from the video
  121. cap >> frame;
  122. // Stop the program if reached end of video
  123. if (frame.empty()) {
  124. cout << "Done processing !!!" << endl;
  125. cout << "Output file is stored as " << outputFile << endl;
  126. waitKey(3000);
  127. break;
  128. }
  129. // Create a 4D blob from a frame.
  130. blobFromImage(frame, blob, 1.0, Size(frame.cols, frame.rows), Scalar(), true, false);
  131. //blobFromImage(frame, blob);
  132. //Sets the input to the network
  133. net.setInput(blob);
  134. // Runs the forward pass to get output from the output layers
  135. std::vector<String> outNames(2);
  136. outNames[0] = "detection_out_final";
  137. outNames[1] = "detection_masks";
  138. vector<Mat> outs;
  139. net.forward(outs, outNames);
  140. // Extract the bounding box and mask for each of the detected objects
  141. postprocess(frame, outs);
  142. // Put efficiency information. The function getPerfProfile returns the overall time for inference(t) and the timings for each of the layers(in layersTimes)
  143. vector<double> layersTimes;
  144. double freq = getTickFrequency() / 1000;
  145. double t = net.getPerfProfile(layersTimes) / freq;
  146. string label = format("Mask-RCNN on 2.5 GHz Intel Core i7 CPU, Inference time for a frame : %0.0f ms", t);
  147. putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 0, 0));
  148. // Write the frame with the detection boxes
  149. Mat detectedFrame;
  150. frame.convertTo(detectedFrame, CV_8U);
  151. if (parser.has("image")) imwrite(outputFile, detectedFrame);
  152. else video.write(detectedFrame);
  153. imshow(kWinName, frame);
  154. }
  155. cap.release();
  156. if (!parser.has("image")) video.release();
  157. return 0;
  158. }
  159. // For each frame, extract the bounding box and mask for each detected object
  160. void postprocess(Mat& frame, const vector<Mat>& outs)
  161. {
  162. Mat outDetections = outs[0];
  163. Mat outMasks = outs[1];
  164. // Output size of masks is NxCxHxW where
  165. // N - number of detected boxes
  166. // C - number of classes (excluding background)
  167. // HxW - segmentation shape
  168. const int numDetections = outDetections.size[2];
  169. const int numClasses = outMasks.size[1];
  170. outDetections = outDetections.reshape(1, outDetections.total() / 7);
  171. for (int i = 0; i < numDetections; ++i)
  172. {
  173. float score = outDetections.at<float>(i, 2);
  174. if (score > confThreshold)
  175. {
  176. // Extract the bounding box
  177. int classId = static_cast<int>(outDetections.at<float>(i, 1));
  178. int left = static_cast<int>(frame.cols * outDetections.at<float>(i, 3));
  179. int top = static_cast<int>(frame.rows * outDetections.at<float>(i, 4));
  180. int right = static_cast<int>(frame.cols * outDetections.at<float>(i, 5));
  181. int bottom = static_cast<int>(frame.rows * outDetections.at<float>(i, 6));
  182. left = max(0, min(left, frame.cols - 1));
  183. top = max(0, min(top, frame.rows - 1));
  184. right = max(0, min(right, frame.cols - 1));
  185. bottom = max(0, min(bottom, frame.rows - 1));
  186. Rect box = Rect(left, top, right - left + 1, bottom - top + 1);
  187. // Extract the mask for the object
  188. Mat objectMask(outMasks.size[2], outMasks.size[3],CV_32F, outMasks.ptr<float>(i,classId));
  189. // Draw bounding box, colorize and show the mask on the image
  190. drawBox(frame, classId, score, box, objectMask);
  191. }
  192. }
  193. }
  194. // Draw the predicted bounding box, colorize and show the mask on the image
  195. void drawBox(Mat& frame, int classId, float conf, Rect box, Mat& objectMask)
  196. {
  197. //Draw a rectangle displaying the bounding box
  198. rectangle(frame, Point(box.x, box.y), Point(box.x+box.width, box.y+box.height), Scalar(255, 178, 50), 3);
  199. //Get the label for the class name and its confidence
  200. string label = format("%.2f", conf);
  201. if (!classes.empty())
  202. {
  203. CV_Assert(classId < (int)classes.size());
  204. label = classes[classId] + ":" + label;
  205. }
  206. //Display the label at the top of the bounding box
  207. int baseLine;
  208. Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
  209. box.y = max(box.y, labelSize.height);
  210. rectangle(frame, Point(box.x, box.y - round(1.5*labelSize.height)), Point(box.x + round(1.5*labelSize.width), box.y + baseLine), Scalar(255, 255, 255), FILLED);
  211. putText(frame, label, Point(box.x, box.y), FONT_HERSHEY_SIMPLEX, 0.75, Scalar(0,0,0),1);
  212. Scalar color = colors[classId%colors.size()];
  213. // Resize the mask, threshold, color and apply it on the image
  214. resize(objectMask, objectMask, Size(box.width, box.height));
  215. Mat mask = (objectMask > maskThreshold);
  216. Mat coloredRoi = (0.3 * color + 0.7 * frame(box));
  217. coloredRoi.convertTo(coloredRoi, CV_8UC3);
  218. // Draw the contours on the image
  219. vector<Mat> contours;
  220. Mat hierarchy;
  221. mask.convertTo(mask, CV_8U);
  222. findContours(mask, contours, hierarchy, RETR_CCOMP, CHAIN_APPROX_SIMPLE);
  223. drawContours(coloredRoi, contours, -1, color, 5, LINE_8, hierarchy, 100);
  224. coloredRoi.copyTo(frame(box), mask);
  225. }