text_formats_test.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # coding=utf-8
  2. # Copyright 2016 Google Inc. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # ==============================================================================
  16. """Tests for english_tokenizer."""
  17. # disable=no-name-in-module,unused-import,g-bad-import-order,maybe-no-member
  18. import os.path
  19. import tensorflow as tf
  20. import syntaxnet.load_parser_ops
  21. from tensorflow.python.framework import test_util
  22. from tensorflow.python.platform import googletest
  23. from tensorflow.python.platform import tf_logging as logging
  24. from syntaxnet import sentence_pb2
  25. from syntaxnet import task_spec_pb2
  26. from syntaxnet.ops import gen_parser_ops
  27. FLAGS = tf.app.flags.FLAGS
  28. class TextFormatsTest(test_util.TensorFlowTestCase):
  29. def setUp(self):
  30. if not hasattr(FLAGS, 'test_srcdir'):
  31. FLAGS.test_srcdir = ''
  32. if not hasattr(FLAGS, 'test_tmpdir'):
  33. FLAGS.test_tmpdir = tf.test.get_temp_dir()
  34. self.corpus_file = os.path.join(FLAGS.test_tmpdir, 'documents.conll')
  35. self.context_file = os.path.join(FLAGS.test_tmpdir, 'context.pbtxt')
  36. def AddInput(self, name, file_pattern, record_format, context):
  37. inp = context.input.add()
  38. inp.name = name
  39. inp.record_format.append(record_format)
  40. inp.part.add().file_pattern = file_pattern
  41. def AddParameter(self, name, value, context):
  42. param = context.parameter.add()
  43. param.name = name
  44. param.value = value
  45. def WriteContext(self, corpus_format):
  46. context = task_spec_pb2.TaskSpec()
  47. self.AddInput('documents', self.corpus_file, corpus_format, context)
  48. for name in ('word-map', 'lcword-map', 'tag-map',
  49. 'category-map', 'label-map', 'prefix-table',
  50. 'suffix-table', 'tag-to-category'):
  51. self.AddInput(name, os.path.join(FLAGS.test_tmpdir, name), '', context)
  52. logging.info('Writing context to: %s', self.context_file)
  53. with open(self.context_file, 'w') as f:
  54. f.write(str(context))
  55. def ReadNextDocument(self, sess, sentence):
  56. sentence_str, = sess.run([sentence])
  57. if sentence_str:
  58. sentence_doc = sentence_pb2.Sentence()
  59. sentence_doc.ParseFromString(sentence_str[0])
  60. else:
  61. sentence_doc = None
  62. return sentence_doc
  63. def CheckTokenization(self, sentence, tokenization):
  64. self.WriteContext('english-text')
  65. logging.info('Writing text file to: %s', self.corpus_file)
  66. with open(self.corpus_file, 'w') as f:
  67. f.write(sentence)
  68. sentence, _ = gen_parser_ops.document_source(
  69. self.context_file, batch_size=1)
  70. with self.test_session() as sess:
  71. sentence_doc = self.ReadNextDocument(sess, sentence)
  72. self.assertEqual(' '.join([t.word for t in sentence_doc.token]),
  73. tokenization)
  74. def CheckUntokenizedDoc(self, sentence, words, starts, ends):
  75. self.WriteContext('untokenized-text')
  76. logging.info('Writing text file to: %s', self.corpus_file)
  77. with open(self.corpus_file, 'w') as f:
  78. f.write(sentence)
  79. sentence, _ = gen_parser_ops.document_source(
  80. self.context_file, batch_size=1)
  81. with self.test_session() as sess:
  82. sentence_doc = self.ReadNextDocument(sess, sentence)
  83. self.assertEqual(len(sentence_doc.token), len(words))
  84. self.assertEqual(len(sentence_doc.token), len(starts))
  85. self.assertEqual(len(sentence_doc.token), len(ends))
  86. for i, token in enumerate(sentence_doc.token):
  87. self.assertEqual(token.word.encode('utf-8'), words[i])
  88. self.assertEqual(token.start, starts[i])
  89. self.assertEqual(token.end, ends[i])
  90. def testUntokenized(self):
  91. self.CheckUntokenizedDoc('一个测试', ['一', '个', '测', '试'],
  92. [0, 3, 6, 9], [2, 5, 8, 11])
  93. self.CheckUntokenizedDoc('Hello ', ['H', 'e', 'l', 'l', 'o', ' '],
  94. [0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5])
  95. def testSegmentationTrainingData(self):
  96. doc1_lines = ['测试 NO_SPACE\n',
  97. '的 NO_SPACE\n',
  98. '句子 NO_SPACE']
  99. doc1_text = '测试的句子'
  100. doc1_tokens = ['测', '试', '的', '句', '子']
  101. doc1_break_levles = [1, 0, 1, 1, 0]
  102. doc2_lines = ['That NO_SPACE\n',
  103. '\'s SPACE\n',
  104. 'a SPACE\n',
  105. 'good SPACE\n',
  106. 'point NO_SPACE\n',
  107. '. NO_SPACE']
  108. doc2_text = 'That\'s a good point.'
  109. doc2_tokens = ['T', 'h', 'a', 't', '\'', 's', ' ', 'a', ' ', 'g', 'o', 'o',
  110. 'd', ' ', 'p', 'o', 'i', 'n', 't', '.']
  111. doc2_break_levles = [1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0,
  112. 0, 1]
  113. self.CheckSegmentationTrainingData(doc1_lines, doc1_text, doc1_tokens,
  114. doc1_break_levles)
  115. self.CheckSegmentationTrainingData(doc2_lines, doc2_text, doc2_tokens,
  116. doc2_break_levles)
  117. def CheckSegmentationTrainingData(self, doc_lines, doc_text, doc_words,
  118. break_levels):
  119. # Prepare context.
  120. self.WriteContext('segment-train-data')
  121. # Prepare test sentence.
  122. with open(self.corpus_file, 'w') as f:
  123. f.write(''.join(doc_lines))
  124. # Test converted sentence.
  125. sentence, _ = gen_parser_ops.document_source(
  126. self.context_file, batch_size=1)
  127. with self.test_session() as sess:
  128. sentence_doc = self.ReadNextDocument(sess, sentence)
  129. self.assertEqual(doc_text.decode('utf-8'), sentence_doc.text)
  130. self.assertEqual([t.decode('utf-8') for t in doc_words],
  131. [t.word for t in sentence_doc.token])
  132. self.assertEqual(break_levels,
  133. [t.break_level for t in sentence_doc.token])
  134. def testSimple(self):
  135. self.CheckTokenization('Hello, world!', 'Hello , world !')
  136. self.CheckTokenization('"Hello"', "`` Hello ''")
  137. self.CheckTokenization('{"Hello@#$', '-LRB- `` Hello @ # $')
  138. self.CheckTokenization('"Hello..."', "`` Hello ... ''")
  139. self.CheckTokenization('()[]{}<>',
  140. '-LRB- -RRB- -LRB- -RRB- -LRB- -RRB- < >')
  141. self.CheckTokenization('Hello--world', 'Hello -- world')
  142. self.CheckTokenization("Isn't", "Is n't")
  143. self.CheckTokenization("n't", "n't")
  144. self.CheckTokenization('Hello Mr. Smith.', 'Hello Mr. Smith .')
  145. self.CheckTokenization("It's Mr. Smith's.", "It 's Mr. Smith 's .")
  146. self.CheckTokenization("It's the Smiths'.", "It 's the Smiths ' .")
  147. self.CheckTokenization('Gotta go', 'Got ta go')
  148. self.CheckTokenization('50-year-old', '50-year-old')
  149. def testUrl(self):
  150. self.CheckTokenization('http://www.google.com/news is down',
  151. 'http : //www.google.com/news is down')
  152. if __name__ == '__main__':
  153. googletest.main()