فهرست منبع

Initial commit of compression models. (#402)

nickj-google 9 سال پیش
والد
کامیت
76739168f6
5فایلهای تغییر یافته به همراه540 افزوده شده و 0 حذف شده
  1. 96 0
      compression/README.md
  2. 124 0
      compression/decoder.py
  3. 103 0
      compression/encoder.py
  4. BIN
      compression/example.png
  5. 217 0
      compression/msssim.py

+ 96 - 0
compression/README.md

@@ -0,0 +1,96 @@
+# Image Compression with Neural Networks
+
+This is a [TensorFlow](http://www.tensorflow.org/) model for compressing and
+decompressing images using an already trained  Residual GRU model as descibed
+in [Full Resolution Image Compression with Recurrent Neural Networks]
+(https://arxiv.org/abs/1608.05148). Please consult the paper for more details
+on the architecture and compression results.
+
+This code will allow you to perform the lossy compression on an model
+already trained on compression. This code doesn't not currently contain the
+Entropy Coding portions of our paper.
+
+
+## Prerequisites
+The only software requirements for running the encoder and decoder is having
+Tensorflow installed. You will also need to [download]
+(http://download.tensorflow.org/models/compression_residual_gru-2016-08-23.tar.gz)
+and extract the model residual_gru.pb.
+
+If you want to generate the perceptual similarity under MS-SSIM, you will also
+need to [Install SciPy](https://www.scipy.org/install.html).
+
+## Encoding
+The Residual GRU network is fully convolutional, but requires the images
+height and width in pixels by a multiple of 32. There is an image in this folder
+called example.png that is 768x1024 if one is needed for testing. We also
+rely on TensorFlow's built in decoding ops, which support only PNG and JPEG at
+time of release.
+
+To encode an image, simply run the following command:
+
+`python encoder.py --input_image=/your/image/here.png
+--output_codes=output_codes.npz --iteration=15
+--model=/path/to/model/residual_gru.pb
+`
+
+The iteration parameter specifies the lossy-quality to target for compression.
+The quality can be [0-15], where 0 corresponds to a target of 1/8 (bits per
+pixel) bpp and every increment results in an additional 1/8 bpp.
+
+| Iteration | BPP | Compression Ratio |
+|---: |---: |---: |
+|0 | 0.125 | 192:1|
+|1 | 0.250 | 96:1|
+|2 | 0.375 | 64:1|
+|3 | 0.500 | 48:1|
+|4 | 0.625 | 38.4:1|
+|5 | 0.750 | 32:1|
+|6 | 0.875 | 27.4:1|
+|7 | 1.000 | 24:1|
+|8 | 1.125 | 21.3:1|
+|9 | 1.250 | 19.2:1|
+|10 | 1.375 | 17.4:1|
+|11 | 1.500 | 16:1|
+|12 | 1.625 | 14.7:1|
+|13 | 1.750 | 13.7:1|
+|14 | 1.875 | 12.8:1|
+|15 | 2.000 | 12:1|
+
+The output_codes file contains the numpy shape and a flattened, bit-packed
+array of the codes. These can be inspected in python by using numpy.load().
+
+
+## Decoding
+After generating codes for an image, the lossy reconstructions for that image
+can be done as follows:
+
+`python decoder.py --input_codes=codes.npz --output_directory=/tmp/decoded/
+--model=residual_gru.pb`
+
+The output_directory will contain images decoded at each quality level.
+
+
+## Comparing Similarity
+One of our primary metrics for comparing how similar two images are
+is MS-SSIM.
+
+To generate these metrics on your images you can run:
+`python msssim.py --original_image=/path/to/your/image.png
+--compared_image=/tmp/decoded/image_15.png`
+
+
+## FAQ
+
+#### How do I train my own compression network?
+We currently don't provide the code to build and train a compression
+graph from scratch.
+
+#### I get an InvalidArgumentError: Incompatible shapes.
+This is usually due to the fact that our network only supports images that are
+both height and width divisible by 32 pixel. Try padding your images to 32
+pixel boundaries.
+
+
+## Contact Info
+Model repository maintained by Nick Johnston ([nickj-google](https://github.com/nickj-google)).

+ 124 - 0
compression/decoder.py

@@ -0,0 +1,124 @@
+#!/usr/bin/python
+#
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Neural Network Image Compression Decoder.
+
+Decompress an image from the numpy's npz format generated by the encoder.
+
+Example usage:
+python decoder.py --input_codes=output_codes.pkl --iteration=15 \
+--output_directory=/tmp/compression_output/ --model=residual_gru.pb
+"""
+import os
+
+import numpy as np
+import tensorflow as tf
+
+tf.flags.DEFINE_string('input_codes', None, 'Location of binary code file.')
+tf.flags.DEFINE_integer('iteration', -1, 'The max quality level of '
+                        'the images to output. Use -1 to infer from loaded '
+                        ' codes.')
+tf.flags.DEFINE_string('output_directory', None, 'Directory to save decoded '
+                       'images.')
+tf.flags.DEFINE_string('model', None, 'Location of compression model.')
+
+FLAGS = tf.flags.FLAGS
+
+
+def get_input_tensor_names():
+  name_list = ['GruBinarizer/SignBinarizer/Sign:0']
+  for i in xrange(1, 16):
+    name_list.append('GruBinarizer/SignBinarizer/Sign_{}:0'.format(i))
+  return name_list
+
+
+def get_output_tensor_names():
+  return ['loop_{0:02d}/add:0'.format(i) for i in xrange(0, 16)]
+
+
+def main(_):
+  if (FLAGS.input_codes is None or FLAGS.output_directory is None or
+      FLAGS.model is None):
+    print ('\nUsage: python decoder.py --input_codes=output_codes.pkl '
+           '--iteration=15 --output_directory=/tmp/compression_output/ '
+           '--model=residual_gru.pb\n\n')
+    return
+
+  if FLAGS.iteration < -1 or FLAGS.iteration > 15:
+    print ('\n--iteration must be between 0 and 15 inclusive, or -1 to infer '
+           'from file.\n')
+    return
+  iteration = FLAGS.iteration
+
+  if not tf.gfile.Exists(FLAGS.output_directory):
+    tf.gfile.MkDir(FLAGS.output_directory)
+
+  if not tf.gfile.Exists(FLAGS.input_codes):
+    print '\nInput codes not found.\n'
+    return
+
+  with tf.gfile.FastGFile(FLAGS.input_codes, 'rb') as code_file:
+    loaded_codes = np.load(code_file)
+    assert ['codes', 'shape'] not in loaded_codes.files
+    loaded_shape = loaded_codes['shape']
+    loaded_array = loaded_codes['codes']
+
+    # Unpack and recover code shapes.
+    unpacked_codes = np.reshape(np.unpackbits(loaded_array)
+                                [:np.prod(loaded_shape)],
+                                loaded_shape)
+
+    numpy_int_codes = np.split(unpacked_codes, len(unpacked_codes))
+    if iteration == -1:
+      iteration = len(unpacked_codes) - 1
+    # Convert back to float and recover scale.
+    numpy_codes = [np.squeeze(x.astype(np.float32), 0) * 2 - 1 for x in
+                   numpy_int_codes]
+
+  with tf.Graph().as_default() as graph:
+    # Load the inference model for decoding.
+    with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
+      graph_def = tf.GraphDef()
+      graph_def.ParseFromString(model_file.read())
+    _ = tf.import_graph_def(graph_def, name='')
+
+    # For encoding the tensors into PNGs.
+    input_image = tf.placeholder(tf.uint8)
+    encoded_image = tf.image.encode_png(input_image)
+
+    input_tensors = [graph.get_tensor_by_name(name) for name in
+                     get_input_tensor_names()][0:iteration+1]
+    outputs = [graph.get_tensor_by_name(name) for name in
+               get_output_tensor_names()][0:iteration+1]
+
+  feed_dict = {key: value for (key, value) in zip(input_tensors,
+                                                  numpy_codes)}
+
+  with tf.Session(graph=graph) as sess:
+    results = sess.run(outputs, feed_dict=feed_dict)
+
+    for index, result in enumerate(results):
+      img = np.uint8(np.clip(result + 0.5, 0, 255))
+      img = img.squeeze()
+      png_img = sess.run(encoded_image, feed_dict={input_image: img})
+
+      with tf.gfile.FastGFile(os.path.join(FLAGS.output_directory,
+                                           'image_{0:02d}.png'.format(index)),
+                              'w') as output_image:
+        output_image.write(png_img)
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 103 - 0
compression/encoder.py

@@ -0,0 +1,103 @@
+#!/usr/bin/python
+#
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Neural Network Image Compression Encoder.
+
+Compresses an image to a binarized numpy array. The image must be padded to a
+multiple of 32 pixels in height and width.
+
+Example usage:
+python encoder.py --input_image=/your/image/here.png \
+--output_codes=output_codes.pkl --iteration=15 --model=residual_gru.pb
+"""
+import os
+
+import numpy as np
+import tensorflow as tf
+
+tf.flags.DEFINE_string('input_image', None, 'Location of input image. We rely '
+                       'on tf.image to decode the image, so only PNG and JPEG '
+                       'formats are currently supported.')
+tf.flags.DEFINE_integer('iteration', 15, 'Quality level for encoding image. '
+                        'Must be between 0 and 15 inclusive.')
+tf.flags.DEFINE_string('output_codes', None, 'Directory to save decoded '
+                       'images.')
+tf.flags.DEFINE_string('model', None, 'Location of compression model.')
+
+FLAGS = tf.flags.FLAGS
+
+
+def get_output_tensor_names():
+  name_list = ['GruBinarizer/SignBinarizer/Sign:0']
+  for i in xrange(1, 16):
+    name_list.append('GruBinarizer/SignBinarizer/Sign_{}:0'.format(i))
+  return name_list
+
+
+def main(_):
+  if (FLAGS.input_image is None or FLAGS.output_codes is None or
+      FLAGS.model is None):
+    print ('\nUsage: python encoder.py --input_image=/your/image/here.png '
+           '--output_codes=output_codes.pkl --iteration=15 '
+           '--model=residual_gru.pb\n\n')
+    return
+
+  if FLAGS.iteration < 0 or FLAGS.iteration > 15:
+    print '\n--iteration must be between 0 and 15 inclusive.\n'
+    return
+
+  with tf.gfile.FastGFile(FLAGS.input_image) as input_image:
+    input_image_str = input_image.read()
+
+  with tf.Graph().as_default() as graph:
+    # Load the inference model for encoding.
+    with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
+      graph_def = tf.GraphDef()
+      graph_def.ParseFromString(model_file.read())
+    _ = tf.import_graph_def(graph_def, name='')
+
+    input_tensor = graph.get_tensor_by_name('Placeholder:0')
+    outputs = [graph.get_tensor_by_name(name) for name in
+               get_output_tensor_names()]
+
+    input_image = tf.placeholder(tf.string)
+    _, ext = os.path.splitext(FLAGS.input_image)
+    if ext == '.png':
+      decoded_image = tf.image.decode_png(input_image, channels=3)
+    elif ext == '.jpeg' or ext == '.jpg':
+      decoded_image = tf.image.decode_jpeg(input_image, channels=3)
+    else:
+      assert False, 'Unsupported file format {}'.format(ext)
+    decoded_image = tf.expand_dims(decoded_image, 0)
+
+  with tf.Session(graph=graph) as sess:
+    img_array = sess.run(decoded_image, feed_dict={input_image:
+                                                   input_image_str})
+    results = sess.run(outputs, feed_dict={input_tensor: img_array})
+
+  results = results[0:FLAGS.iteration + 1]
+  int_codes = np.asarray([x.astype(np.int8) for x in results])
+
+  # Convert int codes to binary.
+  int_codes = (int_codes + 1)/2
+  export = np.packbits(int_codes.reshape(-1))
+
+  with tf.gfile.FastGFile(FLAGS.output_codes, 'wb') as code_file:
+    np.savez_compressed(code_file, shape=int_codes.shape, codes=export)
+
+
+if __name__ == '__main__':
+  tf.app.run()

BIN
compression/example.png


+ 217 - 0
compression/msssim.py

@@ -0,0 +1,217 @@
+#!/usr/bin/python
+#
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Python implementation of MS-SSIM.
+
+Usage:
+
+python msssim.py --original_image=original.png --compared_image=distorted.png
+"""
+import numpy as np
+from scipy import signal
+from scipy.ndimage.filters import convolve
+import tensorflow as tf
+
+
+tf.flags.DEFINE_string('original_image', None, 'Path to PNG image.')
+tf.flags.DEFINE_string('compared_image', None, 'Path to PNG image.')
+FLAGS = tf.flags.FLAGS
+
+
+def _FSpecialGauss(size, sigma):
+  """Function to mimic the 'fspecial' gaussian MATLAB function."""
+  radius = size // 2
+  offset = 0.0
+  start, stop = -radius, radius + 1
+  if size % 2 == 0:
+    offset = 0.5
+    stop -= 1
+  x, y = np.mgrid[offset + start:stop, offset + start:stop]
+  assert len(x) == size
+  g = np.exp(-((x**2 + y**2)/(2.0 * sigma**2)))
+  return g / g.sum()
+
+
+def _SSIMForMultiScale(img1, img2, max_val=255, filter_size=11,
+                       filter_sigma=1.5, k1=0.01, k2=0.03):
+  """Return the Structural Similarity Map between `img1` and `img2`.
+
+  This function attempts to match the functionality of ssim_index_new.m by
+  Zhou Wang: http://www.cns.nyu.edu/~lcv/ssim/msssim.zip
+
+  Arguments:
+    img1: Numpy array holding the first RGB image batch.
+    img2: Numpy array holding the second RGB image batch.
+    max_val: the dynamic range of the images (i.e., the difference between the
+      maximum the and minimum allowed values).
+    filter_size: Size of blur kernel to use (will be reduced for small images).
+    filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced
+      for small images).
+    k1: Constant used to maintain stability in the SSIM calculation (0.01 in
+      the original paper).
+    k2: Constant used to maintain stability in the SSIM calculation (0.03 in
+      the original paper).
+
+  Returns:
+    Pair containing the mean SSIM and contrast sensitivity between `img1` and
+    `img2`.
+
+  Raises:
+    RuntimeError: If input images don't have the same shape or don't have four
+      dimensions: [batch_size, height, width, depth].
+  """
+  if img1.shape != img2.shape:
+    raise RuntimeError('Input images must have the same shape (%s vs. %s).',
+                       img1.shape, img2.shape)
+  if img1.ndim != 4:
+    raise RuntimeError('Input images must have four dimensions, not %d',
+                       img1.ndim)
+
+  img1 = img1.astype(np.float64)
+  img2 = img2.astype(np.float64)
+  _, height, width, _ = img1.shape
+
+  # Filter size can't be larger than height or width of images.
+  size = min(filter_size, height, width)
+
+  # Scale down sigma if a smaller filter size is used.
+  sigma = size * filter_sigma / filter_size if filter_size else 0
+
+  if filter_size:
+    window = np.reshape(_FSpecialGauss(size, sigma), (1, size, size, 1))
+    mu1 = signal.fftconvolve(img1, window, mode='valid')
+    mu2 = signal.fftconvolve(img2, window, mode='valid')
+    sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid')
+    sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid')
+    sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid')
+  else:
+    # Empty blur kernel so no need to convolve.
+    mu1, mu2 = img1, img2
+    sigma11 = img1 * img1
+    sigma22 = img2 * img2
+    sigma12 = img1 * img2
+
+  mu11 = mu1 * mu1
+  mu22 = mu2 * mu2
+  mu12 = mu1 * mu2
+  sigma11 -= mu11
+  sigma22 -= mu22
+  sigma12 -= mu12
+
+  # Calculate intermediate values used by both ssim and cs_map.
+  c1 = (k1 * max_val) ** 2
+  c2 = (k2 * max_val) ** 2
+  v1 = 2.0 * sigma12 + c2
+  v2 = sigma11 + sigma22 + c2
+  ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2)))
+  cs = np.mean(v1 / v2)
+  return ssim, cs
+
+
+def MultiScaleSSIM(img1, img2, max_val=255, filter_size=11, filter_sigma=1.5,
+                   k1=0.01, k2=0.03, weights=None):
+  """Return the MS-SSIM score between `img1` and `img2`.
+
+  This function implements Multi-Scale Structural Similarity (MS-SSIM) Image
+  Quality Assessment according to Zhou Wang's paper, "Multi-scale structural
+  similarity for image quality assessment" (2003).
+  Link: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf
+
+  Author's MATLAB implementation:
+  http://www.cns.nyu.edu/~lcv/ssim/msssim.zip
+
+  Arguments:
+    img1: Numpy array holding the first RGB image batch.
+    img2: Numpy array holding the second RGB image batch.
+    max_val: the dynamic range of the images (i.e., the difference between the
+      maximum the and minimum allowed values).
+    filter_size: Size of blur kernel to use (will be reduced for small images).
+    filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced
+      for small images).
+    k1: Constant used to maintain stability in the SSIM calculation (0.01 in
+      the original paper).
+    k2: Constant used to maintain stability in the SSIM calculation (0.03 in
+      the original paper).
+    weights: List of weights for each level; if none, use five levels and the
+      weights from the original paper.
+
+  Returns:
+    MS-SSIM score between `img1` and `img2`.
+
+  Raises:
+    RuntimeError: If input images don't have the same shape or don't have four
+      dimensions: [batch_size, height, width, depth].
+  """
+  if img1.shape != img2.shape:
+    raise RuntimeError('Input images must have the same shape (%s vs. %s).',
+                       img1.shape, img2.shape)
+  if img1.ndim != 4:
+    raise RuntimeError('Input images must have four dimensions, not %d',
+                       img1.ndim)
+
+  # Note: default weights don't sum to 1.0 but do match the paper / matlab code.
+  weights = np.array(weights if weights else
+                     [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
+  levels = weights.size
+  downsample_filter = np.ones((1, 2, 2, 1)) / 4.0
+  im1, im2 = [x.astype(np.float64) for x in [img1, img2]]
+  mssim = np.array([])
+  mcs = np.array([])
+  for _ in xrange(levels):
+    ssim, cs = _SSIMForMultiScale(
+        im1, im2, max_val=max_val, filter_size=filter_size,
+        filter_sigma=filter_sigma, k1=k1, k2=k2)
+    mssim = np.append(mssim, ssim)
+    mcs = np.append(mcs, cs)
+    filtered = [convolve(im, downsample_filter, mode='reflect')
+                for im in [im1, im2]]
+    im1, im2 = [x[:, ::2, ::2, :] for x in filtered]
+  return (np.prod(mcs[0:levels-1] ** weights[0:levels-1]) *
+          (mssim[levels-1] ** weights[levels-1]))
+
+
+def main(_):
+  if FLAGS.original_image is None or FLAGS.compared_image is None:
+    print ('\nUsage: python msssim.py --original_image=original.png '
+           '--compared_image=distorted.png\n\n')
+    return
+
+  if not tf.gfile.Exists(FLAGS.original_image):
+    print '\nCannot find --original_image.\n'
+    return
+
+  if not tf.gfile.Exists(FLAGS.compared_image):
+    print '\nCannot find --compared_image.\n'
+    return
+
+  with tf.gfile.FastGFile(FLAGS.original_image) as image_file:
+    img1_str = image_file.read()
+  with tf.gfile.FastGFile(FLAGS.compared_image) as image_file:
+    img2_str = image_file.read()
+
+  input_img = tf.placeholder(tf.string)
+  decoded_image = tf.expand_dims(tf.image.decode_png(input_img, channels=3), 0)
+
+  with tf.Session() as sess:
+    img1 = sess.run(decoded_image, feed_dict={input_img: img1_str})
+    img2 = sess.run(decoded_image, feed_dict={input_img: img2_str})
+
+  print MultiScaleSSIM(img1, img2, max_val=255)
+
+
+if __name__ == '__main__':
+  tf.app.run()