make_parser_spec.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Construct the spec for the CONLL2017 Parser baseline."""
  16. import tensorflow as tf
  17. from tensorflow.python.platform import gfile
  18. from dragnn.protos import spec_pb2
  19. from dragnn.python import spec_builder
  20. flags = tf.app.flags
  21. FLAGS = flags.FLAGS
  22. flags.DEFINE_string('spec_file', 'parser_spec.textproto',
  23. 'Filename to save the spec to.')
  24. def main(unused_argv):
  25. # Left-to-right, character-based LSTM.
  26. char2word = spec_builder.ComponentSpecBuilder('char_lstm')
  27. char2word.set_network_unit(
  28. name='wrapped_units.LayerNormBasicLSTMNetwork',
  29. hidden_layer_sizes='256')
  30. char2word.set_transition_system(name='char-shift-only', left_to_right='true')
  31. char2word.add_fixed_feature(name='chars', fml='char-input.text-char',
  32. embedding_dim=16)
  33. # Lookahead LSTM reads right-to-left to represent the rightmost context of the
  34. # words. It gets word embeddings from the char model.
  35. lookahead = spec_builder.ComponentSpecBuilder('lookahead')
  36. lookahead.set_network_unit(
  37. name='wrapped_units.LayerNormBasicLSTMNetwork',
  38. hidden_layer_sizes='256')
  39. lookahead.set_transition_system(name='shift-only', left_to_right='false')
  40. lookahead.add_link(source=char2word, fml='input.last-char-focus',
  41. embedding_dim=64)
  42. # Construct the tagger. This is a simple left-to-right LSTM sequence tagger.
  43. tagger = spec_builder.ComponentSpecBuilder('tagger')
  44. tagger.set_network_unit(
  45. name='wrapped_units.LayerNormBasicLSTMNetwork',
  46. hidden_layer_sizes='256')
  47. tagger.set_transition_system(name='tagger')
  48. tagger.add_token_link(source=lookahead, fml='input.focus', embedding_dim=64)
  49. # Construct the parser.
  50. parser = spec_builder.ComponentSpecBuilder('parser')
  51. parser.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256',
  52. layer_norm_hidden='true')
  53. parser.set_transition_system(name='arc-standard')
  54. parser.add_token_link(source=lookahead, fml='input.focus', embedding_dim=64)
  55. parser.add_token_link(
  56. source=tagger, fml='input.focus stack.focus stack(1).focus',
  57. embedding_dim=64)
  58. # Add discrete features of the predicted parse tree so far, like in Parsey
  59. # McParseface.
  60. parser.add_fixed_feature(name='labels', embedding_dim=16,
  61. fml=' '.join([
  62. 'stack.child(1).label',
  63. 'stack.child(1).sibling(-1).label',
  64. 'stack.child(-1).label',
  65. 'stack.child(-1).sibling(1).label',
  66. 'stack(1).child(1).label',
  67. 'stack(1).child(1).sibling(-1).label',
  68. 'stack(1).child(-1).label',
  69. 'stack(1).child(-1).sibling(1).label',
  70. 'stack.child(2).label',
  71. 'stack.child(-2).label',
  72. 'stack(1).child(2).label',
  73. 'stack(1).child(-2).label']))
  74. # Recurrent connection for the arc-standard parser. For both tokens on the
  75. # stack, we connect to the last time step to either SHIFT or REDUCE that
  76. # token. This allows the parser to build up compositional representations of
  77. # phrases.
  78. parser.add_link(
  79. source=parser, # recurrent connection
  80. name='rnn-stack', # unique identifier
  81. fml='stack.focus stack(1).focus', # look for both stack tokens
  82. source_translator='shift-reduce-step', # maps token indices -> step
  83. embedding_dim=64) # project down to 64 dims
  84. master_spec = spec_pb2.MasterSpec()
  85. master_spec.component.extend(
  86. [char2word.spec, lookahead.spec, tagger.spec, parser.spec])
  87. with gfile.FastGFile(FLAGS.spec_file, 'w') as f:
  88. f.write(str(master_spec).encode('utf-8'))
  89. if __name__ == '__main__':
  90. tf.app.run()