colorizeImage.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # This code is written by Sunita Nayak at BigVision LLC. It is based on the OpenCV project.
  2. # It is subject to the license terms in the LICENSE file found in this distribution and at http://opencv.org/license.html
  3. # Usage example: python3 colorizeImage.py --input greyscaleImage.png
  4. import numpy as np
  5. import cv2 as cv
  6. import argparse
  7. import os.path
  8. import time
  9. parser = argparse.ArgumentParser(description='Colorize GreyScale Image')
  10. parser.add_argument('--input', help='Path to image.')
  11. parser.add_argument("--device", default="cpu", help="Device to inference on")
  12. args = parser.parse_args()
  13. if args.input is None:
  14. print('Please give the input greyscale image name.')
  15. print('Usage example: python3 colorizeImage.py --input greyscaleImage.png')
  16. exit()
  17. if not os.path.isfile(args.input):
  18. print('Input file does not exist')
  19. exit()
  20. print("Input image file: ", args.input)
  21. # Read the input image
  22. frame = cv.imread(args.input)
  23. # Specify the paths for the 2 model files
  24. protoFile = "./models/colorization_deploy_v2.prototxt"
  25. weightsFile = "./models/colorization_release_v2.caffemodel"
  26. # Load the cluster centers
  27. pts_in_hull = np.load('./pts_in_hull.npy')
  28. # Read the network into Memory
  29. net = cv.dnn.readNetFromCaffe(protoFile, weightsFile)
  30. if args.device == "cpu":
  31. net.setPreferableBackend(cv.dnn.DNN_TARGET_CPU)
  32. print("Using CPU device")
  33. elif args.device == "gpu":
  34. net.setPreferableBackend(cv.dnn.DNN_BACKEND_CUDA)
  35. net.setPreferableTarget(cv.dnn.DNN_TARGET_CUDA)
  36. print("Using GPU device")
  37. # populate cluster centers as 1x1 convolution kernel
  38. pts_in_hull = pts_in_hull.transpose().reshape(2, 313, 1, 1)
  39. net.getLayer(net.getLayerId('class8_ab')).blobs = [pts_in_hull.astype(np.float32)]
  40. net.getLayer(net.getLayerId('conv8_313_rh')).blobs = [np.full([1, 313], 2.606, np.float32)]
  41. # from opencv sample
  42. W_in = 224
  43. H_in = 224
  44. start = time.time()
  45. img_rgb = (frame[:, :, [2, 1, 0]] * 1.0 / 255).astype(np.float32)
  46. img_lab = cv.cvtColor(img_rgb, cv.COLOR_RGB2Lab)
  47. img_l = img_lab[:, :, 0] # pull out L channel
  48. # resize lightness channel to network input size
  49. img_l_rs = cv.resize(img_l, (W_in, H_in))
  50. img_l_rs -= 50 # subtract 50 for mean-centering
  51. net.setInput(cv.dnn.blobFromImage(img_l_rs))
  52. ab_dec = net.forward()[0, :, :, :].transpose((1, 2, 0)) # this is our result
  53. (H_orig, W_orig) = img_rgb.shape[:2] # original image size
  54. ab_dec_us = cv.resize(ab_dec, (W_orig, H_orig))
  55. img_lab_out = np.concatenate((img_l[:, :, np.newaxis],ab_dec_us), axis=2) # concatenate with original image L
  56. img_bgr_out = np.clip(cv.cvtColor(img_lab_out, cv.COLOR_Lab2BGR), 0, 1)
  57. end = time.time()
  58. print("Time taken : {:0.5f} secs".format(end - start))
  59. outputFile = args.input[:-4] + '_colorized.png'
  60. cv.imwrite(outputFile, (img_bgr_out * 255).astype(np.uint8))
  61. print('Colorized image saved as ' + outputFile)
  62. print('Done !!!')