msssim.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. #!/usr/bin/python
  2. #
  3. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ==============================================================================
  17. """Python implementation of MS-SSIM.
  18. Usage:
  19. python msssim.py --original_image=original.png --compared_image=distorted.png
  20. """
  21. import numpy as np
  22. from scipy import signal
  23. from scipy.ndimage.filters import convolve
  24. import tensorflow as tf
  25. tf.flags.DEFINE_string('original_image', None, 'Path to PNG image.')
  26. tf.flags.DEFINE_string('compared_image', None, 'Path to PNG image.')
  27. FLAGS = tf.flags.FLAGS
  28. def _FSpecialGauss(size, sigma):
  29. """Function to mimic the 'fspecial' gaussian MATLAB function."""
  30. radius = size // 2
  31. offset = 0.0
  32. start, stop = -radius, radius + 1
  33. if size % 2 == 0:
  34. offset = 0.5
  35. stop -= 1
  36. x, y = np.mgrid[offset + start:stop, offset + start:stop]
  37. assert len(x) == size
  38. g = np.exp(-((x**2 + y**2)/(2.0 * sigma**2)))
  39. return g / g.sum()
  40. def _SSIMForMultiScale(img1, img2, max_val=255, filter_size=11,
  41. filter_sigma=1.5, k1=0.01, k2=0.03):
  42. """Return the Structural Similarity Map between `img1` and `img2`.
  43. This function attempts to match the functionality of ssim_index_new.m by
  44. Zhou Wang: http://www.cns.nyu.edu/~lcv/ssim/msssim.zip
  45. Arguments:
  46. img1: Numpy array holding the first RGB image batch.
  47. img2: Numpy array holding the second RGB image batch.
  48. max_val: the dynamic range of the images (i.e., the difference between the
  49. maximum the and minimum allowed values).
  50. filter_size: Size of blur kernel to use (will be reduced for small images).
  51. filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced
  52. for small images).
  53. k1: Constant used to maintain stability in the SSIM calculation (0.01 in
  54. the original paper).
  55. k2: Constant used to maintain stability in the SSIM calculation (0.03 in
  56. the original paper).
  57. Returns:
  58. Pair containing the mean SSIM and contrast sensitivity between `img1` and
  59. `img2`.
  60. Raises:
  61. RuntimeError: If input images don't have the same shape or don't have four
  62. dimensions: [batch_size, height, width, depth].
  63. """
  64. if img1.shape != img2.shape:
  65. raise RuntimeError('Input images must have the same shape (%s vs. %s).',
  66. img1.shape, img2.shape)
  67. if img1.ndim != 4:
  68. raise RuntimeError('Input images must have four dimensions, not %d',
  69. img1.ndim)
  70. img1 = img1.astype(np.float64)
  71. img2 = img2.astype(np.float64)
  72. _, height, width, _ = img1.shape
  73. # Filter size can't be larger than height or width of images.
  74. size = min(filter_size, height, width)
  75. # Scale down sigma if a smaller filter size is used.
  76. sigma = size * filter_sigma / filter_size if filter_size else 0
  77. if filter_size:
  78. window = np.reshape(_FSpecialGauss(size, sigma), (1, size, size, 1))
  79. mu1 = signal.fftconvolve(img1, window, mode='valid')
  80. mu2 = signal.fftconvolve(img2, window, mode='valid')
  81. sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid')
  82. sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid')
  83. sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid')
  84. else:
  85. # Empty blur kernel so no need to convolve.
  86. mu1, mu2 = img1, img2
  87. sigma11 = img1 * img1
  88. sigma22 = img2 * img2
  89. sigma12 = img1 * img2
  90. mu11 = mu1 * mu1
  91. mu22 = mu2 * mu2
  92. mu12 = mu1 * mu2
  93. sigma11 -= mu11
  94. sigma22 -= mu22
  95. sigma12 -= mu12
  96. # Calculate intermediate values used by both ssim and cs_map.
  97. c1 = (k1 * max_val) ** 2
  98. c2 = (k2 * max_val) ** 2
  99. v1 = 2.0 * sigma12 + c2
  100. v2 = sigma11 + sigma22 + c2
  101. ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2)))
  102. cs = np.mean(v1 / v2)
  103. return ssim, cs
  104. def MultiScaleSSIM(img1, img2, max_val=255, filter_size=11, filter_sigma=1.5,
  105. k1=0.01, k2=0.03, weights=None):
  106. """Return the MS-SSIM score between `img1` and `img2`.
  107. This function implements Multi-Scale Structural Similarity (MS-SSIM) Image
  108. Quality Assessment according to Zhou Wang's paper, "Multi-scale structural
  109. similarity for image quality assessment" (2003).
  110. Link: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf
  111. Author's MATLAB implementation:
  112. http://www.cns.nyu.edu/~lcv/ssim/msssim.zip
  113. Arguments:
  114. img1: Numpy array holding the first RGB image batch.
  115. img2: Numpy array holding the second RGB image batch.
  116. max_val: the dynamic range of the images (i.e., the difference between the
  117. maximum the and minimum allowed values).
  118. filter_size: Size of blur kernel to use (will be reduced for small images).
  119. filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced
  120. for small images).
  121. k1: Constant used to maintain stability in the SSIM calculation (0.01 in
  122. the original paper).
  123. k2: Constant used to maintain stability in the SSIM calculation (0.03 in
  124. the original paper).
  125. weights: List of weights for each level; if none, use five levels and the
  126. weights from the original paper.
  127. Returns:
  128. MS-SSIM score between `img1` and `img2`.
  129. Raises:
  130. RuntimeError: If input images don't have the same shape or don't have four
  131. dimensions: [batch_size, height, width, depth].
  132. """
  133. if img1.shape != img2.shape:
  134. raise RuntimeError('Input images must have the same shape (%s vs. %s).',
  135. img1.shape, img2.shape)
  136. if img1.ndim != 4:
  137. raise RuntimeError('Input images must have four dimensions, not %d',
  138. img1.ndim)
  139. # Note: default weights don't sum to 1.0 but do match the paper / matlab code.
  140. weights = np.array(weights if weights else
  141. [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
  142. levels = weights.size
  143. downsample_filter = np.ones((1, 2, 2, 1)) / 4.0
  144. im1, im2 = [x.astype(np.float64) for x in [img1, img2]]
  145. mssim = np.array([])
  146. mcs = np.array([])
  147. for _ in xrange(levels):
  148. ssim, cs = _SSIMForMultiScale(
  149. im1, im2, max_val=max_val, filter_size=filter_size,
  150. filter_sigma=filter_sigma, k1=k1, k2=k2)
  151. mssim = np.append(mssim, ssim)
  152. mcs = np.append(mcs, cs)
  153. filtered = [convolve(im, downsample_filter, mode='reflect')
  154. for im in [im1, im2]]
  155. im1, im2 = [x[:, ::2, ::2, :] for x in filtered]
  156. return (np.prod(mcs[0:levels-1] ** weights[0:levels-1]) *
  157. (mssim[levels-1] ** weights[levels-1]))
  158. def main(_):
  159. if FLAGS.original_image is None or FLAGS.compared_image is None:
  160. print ('\nUsage: python msssim.py --original_image=original.png '
  161. '--compared_image=distorted.png\n\n')
  162. return
  163. if not tf.gfile.Exists(FLAGS.original_image):
  164. print '\nCannot find --original_image.\n'
  165. return
  166. if not tf.gfile.Exists(FLAGS.compared_image):
  167. print '\nCannot find --compared_image.\n'
  168. return
  169. with tf.gfile.FastGFile(FLAGS.original_image) as image_file:
  170. img1_str = image_file.read()
  171. with tf.gfile.FastGFile(FLAGS.compared_image) as image_file:
  172. img2_str = image_file.read()
  173. input_img = tf.placeholder(tf.string)
  174. decoded_image = tf.expand_dims(tf.image.decode_png(input_img, channels=3), 0)
  175. with tf.Session() as sess:
  176. img1 = sess.run(decoded_image, feed_dict={input_img: img1_str})
  177. img2 = sess.run(decoded_image, feed_dict={input_img: img2_str})
  178. print MultiScaleSSIM(img1, img2, max_val=255)
  179. if __name__ == '__main__':
  180. tf.app.run()