123456789101112131415161718192021222324252627282930313233343536373839404142 |
- # Copyright 2016 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Tests for rmsprop."""
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import tensorflow as tf
- import gamma_mixture_model
- FLAGS = tf.flags.FLAGS
- class RMSPropTest(tf.test.TestCase):
- def testGammaMixtureModel(self):
- """Test a Gamma Mixture model.
- """
- mu_manual = gamma_mixture_model.train('rmsprop_manual')
- mu_indexed_slices = gamma_mixture_model.train('rmsprop_indexed_slices')
- self.assertAllClose(mu_indexed_slices, mu_manual, rtol=2e-1, atol=2e-1)
- if __name__ == '__main__':
- tf.test.main()
|