AgeGender.cpp 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. #include <opencv2/imgproc.hpp>
  2. #include <opencv2/highgui.hpp>
  3. #include <opencv2/dnn.hpp>
  4. #include <tuple>
  5. #include <iostream>
  6. #include <opencv2/opencv.hpp>
  7. #include <iterator>
  8. using namespace cv;
  9. using namespace cv::dnn;
  10. using namespace std;
  11. tuple<Mat, vector<vector<int>>> getFaceBox(Net net, Mat &frame, double conf_threshold)
  12. {
  13. Mat frameOpenCVDNN = frame.clone();
  14. int frameHeight = frameOpenCVDNN.rows;
  15. int frameWidth = frameOpenCVDNN.cols;
  16. double inScaleFactor = 1.0;
  17. Size size = Size(300, 300);
  18. // std::vector<int> meanVal = {104, 117, 123};
  19. Scalar meanVal = Scalar(104, 117, 123);
  20. cv::Mat inputBlob;
  21. inputBlob = cv::dnn::blobFromImage(frameOpenCVDNN, inScaleFactor, size, meanVal, true, false);
  22. net.setInput(inputBlob, "data");
  23. cv::Mat detection = net.forward("detection_out");
  24. cv::Mat detectionMat(detection.size[2], detection.size[3], CV_32F, detection.ptr<float>());
  25. vector<vector<int>> bboxes;
  26. for(int i = 0; i < detectionMat.rows; i++)
  27. {
  28. float confidence = detectionMat.at<float>(i, 2);
  29. if(confidence > conf_threshold)
  30. {
  31. int x1 = static_cast<int>(detectionMat.at<float>(i, 3) * frameWidth);
  32. int y1 = static_cast<int>(detectionMat.at<float>(i, 4) * frameHeight);
  33. int x2 = static_cast<int>(detectionMat.at<float>(i, 5) * frameWidth);
  34. int y2 = static_cast<int>(detectionMat.at<float>(i, 6) * frameHeight);
  35. vector<int> box = {x1, y1, x2, y2};
  36. bboxes.push_back(box);
  37. cv::rectangle(frameOpenCVDNN, cv::Point(x1, y1), cv::Point(x2, y2), cv::Scalar(0, 255, 0),2, 4);
  38. }
  39. }
  40. return make_tuple(frameOpenCVDNN, bboxes);
  41. }
  42. int main(int argc, char** argv)
  43. {
  44. string faceProto = "opencv_face_detector.pbtxt";
  45. string faceModel = "opencv_face_detector_uint8.pb";
  46. string ageProto = "age_deploy.prototxt";
  47. string ageModel = "age_net.caffemodel";
  48. string genderProto = "gender_deploy.prototxt";
  49. string genderModel = "gender_net.caffemodel";
  50. Scalar MODEL_MEAN_VALUES = Scalar(78.4263377603, 87.7689143744, 114.895847746);
  51. vector<string> ageList = {"(0-2)", "(4-6)", "(8-12)", "(15-20)", "(25-32)",
  52. "(38-43)", "(48-53)", "(60-100)"};
  53. vector<string> genderList = {"Male", "Female"};
  54. cout << "USAGE : ./AgeGender <videoFile> " << endl;
  55. cout << "USAGE : ./AgeGender <device> " << endl;
  56. cout << "USAGE : ./AgeGender <videoFile> <device>" << endl;
  57. string device = "cpu";
  58. string videoFile = "0";
  59. // Take arguments from commmand line
  60. if (argc == 2)
  61. {
  62. if((string)argv[1] == "gpu")
  63. device = "gpu";
  64. else if((string)argv[1] == "cpu")
  65. device = "cpu";
  66. else
  67. videoFile = argv[1];
  68. }
  69. else if (argc == 3)
  70. {
  71. videoFile = argv[1];
  72. if((string)argv[2] == "gpu")
  73. device = "gpu";
  74. }
  75. // Load Network
  76. Net ageNet = readNet(ageModel, ageProto);
  77. Net genderNet = readNet(genderModel, genderProto);
  78. Net faceNet = readNet(faceModel, faceProto);
  79. if (device == "cpu")
  80. {
  81. cout << "Using CPU device" << endl;
  82. ageNet.setPreferableBackend(DNN_TARGET_CPU);
  83. genderNet.setPreferableBackend(DNN_TARGET_CPU);
  84. faceNet.setPreferableBackend(DNN_TARGET_CPU);
  85. }
  86. else if (device == "gpu")
  87. {
  88. cout << "Using GPU device" << endl;
  89. ageNet.setPreferableBackend(DNN_BACKEND_CUDA);
  90. ageNet.setPreferableTarget(DNN_TARGET_CUDA);
  91. genderNet.setPreferableBackend(DNN_BACKEND_CUDA);
  92. genderNet.setPreferableTarget(DNN_TARGET_CUDA);
  93. faceNet.setPreferableBackend(DNN_BACKEND_CUDA);
  94. faceNet.setPreferableTarget(DNN_TARGET_CUDA);
  95. }
  96. VideoCapture cap;
  97. if (videoFile.length() > 1)
  98. cap.open(videoFile);
  99. else
  100. cap.open(0);
  101. int padding = 20;
  102. while(waitKey(1) < 0) {
  103. // read frame
  104. Mat frame;
  105. cap.read(frame);
  106. if (frame.empty())
  107. {
  108. waitKey();
  109. break;
  110. }
  111. vector<vector<int>> bboxes;
  112. Mat frameFace;
  113. tie(frameFace, bboxes) = getFaceBox(faceNet, frame, 0.7);
  114. if(bboxes.size() == 0) {
  115. cout << "No face detected, checking next frame." << endl;
  116. continue;
  117. }
  118. for (auto it = begin(bboxes); it != end(bboxes); ++it) {
  119. Rect rec(it->at(0) - padding, it->at(1) - padding, it->at(2) - it->at(0) + 2*padding, it->at(3) - it->at(1) + 2*padding);
  120. Mat face = frame(rec); // take the ROI of box on the frame
  121. Mat blob;
  122. blob = blobFromImage(face, 1, Size(227, 227), MODEL_MEAN_VALUES, false);
  123. genderNet.setInput(blob);
  124. // string gender_preds;
  125. vector<float> genderPreds = genderNet.forward();
  126. // printing gender here
  127. // find max element index
  128. // distance function does the argmax() work in C++
  129. int max_index_gender = std::distance(genderPreds.begin(), max_element(genderPreds.begin(), genderPreds.end()));
  130. string gender = genderList[max_index_gender];
  131. cout << "Gender: " << gender << endl;
  132. /* // Uncomment if you want to iterate through the gender_preds vector
  133. for(auto it=begin(gender_preds); it != end(gender_preds); ++it) {
  134. cout << *it << endl;
  135. }
  136. */
  137. ageNet.setInput(blob);
  138. vector<float> agePreds = ageNet.forward();
  139. /* // uncomment below code if you want to iterate through the age_preds
  140. * vector
  141. cout << "PRINTING AGE_PREDS" << endl;
  142. for(auto it = age_preds.begin(); it != age_preds.end(); ++it) {
  143. cout << *it << endl;
  144. }
  145. */
  146. // finding maximum indicd in the age_preds vector
  147. int max_indice_age = std::distance(agePreds.begin(), max_element(agePreds.begin(), agePreds.end()));
  148. string age = ageList[max_indice_age];
  149. cout << "Age: " << age << endl;
  150. string label = gender + ", " + age; // label
  151. cv::putText(frameFace, label, Point(it->at(0), it->at(1) -15), cv::FONT_HERSHEY_SIMPLEX, 0.9, Scalar(0, 255, 255), 2, cv::LINE_AA);
  152. imshow("Frame", frameFace);
  153. imwrite("out.jpg",frameFace);
  154. }
  155. }
  156. }