scopes.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Contains the new arg_scope used for TF-Slim ops.
  16. Allows one to define models much more compactly by eliminating boilerplate
  17. code. This is accomplished through the use of argument scoping (arg_scope).
  18. Example of how to use scopes.arg_scope:
  19. with scopes.arg_scope(ops.conv2d, padding='SAME',
  20. stddev=0.01, weight_decay=0.0005):
  21. net = ops.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1')
  22. net = ops.conv2d(net, 256, [5, 5], scope='conv2')
  23. The first call to conv2d will use predefined args:
  24. ops.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
  25. stddev=0.01, weight_decay=0.0005, scope='conv1')
  26. The second call to Conv will overwrite padding:
  27. ops.conv2d(inputs, 256, [5, 5], padding='SAME',
  28. stddev=0.01, weight_decay=0.0005, scope='conv2')
  29. Example of how to reuse an arg_scope:
  30. with scopes.arg_scope(ops.conv2d, padding='SAME',
  31. stddev=0.01, weight_decay=0.0005) as conv2d_arg_scope:
  32. net = ops.conv2d(net, 256, [5, 5], scope='conv1')
  33. ....
  34. with scopes.arg_scope(conv2d_arg_scope):
  35. net = ops.conv2d(net, 256, [5, 5], scope='conv2')
  36. Example of how to use scopes.add_arg_scope:
  37. @scopes.add_arg_scope
  38. def conv2d(*args, **kwargs)
  39. """
  40. from __future__ import absolute_import
  41. from __future__ import division
  42. from __future__ import print_function
  43. import contextlib
  44. import functools
  45. from tensorflow.python.framework import ops
  46. _ARGSTACK_KEY = ("__arg_stack",)
  47. _DECORATED_OPS = set()
  48. def _get_arg_stack():
  49. stack = ops.get_collection(_ARGSTACK_KEY)
  50. if stack:
  51. return stack[0]
  52. else:
  53. stack = [{}]
  54. ops.add_to_collection(_ARGSTACK_KEY, stack)
  55. return stack
  56. def _current_arg_scope():
  57. stack = _get_arg_stack()
  58. return stack[-1]
  59. def _add_op(op):
  60. key_op = (op.__module__, op.__name__)
  61. if key_op not in _DECORATED_OPS:
  62. _DECORATED_OPS.add(key_op)
  63. @contextlib.contextmanager
  64. def arg_scope(list_ops_or_scope, **kwargs):
  65. """Stores the default arguments for the given set of list_ops.
  66. For usage, please see examples at top of the file.
  67. Args:
  68. list_ops_or_scope: List or tuple of operations to set argument scope for or
  69. a dictionary containg the current scope. When list_ops_or_scope is a dict,
  70. kwargs must be empty. When list_ops_or_scope is a list or tuple, then
  71. every op in it need to be decorated with @add_arg_scope to work.
  72. **kwargs: keyword=value that will define the defaults for each op in
  73. list_ops. All the ops need to accept the given set of arguments.
  74. Yields:
  75. the current_scope, which is a dictionary of {op: {arg: value}}
  76. Raises:
  77. TypeError: if list_ops is not a list or a tuple.
  78. ValueError: if any op in list_ops has not be decorated with @add_arg_scope.
  79. """
  80. if isinstance(list_ops_or_scope, dict):
  81. # Assumes that list_ops_or_scope is a scope that is being reused.
  82. if kwargs:
  83. raise ValueError("When attempting to re-use a scope by suppling a"
  84. "dictionary, kwargs must be empty.")
  85. current_scope = list_ops_or_scope.copy()
  86. try:
  87. _get_arg_stack().append(current_scope)
  88. yield current_scope
  89. finally:
  90. _get_arg_stack().pop()
  91. else:
  92. # Assumes that list_ops_or_scope is a list/tuple of ops with kwargs.
  93. if not isinstance(list_ops_or_scope, (list, tuple)):
  94. raise TypeError("list_ops_or_scope must either be a list/tuple or reused"
  95. "scope (i.e. dict)")
  96. try:
  97. current_scope = _current_arg_scope().copy()
  98. for op in list_ops_or_scope:
  99. key_op = (op.__module__, op.__name__)
  100. if not has_arg_scope(op):
  101. raise ValueError("%s is not decorated with @add_arg_scope", key_op)
  102. if key_op in current_scope:
  103. current_kwargs = current_scope[key_op].copy()
  104. current_kwargs.update(kwargs)
  105. current_scope[key_op] = current_kwargs
  106. else:
  107. current_scope[key_op] = kwargs.copy()
  108. _get_arg_stack().append(current_scope)
  109. yield current_scope
  110. finally:
  111. _get_arg_stack().pop()
  112. def add_arg_scope(func):
  113. """Decorates a function with args so it can be used within an arg_scope.
  114. Args:
  115. func: function to decorate.
  116. Returns:
  117. A tuple with the decorated function func_with_args().
  118. """
  119. @functools.wraps(func)
  120. def func_with_args(*args, **kwargs):
  121. current_scope = _current_arg_scope()
  122. current_args = kwargs
  123. key_func = (func.__module__, func.__name__)
  124. if key_func in current_scope:
  125. current_args = current_scope[key_func].copy()
  126. current_args.update(kwargs)
  127. return func(*args, **current_args)
  128. _add_op(func)
  129. return func_with_args
  130. def has_arg_scope(func):
  131. """Checks whether a func has been decorated with @add_arg_scope or not.
  132. Args:
  133. func: function to check.
  134. Returns:
  135. a boolean.
  136. """
  137. key_op = (func.__module__, func.__name__)
  138. return key_op in _DECORATED_OPS