colorizeImage.cpp 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. // This code is written by Sunita Nayak at BigVision LLC. It is based on the OpenCV project. It is subject to the license terms in the LICENSE file found in this distribution and at http://opencv.org/license.html
  2. // Usage example: ./colorizeImage.out greyscaleImage.png
  3. #include <opencv2/dnn.hpp>
  4. #include <opencv2/imgproc.hpp>
  5. #include <opencv2/highgui.hpp>
  6. #include <iostream>
  7. using namespace cv;
  8. using namespace cv::dnn;
  9. using namespace std;
  10. // the 313 ab cluster centers from pts_in_hull.npy (already transposed)
  11. static float hull_pts[] = {
  12. -90., -90., -90., -90., -90., -80., -80., -80., -80., -80., -80., -80., -80., -70., -70., -70., -70., -70., -70., -70., -70.,
  13. -70., -70., -60., -60., -60., -60., -60., -60., -60., -60., -60., -60., -60., -60., -50., -50., -50., -50., -50., -50., -50., -50.,
  14. -50., -50., -50., -50., -50., -50., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -30.,
  15. -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -20., -20., -20., -20., -20., -20., -20.,
  16. -20., -20., -20., -20., -20., -20., -20., -20., -20., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
  17. -10., -10., -10., -10., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 10., 10., 10., 10., 10., 10., 10.,
  18. 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20.,
  19. 20., 20., 20., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 40., 40., 40., 40.,
  20. 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 50., 50., 50., 50., 50., 50., 50., 50., 50., 50.,
  21. 50., 50., 50., 50., 50., 50., 50., 50., 50., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60.,
  22. 60., 60., 60., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 80., 80., 80.,
  23. 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 90., 90., 90., 90., 90., 90., 90., 90., 90., 90.,
  24. 90., 90., 90., 90., 90., 90., 90., 90., 90., 100., 100., 100., 100., 100., 100., 100., 100., 100., 100., 50., 60., 70., 80., 90.,
  25. 20., 30., 40., 50., 60., 70., 80., 90., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., -20., -10., 0., 10., 20., 30., 40., 50.,
  26. 60., 70., 80., 90., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., 100., -40., -30., -20., -10., 0., 10., 20.,
  27. 30., 40., 50., 60., 70., 80., 90., 100., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., 100., -50.,
  28. -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., 100., -60., -50., -40., -30., -20., -10., 0., 10., 20.,
  29. 30., 40., 50., 60., 70., 80., 90., 100., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90.,
  30. 100., -80., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., -80., -70., -60., -50.,
  31. -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., -90., -80., -70., -60., -50., -40., -30., -20., -10.,
  32. 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., -100., -90., -80., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30.,
  33. 40., 50., 60., 70., 80., 90., -100., -90., -80., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70.,
  34. 80., -110., -100., -90., -80., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., -110., -100.,
  35. -90., -80., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., -110., -100., -90., -80., -70.,
  36. -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., -110., -100., -90., -80., -70., -60., -50., -40., -30.,
  37. -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., -90., -80., -70., -60., -50., -40., -30., -20., -10., 0.
  38. };
  39. int main(int argc, char **argv)
  40. {
  41. string imageFileName;
  42. string device;
  43. // Take arguments from command line
  44. if (argc == 3)
  45. {
  46. device = argv[2];
  47. }
  48. else if (argc == 2)
  49. device = "cpu";
  50. else
  51. {
  52. cout << "Please input the greyscale image filename." << endl;
  53. cout << "Usage example: ./colorizeImage.out greyscaleImage.png" << endl;
  54. cout << "If you want to use GPU device instead of CPU, add one more argument." << endl;
  55. cout << "Usage example:./colorizeImage.out greyscaleImage.png gpu" << endl;
  56. return 1;
  57. }
  58. imageFileName = argv[1];
  59. Mat img = imread(imageFileName);
  60. if (img.empty())
  61. {
  62. cout << "Can't read image from file: " << imageFileName << endl;
  63. return 1;
  64. }
  65. cout << "Input image file: " << imageFileName << endl;
  66. string protoFile = "./models/colorization_deploy_v2.prototxt";
  67. string weightsFile = "./models/colorization_release_v2.caffemodel";
  68. // fixed input size for the pre-trained network
  69. const int W_in = 224;
  70. const int H_in = 224;
  71. Net net = dnn::readNetFromCaffe(protoFile, weightsFile);
  72. if (device != "gpu")
  73. {
  74. cout << "Using CPU device" << endl;
  75. net.setPreferableBackend(DNN_TARGET_CPU);
  76. }
  77. else
  78. {
  79. cout << "Using GPU device" << endl;
  80. net.setPreferableBackend(DNN_BACKEND_CUDA);
  81. net.setPreferableTarget(DNN_TARGET_CUDA);
  82. }
  83. // setup additional layers:
  84. int sz[] = {2, 313, 1, 1};
  85. const Mat pts_in_hull(4, sz, CV_32F, hull_pts);
  86. Ptr<dnn::Layer> class8_ab = net.getLayer("class8_ab");
  87. class8_ab->blobs.push_back(pts_in_hull);
  88. Ptr<dnn::Layer> conv8_313_rh = net.getLayer("conv8_313_rh");
  89. conv8_313_rh->blobs.push_back(Mat(1, 313, CV_32F, Scalar(2.606)));
  90. double t = (double) cv::getTickCount();
  91. // extract L channel and subtract mean
  92. Mat lab, L, input;
  93. img.convertTo(img, CV_32F, 1.0/255);
  94. cvtColor(img, lab, COLOR_BGR2Lab);
  95. extractChannel(lab, L, 0);
  96. resize(L, input, Size(W_in, H_in));
  97. input -= 50;
  98. // run the L channel through the network
  99. Mat inputBlob = blobFromImage(input);
  100. net.setInput(inputBlob);
  101. Mat result = net.forward();
  102. // retrieve the calculated a,b channels from the network output
  103. Size out_size(result.size[2], result.size[3]);
  104. Mat a = Mat(out_size, CV_32F, result.ptr(0, 0));
  105. Mat b = Mat(out_size, CV_32F, result.ptr(0, 1));
  106. resize(a, a, img.size());
  107. resize(b, b, img.size());
  108. // merge, and convert back to BGR
  109. Mat color, chn[] = {L, a, b};
  110. merge(chn, 3, lab);
  111. cvtColor(lab, color, COLOR_Lab2BGR);
  112. t = ((double)cv::getTickCount() - t)/cv::getTickFrequency();
  113. cout << "Time taken : " << t << " secs" << endl;
  114. string str = imageFileName;
  115. str.replace(str.end() - 4, str.end(), "");
  116. str = str + "_colorized.png";
  117. color = color.mul(255);
  118. color.convertTo(color, CV_8U);
  119. imwrite(str, color);
  120. cout << "Colorized image saved as " << str << endl;
  121. return 0;
  122. }