scopes_test.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests slim.scopes."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from inception.slim import scopes
  21. @scopes.add_arg_scope
  22. def func1(*args, **kwargs):
  23. return (args, kwargs)
  24. @scopes.add_arg_scope
  25. def func2(*args, **kwargs):
  26. return (args, kwargs)
  27. class ArgScopeTest(tf.test.TestCase):
  28. def testEmptyArgScope(self):
  29. with self.test_session():
  30. self.assertEqual(scopes._current_arg_scope(), {})
  31. def testSimpleArgScope(self):
  32. func1_args = (0,)
  33. func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
  34. with self.test_session():
  35. with scopes.arg_scope([func1], a=1, b=None, c=[1]):
  36. args, kwargs = func1(0)
  37. self.assertTupleEqual(args, func1_args)
  38. self.assertDictEqual(kwargs, func1_kwargs)
  39. def testSimpleArgScopeWithTuple(self):
  40. func1_args = (0,)
  41. func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
  42. with self.test_session():
  43. with scopes.arg_scope((func1,), a=1, b=None, c=[1]):
  44. args, kwargs = func1(0)
  45. self.assertTupleEqual(args, func1_args)
  46. self.assertDictEqual(kwargs, func1_kwargs)
  47. def testOverwriteArgScope(self):
  48. func1_args = (0,)
  49. func1_kwargs = {'a': 1, 'b': 2, 'c': [1]}
  50. with scopes.arg_scope([func1], a=1, b=None, c=[1]):
  51. args, kwargs = func1(0, b=2)
  52. self.assertTupleEqual(args, func1_args)
  53. self.assertDictEqual(kwargs, func1_kwargs)
  54. def testNestedArgScope(self):
  55. func1_args = (0,)
  56. func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
  57. with scopes.arg_scope([func1], a=1, b=None, c=[1]):
  58. args, kwargs = func1(0)
  59. self.assertTupleEqual(args, func1_args)
  60. self.assertDictEqual(kwargs, func1_kwargs)
  61. func1_kwargs['b'] = 2
  62. with scopes.arg_scope([func1], b=2):
  63. args, kwargs = func1(0)
  64. self.assertTupleEqual(args, func1_args)
  65. self.assertDictEqual(kwargs, func1_kwargs)
  66. def testSharedArgScope(self):
  67. func1_args = (0,)
  68. func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
  69. with scopes.arg_scope([func1, func2], a=1, b=None, c=[1]):
  70. args, kwargs = func1(0)
  71. self.assertTupleEqual(args, func1_args)
  72. self.assertDictEqual(kwargs, func1_kwargs)
  73. args, kwargs = func2(0)
  74. self.assertTupleEqual(args, func1_args)
  75. self.assertDictEqual(kwargs, func1_kwargs)
  76. def testSharedArgScopeTuple(self):
  77. func1_args = (0,)
  78. func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
  79. with scopes.arg_scope((func1, func2), a=1, b=None, c=[1]):
  80. args, kwargs = func1(0)
  81. self.assertTupleEqual(args, func1_args)
  82. self.assertDictEqual(kwargs, func1_kwargs)
  83. args, kwargs = func2(0)
  84. self.assertTupleEqual(args, func1_args)
  85. self.assertDictEqual(kwargs, func1_kwargs)
  86. def testPartiallySharedArgScope(self):
  87. func1_args = (0,)
  88. func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
  89. func2_args = (1,)
  90. func2_kwargs = {'a': 1, 'b': None, 'd': [2]}
  91. with scopes.arg_scope([func1, func2], a=1, b=None):
  92. with scopes.arg_scope([func1], c=[1]), scopes.arg_scope([func2], d=[2]):
  93. args, kwargs = func1(0)
  94. self.assertTupleEqual(args, func1_args)
  95. self.assertDictEqual(kwargs, func1_kwargs)
  96. args, kwargs = func2(1)
  97. self.assertTupleEqual(args, func2_args)
  98. self.assertDictEqual(kwargs, func2_kwargs)
  99. if __name__ == '__main__':
  100. tf.test.main()