variables.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Contains convenience wrappers for creating variables in TF-Slim.
  16. The variables module is typically used for defining model variables from the
  17. ops routines (see slim.ops). Such variables are used for training, evaluation
  18. and inference of models.
  19. All the variables created through this module would be added to the
  20. MODEL_VARIABLES collection, if you create a model variable outside slim, it can
  21. be added with slim.variables.add_variable(external_variable, reuse).
  22. Usage:
  23. weights_initializer = tf.truncated_normal_initializer(stddev=0.01)
  24. l2_regularizer = lambda t: losses.l2_loss(t, weight=0.0005)
  25. weights = variables.variable('weights',
  26. shape=[100, 100],
  27. initializer=weights_initializer,
  28. regularizer=l2_regularizer,
  29. device='/cpu:0')
  30. biases = variables.variable('biases',
  31. shape=[100],
  32. initializer=tf.zeros_initializer,
  33. device='/cpu:0')
  34. # More complex example.
  35. net = slim.ops.conv2d(input, 32, [3, 3], scope='conv1')
  36. net = slim.ops.conv2d(net, 64, [3, 3], scope='conv2')
  37. with slim.arg_scope([variables.variable], restore=False):
  38. net = slim.ops.conv2d(net, 64, [3, 3], scope='conv3')
  39. # Get all model variables from all the layers.
  40. model_variables = slim.variables.get_variables()
  41. # Get all model variables from a specific the layer, i.e 'conv1'.
  42. conv1_variables = slim.variables.get_variables('conv1')
  43. # Get all weights from all the layers.
  44. weights = slim.variables.get_variables_by_name('weights')
  45. # Get all bias from all the layers.
  46. biases = slim.variables.get_variables_by_name('biases')
  47. # Get all variables to restore.
  48. # (i.e. only those created by 'conv1' and 'conv2')
  49. variables_to_restore = slim.variables.get_variables_to_restore()
  50. ************************************************
  51. * Initializing model variables from a checkpoint
  52. ************************************************
  53. # Create some variables.
  54. v1 = slim.variables.variable(name="v1", ..., restore=False)
  55. v2 = slim.variables.variable(name="v2", ...) # By default restore=True
  56. ...
  57. # The list of variables to restore should only contain 'v2'.
  58. variables_to_restore = slim.variables.get_variables_to_restore()
  59. restorer = tf.train.Saver(variables_to_restore)
  60. with tf.Session() as sess:
  61. # Restore variables from disk.
  62. restorer.restore(sess, "/tmp/model.ckpt")
  63. print("Model restored.")
  64. # Do some work with the model
  65. ...
  66. """
  67. from __future__ import absolute_import
  68. from __future__ import division
  69. from __future__ import print_function
  70. import tensorflow as tf
  71. from inception.slim import scopes
  72. # Collection containing all the variables created using slim.variables
  73. MODEL_VARIABLES = '_model_variables_'
  74. # Collection containing the slim.variables that are created with restore=True.
  75. VARIABLES_TO_RESTORE = '_variables_to_restore_'
  76. def add_variable(var, restore=True):
  77. """Adds a variable to the MODEL_VARIABLES collection.
  78. Optionally it will add the variable to the VARIABLES_TO_RESTORE collection.
  79. Args:
  80. var: a variable.
  81. restore: whether the variable should be added to the
  82. VARIABLES_TO_RESTORE collection.
  83. """
  84. collections = [MODEL_VARIABLES]
  85. if restore:
  86. collections.append(VARIABLES_TO_RESTORE)
  87. for collection in collections:
  88. if var not in tf.get_collection(collection):
  89. tf.add_to_collection(collection, var)
  90. def get_variables(scope=None, suffix=None):
  91. """Gets the list of variables, filtered by scope and/or suffix.
  92. Args:
  93. scope: an optional scope for filtering the variables to return.
  94. suffix: an optional suffix for filtering the variables to return.
  95. Returns:
  96. a copied list of variables with scope and suffix.
  97. """
  98. candidates = tf.get_collection(MODEL_VARIABLES, scope)[:]
  99. if suffix is not None:
  100. candidates = [var for var in candidates if var.op.name.endswith(suffix)]
  101. return candidates
  102. def get_variables_to_restore():
  103. """Gets the list of variables to restore.
  104. Returns:
  105. a copied list of variables.
  106. """
  107. return tf.get_collection(VARIABLES_TO_RESTORE)[:]
  108. def get_variables_by_name(given_name, scope=None):
  109. """Gets the list of variables that were given that name.
  110. Args:
  111. given_name: name given to the variable without scope.
  112. scope: an optional scope for filtering the variables to return.
  113. Returns:
  114. a copied list of variables with the given name and prefix.
  115. """
  116. return get_variables(scope=scope, suffix=given_name)
  117. def get_unique_variable(name):
  118. """Gets the variable uniquely identified by that name.
  119. Args:
  120. name: a name that uniquely identifies the variable.
  121. Returns:
  122. a tensorflow variable.
  123. Raises:
  124. ValueError: if no variable uniquely identified by the name exists.
  125. """
  126. candidates = tf.get_collection(tf.GraphKeys.VARIABLES, name)
  127. if not candidates:
  128. raise ValueError('Couldnt find variable %s' % name)
  129. for candidate in candidates:
  130. if candidate.op.name == name:
  131. return candidate
  132. raise ValueError('Variable %s does not uniquely identify a variable', name)
  133. class VariableDeviceChooser(object):
  134. """Slim device chooser for variables.
  135. When using a parameter server it will assign them in a round-robin fashion.
  136. When not using a parameter server it allows GPU:0 placement otherwise CPU:0.
  137. """
  138. def __init__(self,
  139. num_parameter_servers=0,
  140. ps_device='/job:ps',
  141. placement='CPU:0'):
  142. """Initialize VariableDeviceChooser.
  143. Args:
  144. num_parameter_servers: number of parameter servers.
  145. ps_device: string representing the parameter server device.
  146. placement: string representing the placement of the variable either CPU:0
  147. or GPU:0. When using parameter servers forced to CPU:0.
  148. """
  149. self._num_ps = num_parameter_servers
  150. self._ps_device = ps_device
  151. self._placement = placement if num_parameter_servers == 0 else 'CPU:0'
  152. self._next_task_id = 0
  153. def __call__(self, op):
  154. device_string = ''
  155. if self._num_ps > 0:
  156. task_id = self._next_task_id
  157. self._next_task_id = (self._next_task_id + 1) % self._num_ps
  158. device_string = '%s/task:%d' % (self._ps_device, task_id)
  159. device_string += '/%s' % self._placement
  160. return device_string
  161. # TODO(sguada) Remove once get_variable is able to colocate op.devices.
  162. def variable_device(device, name):
  163. """Fix the variable device to colocate its ops."""
  164. if callable(device):
  165. var_name = tf.get_variable_scope().name + '/' + name
  166. var_def = tf.NodeDef(name=var_name, op='Variable')
  167. device = device(var_def)
  168. if device is None:
  169. device = ''
  170. return device
  171. @scopes.add_arg_scope
  172. def global_step(device=''):
  173. """Returns the global step variable.
  174. Args:
  175. device: Optional device to place the variable. It can be an string or a
  176. function that is called to get the device for the variable.
  177. Returns:
  178. the tensor representing the global step variable.
  179. """
  180. global_step_ref = tf.get_collection(tf.GraphKeys.GLOBAL_STEP)
  181. if global_step_ref:
  182. return global_step_ref[0]
  183. else:
  184. collections = [
  185. VARIABLES_TO_RESTORE,
  186. tf.GraphKeys.VARIABLES,
  187. tf.GraphKeys.GLOBAL_STEP,
  188. ]
  189. # Get the device for the variable.
  190. with tf.device(variable_device(device, 'global_step')):
  191. return tf.get_variable('global_step', shape=[], dtype=tf.int64,
  192. initializer=tf.zeros_initializer,
  193. trainable=False, collections=collections)
  194. @scopes.add_arg_scope
  195. def variable(name, shape=None, dtype=tf.float32, initializer=None,
  196. regularizer=None, trainable=True, collections=None, device='',
  197. restore=True):
  198. """Gets an existing variable with these parameters or creates a new one.
  199. It also add itself to a group with its name.
  200. Args:
  201. name: the name of the new or existing variable.
  202. shape: shape of the new or existing variable.
  203. dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
  204. initializer: initializer for the variable if one is created.
  205. regularizer: a (Tensor -> Tensor or None) function; the result of
  206. applying it on a newly created variable will be added to the collection
  207. GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
  208. trainable: If `True` also add the variable to the graph collection
  209. `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
  210. collections: A list of collection names to which the Variable will be added.
  211. Note that the variable is always also added to the tf.GraphKeys.VARIABLES
  212. and MODEL_VARIABLES collections.
  213. device: Optional device to place the variable. It can be an string or a
  214. function that is called to get the device for the variable.
  215. restore: whether the variable should be added to the
  216. VARIABLES_TO_RESTORE collection.
  217. Returns:
  218. The created or existing variable.
  219. """
  220. collections = list(collections or [])
  221. # Make sure variables are added to tf.GraphKeys.VARIABLES and MODEL_VARIABLES
  222. collections += [tf.GraphKeys.VARIABLES, MODEL_VARIABLES]
  223. # Add to VARIABLES_TO_RESTORE if necessary
  224. if restore:
  225. collections.append(VARIABLES_TO_RESTORE)
  226. # Remove duplicates
  227. collections = set(collections)
  228. # Get the device for the variable.
  229. with tf.device(variable_device(device, name)):
  230. return tf.get_variable(name, shape=shape, dtype=dtype,
  231. initializer=initializer, regularizer=regularizer,
  232. trainable=trainable, collections=collections)