sentence_io_test.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import os
  16. import tensorflow as tf
  17. from tensorflow.python.framework import test_util
  18. from tensorflow.python.platform import googletest
  19. from dragnn.python import sentence_io
  20. from syntaxnet import sentence_pb2
  21. import syntaxnet.load_parser_ops
  22. FLAGS = tf.app.flags.FLAGS
  23. if not hasattr(FLAGS, 'test_srcdir'):
  24. FLAGS.test_srcdir = ''
  25. if not hasattr(FLAGS, 'test_tmpdir'):
  26. FLAGS.test_tmpdir = tf.test.get_temp_dir()
  27. class ConllSentenceReaderTest(test_util.TensorFlowTestCase):
  28. def setUp(self):
  29. # This dataset contains 54 sentences.
  30. self.filepath = os.path.join(
  31. FLAGS.test_srcdir,
  32. 'syntaxnet/testdata/mini-training-set')
  33. self.batch_size = 20
  34. def assertParseable(self, reader, expected_num, expected_last):
  35. sentences, last = reader.read()
  36. self.assertEqual(expected_num, len(sentences))
  37. self.assertEqual(expected_last, last)
  38. for s in sentences:
  39. pb = sentence_pb2.Sentence()
  40. pb.ParseFromString(s)
  41. self.assertGreater(len(pb.token), 0)
  42. def testReadFirstSentence(self):
  43. reader = sentence_io.ConllSentenceReader(self.filepath, 1)
  44. sentences, last = reader.read()
  45. self.assertEqual(1, len(sentences))
  46. pb = sentence_pb2.Sentence()
  47. pb.ParseFromString(sentences[0])
  48. self.assertFalse(last)
  49. self.assertEqual(
  50. u'I knew I could do it properly if given the right kind of support .',
  51. pb.text)
  52. def testReadFromTextFile(self):
  53. reader = sentence_io.ConllSentenceReader(self.filepath, self.batch_size)
  54. self.assertParseable(reader, self.batch_size, False)
  55. self.assertParseable(reader, self.batch_size, False)
  56. self.assertParseable(reader, 14, True)
  57. self.assertParseable(reader, 0, True)
  58. self.assertParseable(reader, 0, True)
  59. def testReadAndProjectivize(self):
  60. reader = sentence_io.ConllSentenceReader(
  61. self.filepath, self.batch_size, projectivize=True)
  62. self.assertParseable(reader, self.batch_size, False)
  63. self.assertParseable(reader, self.batch_size, False)
  64. self.assertParseable(reader, 14, True)
  65. self.assertParseable(reader, 0, True)
  66. self.assertParseable(reader, 0, True)
  67. if __name__ == '__main__':
  68. googletest.main()