123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- # Copyright 2016 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Contains the new arg_scope used for TF-Slim ops.
- Allows one to define models much more compactly by eliminating boilerplate
- code. This is accomplished through the use of argument scoping (arg_scope).
- Example of how to use scopes.arg_scope:
- with slim.arg_scope(ops.conv2d, padding='SAME',
- stddev=0.01, weight_decay=0.0005):
- net = ops.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1')
- net = ops.conv2d(net, 256, [5, 5], scope='conv2')
- The first call to conv2d will use predefined args:
- ops.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
- stddev=0.01, weight_decay=0.0005, scope='conv1')
- The second call to Conv will overwrite padding:
- ops.conv2d(inputs, 256, [5, 5], padding='SAME',
- stddev=0.01, weight_decay=0.0005, scope='conv2')
- Example of how to use scopes.add_arg_scope:
- @scopes.add_arg_scope
- def conv2d(*args, **kwargs)
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import contextlib
- import functools
- from tensorflow.python.framework import ops
- _ARGSTACK_KEY = ("__arg_stack",)
- _DECORATED_OPS = set()
- def _get_arg_stack():
- stack = ops.get_collection(_ARGSTACK_KEY)
- if stack:
- return stack[0]
- else:
- stack = [{}]
- ops.add_to_collection(_ARGSTACK_KEY, stack)
- return stack
- def _current_arg_scope():
- stack = _get_arg_stack()
- return stack[-1]
- def _add_op(op):
- key_op = (op.__module__, op.__name__)
- if key_op not in _DECORATED_OPS:
- _DECORATED_OPS.add(key_op)
- @contextlib.contextmanager
- def arg_scope(list_ops, **kwargs):
- """Stores the default arguments for the given set of list_ops.
- Args:
- list_ops: List or tuple of operations to set argument scope for. Every op in
- list_ops need to be decorated with @add_arg_scope to work.
- **kwargs: keyword=value that will define the defaults for each op in
- list_ops. All the ops need to accept the given set of arguments.
- Yields:
- the current_scope, which is a dictionary of {op: {arg: value}}
- Raises:
- TypeError: if list_ops is not a list or a tuple.
- ValueError: if any op in list_ops has not be decorated with @add_arg_scope.
- """
- if not isinstance(list_ops, (list, tuple)):
- raise TypeError("list_ops is not a list or a tuple")
- try:
- current_scope = _current_arg_scope().copy()
- for op in list_ops:
- key_op = (op.__module__, op.__name__)
- if not has_arg_scope(op):
- raise ValueError("%s is not decorated with @add_arg_scope", key_op)
- if key_op in current_scope:
- current_kwargs = current_scope[key_op].copy()
- current_kwargs.update(kwargs)
- current_scope[key_op] = current_kwargs
- else:
- current_scope[key_op] = kwargs.copy()
- _get_arg_stack().append(current_scope)
- yield current_scope
- finally:
- _get_arg_stack().pop()
- def add_arg_scope(func):
- """Decorates a function with args so it can be used within an arg_scope.
- Args:
- func: function to decorate.
- Returns:
- A tuple with the decorated function func_with_args().
- """
- @functools.wraps(func)
- def func_with_args(*args, **kwargs):
- current_scope = _current_arg_scope()
- current_args = kwargs
- key_func = (func.__module__, func.__name__)
- if key_func in current_scope:
- current_args = current_scope[key_func].copy()
- current_args.update(kwargs)
- return func(*args, **current_args)
- _add_op(func)
- return func_with_args
- def has_arg_scope(func):
- """Checks whether a func has been decorated with @add_arg_scope or not.
- Args:
- func: function to check.
- Returns:
- a boolean.
- """
- key_op = (func.__module__, func.__name__)
- return key_op in _DECORATED_OPS
|