# Copyright 2016 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for slim.variables.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf from inception.slim import scopes from inception.slim import variables class VariablesTest(tf.test.TestCase): def testCreateVariable(self): with self.test_session(): with tf.variable_scope('A'): a = variables.variable('a', [5]) self.assertEquals(a.op.name, 'A/a') self.assertListEqual(a.get_shape().as_list(), [5]) def testGetVariableGivenName(self): with self.test_session(): with tf.variable_scope('A'): a = variables.variable('a', [5]) with tf.variable_scope('B'): b = variables.variable('a', [5]) self.assertEquals('a', variables.get_variable_given_name(a)) self.assertEquals('a', variables.get_variable_given_name(b)) def testGetVariableGivenNameScoped(self): with self.test_session(): with tf.variable_scope('A'): a = variables.variable('a', [5]) b = variables.variable('b', [5]) self.assertEquals([a], variables.get_variables_by_name('a')) self.assertEquals([b], variables.get_variables_by_name('b')) def testGetVariables(self): with self.test_session(): with tf.variable_scope('A'): a = variables.variable('a', [5]) with tf.variable_scope('B'): b = variables.variable('a', [5]) self.assertEquals([a], variables.get_variables('A')) self.assertEquals([b], variables.get_variables('B')) def testGetVariablesSuffix(self): with self.test_session(): with tf.variable_scope('A'): a = variables.variable('a', [5]) with tf.variable_scope('A'): b = variables.variable('b', [5]) self.assertEquals([a], variables.get_variables(suffix='a')) self.assertEquals([b], variables.get_variables(suffix='b')) def testGetVariableWithSingleVar(self): with self.test_session(): with tf.variable_scope('parent'): a = variables.variable('child', [5]) self.assertEquals(a, variables.get_unique_variable('parent/child')) def testGetVariableWithDistractors(self): with self.test_session(): with tf.variable_scope('parent'): a = variables.variable('child', [5]) with tf.variable_scope('child'): variables.variable('grandchild1', [7]) variables.variable('grandchild2', [9]) self.assertEquals(a, variables.get_unique_variable('parent/child')) def testGetVariableThrowsExceptionWithNoMatch(self): var_name = 'cant_find_me' with self.test_session(): with self.assertRaises(ValueError): variables.get_unique_variable(var_name) def testGetThrowsExceptionWithChildrenButNoMatch(self): var_name = 'parent/child' with self.test_session(): with tf.variable_scope(var_name): variables.variable('grandchild1', [7]) variables.variable('grandchild2', [9]) with self.assertRaises(ValueError): variables.get_unique_variable(var_name) def testGetVariablesToRestore(self): with self.test_session(): with tf.variable_scope('A'): a = variables.variable('a', [5]) with tf.variable_scope('B'): b = variables.variable('b', [5]) self.assertListEqual([a, b], tf.get_collection(variables.VARIABLES_TO_RESTORE)) def testGetVariablesToRestorePartial(self): with self.test_session(): with tf.variable_scope('A'): a = variables.variable('a', [5]) with tf.variable_scope('B'): b = variables.variable('b', [5], restore=False) self.assertListEqual([a, b], variables.get_variables()) self.assertListEqual([a], tf.get_collection(variables.VARIABLES_TO_RESTORE)) def testReuseVariable(self): with self.test_session(): with tf.variable_scope('A'): a = variables.variable('a', []) with tf.variable_scope('A', reuse=True): b = variables.variable('a', []) self.assertEquals(a, b) self.assertListEqual([a], variables.get_variables()) def testVariableWithDevice(self): with self.test_session(): with tf.variable_scope('A'): a = variables.variable('a', [], device='cpu:0') b = variables.variable('b', [], device='cpu:1') self.assertDeviceEqual(a.device, 'cpu:0') self.assertDeviceEqual(b.device, 'cpu:1') def testVariableWithDeviceFromScope(self): with self.test_session(): with tf.device('/cpu:0'): a = variables.variable('a', []) b = variables.variable('b', [], device='cpu:1') self.assertDeviceEqual(a.device, 'cpu:0') self.assertDeviceEqual(b.device, 'cpu:1') def testVariableCollection(self): with self.test_session(): a = variables.variable('a', [], collections='A') b = variables.variable('b', [], collections='B') self.assertEquals(a, tf.get_collection('A')[0]) self.assertEquals(b, tf.get_collection('B')[0]) def testVariableCollections(self): with self.test_session(): a = variables.variable('a', [], collections=['A', 'C']) b = variables.variable('b', [], collections=['B', 'C']) self.assertEquals(a, tf.get_collection('A')[0]) self.assertEquals(b, tf.get_collection('B')[0]) def testVariableCollectionsWithArgScope(self): with self.test_session(): with scopes.arg_scope([variables.variable], collections='A'): a = variables.variable('a', []) b = variables.variable('b', []) self.assertListEqual([a, b], tf.get_collection('A')) def testVariableCollectionsWithArgScopeNested(self): with self.test_session(): with scopes.arg_scope([variables.variable], collections='A'): a = variables.variable('a', []) with scopes.arg_scope([variables.variable], collections='B'): b = variables.variable('b', []) self.assertEquals(a, tf.get_collection('A')[0]) self.assertEquals(b, tf.get_collection('B')[0]) def testVariableCollectionsWithArgScopeNonNested(self): with self.test_session(): with scopes.arg_scope([variables.variable], collections='A'): a = variables.variable('a', []) with scopes.arg_scope([variables.variable], collections='B'): b = variables.variable('b', []) variables.variable('c', []) self.assertListEqual([a], tf.get_collection('A')) self.assertListEqual([b], tf.get_collection('B')) def testVariableRestoreWithArgScopeNested(self): with self.test_session(): with scopes.arg_scope([variables.variable], restore=True): a = variables.variable('a', []) with scopes.arg_scope([variables.variable], trainable=False, collections=['A', 'B']): b = variables.variable('b', []) c = variables.variable('c', []) self.assertListEqual([a, b, c], tf.get_collection(variables.VARIABLES_TO_RESTORE)) self.assertListEqual([a, c], tf.trainable_variables()) self.assertListEqual([b], tf.get_collection('A')) self.assertListEqual([b], tf.get_collection('B')) if __name__ == '__main__': tf.test.main()