variables.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Contains convenience wrappers for creating Variables in TensorFlow.
  16. Usage:
  17. weights_initializer = tf.truncated_normal_initializer(stddev=0.01)
  18. l2_regularizer = lambda t: losses.l2_loss(t, weight=0.0005)
  19. weights = variables.variable('weights',
  20. shape=[100, 100],
  21. initializer=weights_initializer,
  22. regularizer=l2_regularizer,
  23. device='/cpu:0')
  24. biases = variables.variable('biases',
  25. shape=[100],
  26. initializer=tf.zeros_initializer,
  27. device='/cpu:0')
  28. # More complex example.
  29. net = slim.ops.conv2d(input, 32, [3, 3], scope='conv1')
  30. net = slim.ops.conv2d(net, 64, [3, 3], scope='conv2')
  31. with slim.arg_scope(variables.Variables, restore=False):
  32. net = slim.ops.conv2d(net, 64, [3, 3], scope='conv3')
  33. # Get all model variables from all the layers.
  34. model_variables = slim.variables.get_variables()
  35. # Get all model variables from a specific the layer, i.e 'conv1'.
  36. conv1_variables = slim.variables.get_variables('conv1')
  37. # Get all weights from all the layers.
  38. weights = slim.variables.get_variables_by_name('weights')
  39. # Get all bias from all the layers.
  40. biases = slim.variables.get_variables_by_name('biases')
  41. # Get all variables in the VARIABLES_TO_RESTORE collection
  42. # (i.e. only those created by 'conv1' and 'conv2')
  43. variables_to_restore = tf.get_collection(slim.variables.VARIABLES_TO_RESTORE)
  44. ************************************************
  45. * Initializing model variables from a checkpoint
  46. ************************************************
  47. # Create some variables.
  48. v1 = slim.variables.variable(name="v1", ..., restore=False)
  49. v2 = slim.variables.variable(name="v2", ...) # By default restore=True
  50. ...
  51. # The list of variables to restore should only contain 'v2'.
  52. variables_to_restore = tf.get_collection(slim.variables.VARIABLES_TO_RESTORE)
  53. restorer = tf.train.Saver(variables_to_restore)
  54. with tf.Session() as sess:
  55. # Restore variables from disk.
  56. restorer.restore(sess, "/tmp/model.ckpt")
  57. print("Model restored.")
  58. # Do some work with the model
  59. ...
  60. """
  61. from __future__ import absolute_import
  62. from __future__ import division
  63. from __future__ import print_function
  64. import tensorflow as tf
  65. from inception.slim import scopes
  66. # Collection containing all the variables created using slim.variables
  67. VARIABLES_COLLECTION = '_variables_'
  68. # Collection containing all the slim.variables that are marked to_restore
  69. VARIABLES_TO_RESTORE = '_variables_to_restore_'
  70. def get_variable_given_name(var):
  71. """Gets the variable given name without the scope.
  72. Args:
  73. var: a variable.
  74. Returns:
  75. the given name of the variable without the scope.
  76. """
  77. name = var.op.name
  78. if '/' in name:
  79. name = name.split('/')[-1]
  80. return name
  81. def default_collections(given_name, restore):
  82. """Define the set of default collections that variables should be added.
  83. Args:
  84. given_name: the given name of the variable.
  85. restore: whether the variable should be added to the VARIABLES_TO_RESTORE
  86. collection.
  87. Returns:
  88. a list of default collections.
  89. """
  90. defaults = [tf.GraphKeys.VARIABLES, VARIABLES_COLLECTION]
  91. defaults += [VARIABLES_COLLECTION + given_name]
  92. if restore:
  93. defaults += [VARIABLES_TO_RESTORE]
  94. return defaults
  95. def add_variable(var, restore=True):
  96. """Adds a variable to the default set of collections.
  97. Args:
  98. var: a variable.
  99. restore: whether the variable should be added to the
  100. VARIABLES_TO_RESTORE collection.
  101. """
  102. given_name = get_variable_given_name(var)
  103. for collection in default_collections(given_name, restore):
  104. if var not in tf.get_collection(collection):
  105. tf.add_to_collection(collection, var)
  106. def get_variables(prefix=None, suffix=None):
  107. """Gets the list of variables, filtered by prefix and/or suffix.
  108. Args:
  109. prefix: an optional prefix for filtering the variables to return.
  110. suffix: an optional suffix for filtering the variables to return.
  111. Returns:
  112. a list of variables with prefix and suffix.
  113. """
  114. candidates = tf.get_collection(VARIABLES_COLLECTION, prefix)
  115. if suffix is not None:
  116. candidates = [var for var in candidates if var.op.name.endswith(suffix)]
  117. return candidates
  118. def get_variables_by_name(given_name, prefix=None):
  119. """Gets the list of variables were given that name.
  120. Args:
  121. given_name: name given to the variable without scope.
  122. prefix: an optional prefix for filtering the variables to return.
  123. Returns:
  124. a list of variables with prefix and suffix.
  125. """
  126. return tf.get_collection(VARIABLES_COLLECTION + given_name, prefix)
  127. def get_unique_variable(name):
  128. """Gets the variable uniquely identified by that name.
  129. Args:
  130. name: a name that uniquely identifies the variable.
  131. Returns:
  132. a tensorflow variable.
  133. Raises:
  134. ValueError: if no variable uniquely identified by the name exists.
  135. """
  136. candidates = tf.get_collection(tf.GraphKeys.VARIABLES, name)
  137. if not candidates:
  138. raise ValueError('Couldnt find variable %s' % name)
  139. for candidate in candidates:
  140. if candidate.op.name == name:
  141. return candidate
  142. raise ValueError('Variable %s does not uniquely identify a variable', name)
  143. @scopes.add_arg_scope
  144. def variable(name, shape=None, dtype=tf.float32, initializer=None,
  145. regularizer=None, trainable=True, collections=None, device='',
  146. restore=True):
  147. """Gets an existing variable with these parameters or creates a new one.
  148. It also add itself to a group with its name.
  149. Args:
  150. name: the name of the new or existing variable.
  151. shape: shape of the new or existing variable.
  152. dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
  153. initializer: initializer for the variable if one is created.
  154. regularizer: a (Tensor -> Tensor or None) function; the result of
  155. applying it on a newly created variable will be added to the collection
  156. GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
  157. trainable: If `True` also add the variable to the graph collection
  158. `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
  159. collections: A list of collection names to which the Variable will be added.
  160. Note that the variable is always also added to the tf.GraphKeys.VARIABLES
  161. collection.
  162. device: Optional device to place the variable. It can be an string or a
  163. function that is called to get the device for the variable.
  164. restore: whether the variable should be added to the
  165. VARIABLES_TO_RESTORE collection.
  166. Returns:
  167. The created or existing variable.
  168. """
  169. # Instantiate the device for this variable if it is passed as a function.
  170. if device and callable(device):
  171. device = device()
  172. collections = set(list(collections or []) + default_collections(name,
  173. restore))
  174. with tf.device(device):
  175. return tf.get_variable(name, shape=shape, dtype=dtype,
  176. initializer=initializer, regularizer=regularizer,
  177. trainable=trainable, collections=collections)