decoder_test.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for decoder."""
  16. import os
  17. import tensorflow as tf
  18. import decoder
  19. def _testdata(filename):
  20. return os.path.join('../testdata/', filename)
  21. class DecoderTest(tf.test.TestCase):
  22. def testCodesFromCTC(self):
  23. """Tests that the simple CTC decoder drops nulls and duplicates.
  24. """
  25. ctc_labels = [9, 9, 9, 1, 9, 2, 2, 3, 9, 9, 0, 0, 1, 9, 1, 9, 9, 9]
  26. decode = decoder.Decoder(filename=None)
  27. non_null_labels = decode._CodesFromCTC(
  28. ctc_labels, merge_dups=False, null_label=9)
  29. self.assertEqual(non_null_labels, [1, 2, 2, 3, 0, 0, 1, 1])
  30. idempotent_labels = decode._CodesFromCTC(
  31. non_null_labels, merge_dups=False, null_label=9)
  32. self.assertEqual(idempotent_labels, non_null_labels)
  33. collapsed_labels = decode._CodesFromCTC(
  34. ctc_labels, merge_dups=True, null_label=9)
  35. self.assertEqual(collapsed_labels, [1, 2, 3, 0, 1, 1])
  36. non_idempotent_labels = decode._CodesFromCTC(
  37. collapsed_labels, merge_dups=True, null_label=9)
  38. self.assertEqual(non_idempotent_labels, [1, 2, 3, 0, 1])
  39. def testStringFromCTC(self):
  40. """Tests that the decoder can decode sequences including multi-codes.
  41. """
  42. # - f - a r - m(1/2)m -junk sp b a r - n -
  43. ctc_labels = [9, 6, 9, 1, 3, 9, 4, 9, 5, 5, 9, 5, 0, 2, 1, 3, 9, 4, 9]
  44. decode = decoder.Decoder(filename=_testdata('charset_size_10.txt'))
  45. text = decode.StringFromCTC(ctc_labels, merge_dups=True, null_label=9)
  46. self.assertEqual(text, 'farm barn')
  47. if __name__ == '__main__':
  48. tf.test.main()