Bladeren bron

Updated to the latest version of TF-Slim

Derek Murray 9 jaren geleden
bovenliggende
commit
8bc9fe9401

+ 1 - 0
inception/inception/slim/collections_test.py

@@ -17,6 +17,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+
 import tensorflow as tf
 
 from inception.slim import slim

+ 1 - 0
inception/inception/slim/inception_model.py

@@ -43,6 +43,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+
 import tensorflow as tf
 
 from inception.slim import ops

+ 1 - 0
inception/inception/slim/inception_test.py

@@ -17,6 +17,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+
 import tensorflow as tf
 
 from inception.slim import inception_model as inception

+ 1 - 0
inception/inception/slim/losses.py

@@ -26,6 +26,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+
 import tensorflow as tf
 
 # In order to gather all losses in a network, the user should use this

+ 1 - 0
inception/inception/slim/losses_test.py

@@ -18,6 +18,7 @@ from __future__ import division
 from __future__ import print_function
 
 
+
 import tensorflow as tf
 
 from inception.slim import losses

+ 29 - 22
inception/inception/slim/ops.py

@@ -27,6 +27,7 @@ from __future__ import division
 from __future__ import print_function
 
 
+
 import tensorflow as tf
 
 from tensorflow.python.training import moving_averages
@@ -42,6 +43,7 @@ UPDATE_OPS_COLLECTION = '_update_ops_'
 @scopes.add_arg_scope
 def batch_norm(inputs,
                decay=0.999,
+               center=True,
                scale=False,
                epsilon=0.001,
                moving_vars='moving_vars',
@@ -57,6 +59,7 @@ def batch_norm(inputs,
     inputs: a tensor of size [batch_size, height, width, channels]
             or [batch_size, channels].
     decay: decay for the moving average.
+    center: If True, subtract beta. If False, beta is not created and ignored.
     scale: If True, multiply by gamma. If False, gamma is
       not used. When the next layer is linear (also e.g. ReLU), this can be
       disabled since the scaling can be done by the next layer.
@@ -78,31 +81,35 @@ def batch_norm(inputs,
   with tf.variable_op_scope([inputs], scope, 'BatchNorm', reuse=reuse):
     axis = list(range(len(inputs_shape) - 1))
     params_shape = inputs_shape[-1:]
-    with scopes.arg_scope([variables.variable], restore=restore):
-      # Allocate parameters for the beta and gamma of the normalization.
+    # Allocate parameters for the beta and gamma of the normalization.
+    beta, gamma = None, None
+    if center:
       beta = variables.variable('beta',
                                 params_shape,
                                 initializer=tf.zeros_initializer,
-                                trainable=trainable)
-      if scale:
-        gamma = variables.variable('gamma',
-                                   params_shape,
-                                   initializer=tf.ones,
-                                   trainable=trainable)
-      else:
-        gamma = None
-      # Create moving_mean and moving_variance add them to moving_vars and
-      # GraphKeys.MOVING_AVERAGE_VARIABLES collections.
-      with scopes.arg_scope([variables.variable], trainable=False,
-                            collections=[
-                                moving_vars,
-                                tf.GraphKeys.MOVING_AVERAGE_VARIABLES]):
-        moving_mean = variables.variable('moving_mean',
+                                trainable=trainable,
+                                restore=restore)
+    if scale:
+      gamma = variables.variable('gamma',
+                                 params_shape,
+                                 initializer=tf.ones_initializer,
+                                 trainable=trainable,
+                                 restore=restore)
+    # Create moving_mean and moving_variance add them to
+    # GraphKeys.MOVING_AVERAGE_VARIABLES collections.
+    moving_collections = [moving_vars, tf.GraphKeys.MOVING_AVERAGE_VARIABLES]
+    moving_mean = variables.variable('moving_mean',
+                                     params_shape,
+                                     initializer=tf.zeros_initializer,
+                                     trainable=False,
+                                     restore=restore,
+                                     collections=moving_collections)
+    moving_variance = variables.variable('moving_variance',
                                          params_shape,
-                                         initializer=tf.zeros_initializer)
-        moving_variance = variables.variable('moving_variance',
-                                             params_shape,
-                                             initializer=tf.ones)
+                                         initializer=tf.ones_initializer,
+                                         trainable=False,
+                                         restore=restore,
+                                         collections=moving_collections)
     if is_training:
       # Calculate the moments based on the individual batch.
       mean, variance = tf.nn.moments(inputs, axis)
@@ -400,7 +407,7 @@ def dropout(inputs, keep_prob=0.5, is_training=True, scope=None):
 
   Args:
     inputs: the tensor to pass to the Dropout layer.
-    keep_prob: the probability of dropping each input unit.
+    keep_prob: the probability of keeping each input unit.
     is_training: whether or not the model is in training mode. If so, dropout is
     applied and values scaled. Otherwise, inputs is returned.
     scope: Optional scope for op_scope.

+ 43 - 0
inception/inception/slim/ops_test.py

@@ -18,6 +18,7 @@ from __future__ import division
 from __future__ import print_function
 
 
+
 import numpy as np
 import tensorflow as tf
 
@@ -479,6 +480,20 @@ class BatchNormTest(tf.test.TestCase):
     height, width = 3, 3
     with self.test_session():
       images = tf.random_uniform((5, height, width, 3), seed=1)
+      ops.batch_norm(images)
+      beta = variables.get_variables_by_name('beta')[0]
+      self.assertEquals(beta.op.name, 'BatchNorm/beta')
+      gamma = variables.get_variables_by_name('gamma')
+      self.assertEquals(gamma, [])
+      moving_mean = tf.moving_average_variables()[0]
+      moving_variance = tf.moving_average_variables()[1]
+      self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
+      self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
+
+  def testCreateVariablesWithScale(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
       ops.batch_norm(images, scale=True)
       beta = variables.get_variables_by_name('beta')[0]
       gamma = variables.get_variables_by_name('gamma')[0]
@@ -489,6 +504,34 @@ class BatchNormTest(tf.test.TestCase):
       self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
       self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
 
+  def testCreateVariablesWithoutCenterWithScale(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      ops.batch_norm(images, center=False, scale=True)
+      beta = variables.get_variables_by_name('beta')
+      self.assertEquals(beta, [])
+      gamma = variables.get_variables_by_name('gamma')[0]
+      self.assertEquals(gamma.op.name, 'BatchNorm/gamma')
+      moving_mean = tf.moving_average_variables()[0]
+      moving_variance = tf.moving_average_variables()[1]
+      self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
+      self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
+
+  def testCreateVariablesWithoutCenterWithoutScale(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      ops.batch_norm(images, center=False, scale=False)
+      beta = variables.get_variables_by_name('beta')
+      self.assertEquals(beta, [])
+      gamma = variables.get_variables_by_name('gamma')
+      self.assertEquals(gamma, [])
+      moving_mean = tf.moving_average_variables()[0]
+      moving_variance = tf.moving_average_variables()[1]
+      self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
+      self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
+
   def testMovingAverageVariables(self):
     height, width = 3, 3
     with self.test_session():

+ 1 - 0
inception/inception/slim/scopes.py

@@ -53,6 +53,7 @@ from __future__ import print_function
 import contextlib
 import functools
 
+
 from tensorflow.python.framework import ops
 
 _ARGSTACK_KEY = ("__arg_stack",)

+ 1 - 0
inception/inception/slim/scopes_test.py

@@ -18,6 +18,7 @@ from __future__ import division
 from __future__ import print_function
 
 
+
 import tensorflow as tf
 from inception.slim import scopes
 

+ 77 - 4
inception/inception/slim/variables.py

@@ -82,8 +82,10 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+
 import tensorflow as tf
 
+from tensorflow.core.framework import graph_pb2
 from inception.slim import scopes
 
 # Collection containing all the variables created using slim.variables
@@ -171,6 +173,79 @@ def get_unique_variable(name):
   raise ValueError('Variable %s does not uniquely identify a variable', name)
 
 
+class VariableDeviceChooser(object):
+  """Slim device chooser for variables.
+
+  When using a parameter server it will assign them in a round-robin fashion.
+  When not using a parameter server it allows GPU:0 placement otherwise CPU:0.
+  """
+
+  def __init__(self,
+               num_parameter_servers=0,
+               ps_device='/job:ps',
+               placement='CPU:0'):
+    """Initialize VariableDeviceChooser.
+
+    Args:
+      num_parameter_servers: number of parameter servers.
+      ps_device: string representing the parameter server device.
+      placement: string representing the placement of the variable either CPU:0
+        or GPU:0. When using parameter servers forced to CPU:0.
+    """
+    self._num_ps = num_parameter_servers
+    self._ps_device = ps_device
+    self._placement = placement if num_parameter_servers == 0 else 'CPU:0'
+    self._next_task_id = 0
+
+  def __call__(self, op):
+    device_string = ''
+    if self._num_ps > 0:
+      task_id = self._next_task_id
+      self._next_task_id = (self._next_task_id + 1) % self._num_ps
+      device_string = '%s/task:%d' % (self._ps_device, task_id)
+    device_string += '/%s' % self._placement
+    return device_string
+
+
+# TODO(sguada) Remove once get_variable is able to colocate op.devices.
+def variable_device(device, name):
+  """Fix the variable device to colocate its ops."""
+  if callable(device):
+    var_name = tf.get_variable_scope().name + '/' + name
+    var_def = graph_pb2.NodeDef(name=var_name, op='Variable')
+    device = device(var_def)
+  if device is None:
+    device = ''
+  return device
+
+
+@scopes.add_arg_scope
+def global_step(device=''):
+  """Returns the global step variable.
+
+  Args:
+    device: Optional device to place the variable. It can be an string or a
+      function that is called to get the device for the variable.
+
+  Returns:
+    the tensor representing the global step variable.
+  """
+  global_step_ref = tf.get_collection(tf.GraphKeys.GLOBAL_STEP)
+  if global_step_ref:
+    return global_step_ref[0]
+  else:
+    collections = [
+        VARIABLES_TO_RESTORE,
+        tf.GraphKeys.VARIABLES,
+        tf.GraphKeys.GLOBAL_STEP,
+    ]
+    # Get the device for the variable.
+    with tf.device(variable_device(device, 'global_step')):
+      return tf.get_variable('global_step', shape=[], dtype=tf.int64,
+                             initializer=tf.zeros_initializer,
+                             trainable=False, collections=collections)
+
+
 @scopes.add_arg_scope
 def variable(name, shape=None, dtype=tf.float32, initializer=None,
              regularizer=None, trainable=True, collections=None, device='',
@@ -200,9 +275,6 @@ def variable(name, shape=None, dtype=tf.float32, initializer=None,
   Returns:
     The created or existing variable.
   """
-  # Instantiate the device for this variable if it is passed as a function.
-  if device and callable(device):
-    device = device()
   collections = list(collections or [])
 
   # Make sure variables are added to tf.GraphKeys.VARIABLES and MODEL_VARIABLES
@@ -212,7 +284,8 @@ def variable(name, shape=None, dtype=tf.float32, initializer=None,
     collections.append(VARIABLES_TO_RESTORE)
   # Remove duplicates
   collections = set(collections)
-  with tf.device(device):
+  # Get the device for the variable.
+  with tf.device(variable_device(device, name)):
     return tf.get_variable(name, shape=shape, dtype=dtype,
                            initializer=initializer, regularizer=regularizer,
                            trainable=trainable, collections=collections)

+ 164 - 1
inception/inception/slim/variables_test.py

@@ -17,6 +17,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+
 import tensorflow as tf
 
 from inception.slim import scopes
@@ -134,6 +135,109 @@ class VariablesTest(tf.test.TestCase):
       self.assertDeviceEqual(a.device, 'cpu:0')
       self.assertDeviceEqual(b.device, 'cpu:1')
 
+  def testVariableWithDeviceFunction(self):
+    class DevFn(object):
+
+      def __init__(self):
+        self.counter = -1
+
+      def __call__(self, op):
+        self.counter += 1
+        return 'cpu:%d' % self.counter
+
+    with self.test_session():
+      with scopes.arg_scope([variables.variable], device=DevFn()):
+        a = variables.variable('a', [])
+        b = variables.variable('b', [])
+        c = variables.variable('c', [], device='cpu:12')
+        d = variables.variable('d', [])
+        with tf.device('cpu:99'):
+          e_init = tf.constant(12)
+        e = variables.variable('e', initializer=e_init)
+      self.assertDeviceEqual(a.device, 'cpu:0')
+      self.assertDeviceEqual(a.initial_value.device, 'cpu:0')
+      self.assertDeviceEqual(b.device, 'cpu:1')
+      self.assertDeviceEqual(b.initial_value.device, 'cpu:1')
+      self.assertDeviceEqual(c.device, 'cpu:12')
+      self.assertDeviceEqual(c.initial_value.device, 'cpu:12')
+      self.assertDeviceEqual(d.device, 'cpu:2')
+      self.assertDeviceEqual(d.initial_value.device, 'cpu:2')
+      self.assertDeviceEqual(e.device, 'cpu:3')
+      self.assertDeviceEqual(e.initial_value.device, 'cpu:99')
+
+  def testVariableWithReplicaDeviceSetter(self):
+    with self.test_session():
+      with tf.device(tf.train.replica_device_setter(ps_tasks=2)):
+        a = variables.variable('a', [])
+        b = variables.variable('b', [])
+        c = variables.variable('c', [], device='cpu:12')
+        d = variables.variable('d', [])
+        with tf.device('cpu:99'):
+          e_init = tf.constant(12)
+        e = variables.variable('e', initializer=e_init)
+      # The values below highlight how the replica_device_setter puts initial
+      # values on the worker job, and how it merges explicit devices.
+      self.assertDeviceEqual(a.device, '/job:ps/task:0/cpu:0')
+      self.assertDeviceEqual(a.initial_value.device, '/job:worker/cpu:0')
+      self.assertDeviceEqual(b.device, '/job:ps/task:1/cpu:0')
+      self.assertDeviceEqual(b.initial_value.device, '/job:worker/cpu:0')
+      self.assertDeviceEqual(c.device, '/job:ps/task:0/cpu:12')
+      self.assertDeviceEqual(c.initial_value.device, '/job:worker/cpu:12')
+      self.assertDeviceEqual(d.device, '/job:ps/task:1/cpu:0')
+      self.assertDeviceEqual(d.initial_value.device, '/job:worker/cpu:0')
+      self.assertDeviceEqual(e.device, '/job:ps/task:0/cpu:0')
+      self.assertDeviceEqual(e.initial_value.device, '/job:worker/cpu:99')
+
+  def testVariableWithVariableDeviceChooser(self):
+
+    with tf.Graph().as_default():
+      device_fn = variables.VariableDeviceChooser(num_parameter_servers=2)
+      with scopes.arg_scope([variables.variable], device=device_fn):
+        a = variables.variable('a', [])
+        b = variables.variable('b', [])
+        c = variables.variable('c', [], device='cpu:12')
+        d = variables.variable('d', [])
+        with tf.device('cpu:99'):
+          e_init = tf.constant(12)
+        e = variables.variable('e', initializer=e_init)
+      # The values below highlight how the VariableDeviceChooser puts initial
+      # values on the same device as the variable job.
+      self.assertDeviceEqual(a.device, '/job:ps/task:0/cpu:0')
+      self.assertDeviceEqual(a.initial_value.device, a.device)
+      self.assertDeviceEqual(b.device, '/job:ps/task:1/cpu:0')
+      self.assertDeviceEqual(b.initial_value.device, b.device)
+      self.assertDeviceEqual(c.device, '/cpu:12')
+      self.assertDeviceEqual(c.initial_value.device, c.device)
+      self.assertDeviceEqual(d.device, '/job:ps/task:0/cpu:0')
+      self.assertDeviceEqual(d.initial_value.device, d.device)
+      self.assertDeviceEqual(e.device, '/job:ps/task:1/cpu:0')
+      self.assertDeviceEqual(e.initial_value.device, '/cpu:99')
+
+  def testVariableGPUPlacement(self):
+
+    with tf.Graph().as_default():
+      device_fn = variables.VariableDeviceChooser(placement='gpu:0')
+      with scopes.arg_scope([variables.variable], device=device_fn):
+        a = variables.variable('a', [])
+        b = variables.variable('b', [])
+        c = variables.variable('c', [], device='cpu:12')
+        d = variables.variable('d', [])
+        with tf.device('cpu:99'):
+          e_init = tf.constant(12)
+        e = variables.variable('e', initializer=e_init)
+      # The values below highlight how the VariableDeviceChooser puts initial
+      # values on the same device as the variable job.
+      self.assertDeviceEqual(a.device, '/gpu:0')
+      self.assertDeviceEqual(a.initial_value.device, a.device)
+      self.assertDeviceEqual(b.device, '/gpu:0')
+      self.assertDeviceEqual(b.initial_value.device, b.device)
+      self.assertDeviceEqual(c.device, '/cpu:12')
+      self.assertDeviceEqual(c.initial_value.device, c.device)
+      self.assertDeviceEqual(d.device, '/gpu:0')
+      self.assertDeviceEqual(d.initial_value.device, d.device)
+      self.assertDeviceEqual(e.device, '/gpu:0')
+      self.assertDeviceEqual(e.initial_value.device, '/cpu:99')
+
   def testVariableCollection(self):
     with self.test_session():
       a = variables.variable('a', [], collections='A')
@@ -178,7 +282,8 @@ class VariablesTest(tf.test.TestCase):
     with self.test_session():
       with scopes.arg_scope([variables.variable], restore=True):
         a = variables.variable('a', [])
-        with scopes.arg_scope([variables.variable], trainable=False,
+        with scopes.arg_scope([variables.variable],
+                              trainable=False,
                               collections=['A', 'B']):
           b = variables.variable('b', [])
         c = variables.variable('c', [])
@@ -226,5 +331,63 @@ class GetVariablesByNameTest(tf.test.TestCase):
       self.assertEquals([a], matched_variables)
 
 
+class GlobalStepTest(tf.test.TestCase):
+
+  def testStable(self):
+    with tf.Graph().as_default():
+      gs = variables.global_step()
+      gs2 = variables.global_step()
+      self.assertTrue(gs is gs2)
+
+  def testDevice(self):
+    with tf.Graph().as_default():
+      with scopes.arg_scope([variables.global_step], device='/gpu:0'):
+        gs = variables.global_step()
+      self.assertDeviceEqual(gs.device, '/gpu:0')
+
+  def testDeviceFn(self):
+    class DevFn(object):
+
+      def __init__(self):
+        self.counter = -1
+
+      def __call__(self, op):
+        self.counter += 1
+        return '/cpu:%d' % self.counter
+
+    with tf.Graph().as_default():
+      with scopes.arg_scope([variables.global_step], device=DevFn()):
+        gs = variables.global_step()
+        gs2 = variables.global_step()
+      self.assertDeviceEqual(gs.device, '/cpu:0')
+      self.assertEquals(gs, gs2)
+      self.assertDeviceEqual(gs2.device, '/cpu:0')
+
+  def testReplicaDeviceSetter(self):
+    device_fn = tf.train.replica_device_setter(2)
+    with tf.Graph().as_default():
+      with scopes.arg_scope([variables.global_step], device=device_fn):
+        gs = variables.global_step()
+        gs2 = variables.global_step()
+        self.assertEquals(gs, gs2)
+        self.assertDeviceEqual(gs.device, '/job:ps/task:0')
+        self.assertDeviceEqual(gs.initial_value.device, '/job:ps/task:0')
+        self.assertDeviceEqual(gs2.device, '/job:ps/task:0')
+        self.assertDeviceEqual(gs2.initial_value.device, '/job:ps/task:0')
+
+  def testVariableWithVariableDeviceChooser(self):
+
+    with tf.Graph().as_default():
+      device_fn = variables.VariableDeviceChooser()
+      with scopes.arg_scope([variables.global_step], device=device_fn):
+        gs = variables.global_step()
+        gs2 = variables.global_step()
+        self.assertEquals(gs, gs2)
+        self.assertDeviceEqual(gs.device, 'cpu:0')
+        self.assertDeviceEqual(gs.initial_value.device, gs.device)
+        self.assertDeviceEqual(gs2.device, 'cpu:0')
+        self.assertDeviceEqual(gs2.initial_value.device, gs2.device)
+
+
 if __name__ == '__main__':
   tf.test.main()