{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "deletable": true, "editable": true, "scrolled": false }, "outputs": [], "source": [ "import os\n", "import os.path\n", "import random\n", "import time\n", "import tensorflow as tf\n", "\n", "from IPython.display import HTML\n", "from tensorflow.python.platform import gfile\n", "from tensorflow.python.platform import tf_logging as logging\n", "\n", "from google.protobuf import text_format\n", "\n", "from syntaxnet.ops import gen_parser_ops\n", "from syntaxnet import load_parser_ops # This loads the actual op definitions\n", "from syntaxnet import task_spec_pb2\n", "from syntaxnet import sentence_pb2\n", "\n", "from dragnn.protos import spec_pb2\n", "from dragnn.python.sentence_io import ConllSentenceReader\n", "\n", "from dragnn.python import evaluation\n", "from dragnn.python import graph_builder\n", "from dragnn.python import lexicon\n", "from dragnn.python import load_dragnn_cc_impl\n", "from dragnn.python import render_parse_tree_graphviz\n", "from dragnn.python import render_spec_with_graphviz\n", "from dragnn.python import spec_builder\n", "from dragnn.python import trainer_lib\n", "from dragnn.python import visualization\n", "\n", "DATA_DIR = '/opt/tensorflow/syntaxnet/examples/dragnn/data/es'\n", "TENSORBOARD_DIR = '/notebooks/tensorboard'\n", "CHECKPOINT_FILENAME = '{}/spanish.checkpoint'.format(DATA_DIR)\n", "TRAINING_CORPUS_PATH = '{}/es-universal-train.conll'.format(DATA_DIR)\n", "DEV_CORPUS_PATH = '{}/es-universal-dev.conll'.format(DATA_DIR)\n", "\n", "# Some of the IO functions fail miserably if data is missing.\n", "assert os.path.isfile(TRAINING_CORPUS_PATH), 'Could not find training corpus'\n", "\n", "logging.set_verbosity(logging.WARN)\n", "\n", "# Constructs lexical resources for SyntaxNet in the given resource path, from\n", "# the training data.\n", "lexicon.build_lexicon(DATA_DIR, TRAINING_CORPUS_PATH)\n", "\n", "# Construct the 'lookahead' ComponentSpec. This is a simple right-to-left RNN\n", "# sequence model, which encodes the context to the right of each token. It has\n", "# no loss except for the downstream components.\n", "lookahead = spec_builder.ComponentSpecBuilder('lookahead')\n", "lookahead.set_network_unit(\n", " name='FeedForwardNetwork', hidden_layer_sizes='256')\n", "lookahead.set_transition_system(name='shift-only', left_to_right='true')\n", "lookahead.add_fixed_feature(name='words', fml='input.word', embedding_dim=64)\n", "lookahead.add_rnn_link(embedding_dim=-1)\n", "lookahead.fill_from_resources(DATA_DIR)\n", "\n", "# Construct the ComponentSpec for tagging. This is a simple left-to-right RNN\n", "# sequence tagger.\n", "tagger = spec_builder.ComponentSpecBuilder('tagger')\n", "tagger.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256')\n", "tagger.set_transition_system(name='tagger')\n", "tagger.add_rnn_link(embedding_dim=-1)\n", "tagger.add_token_link(source=lookahead, fml='input.focus', embedding_dim=32)\n", "tagger.fill_from_resources(DATA_DIR)\n", "\n", "# Construct the ComponentSpec for parsing.\n", "parser = spec_builder.ComponentSpecBuilder('parser')\n", "parser.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256')\n", "parser.set_transition_system(name='arc-standard')\n", "parser.add_token_link(source=lookahead, fml='input.focus', embedding_dim=32)\n", "parser.add_token_link(\n", " source=tagger,\n", " fml='input.focus stack.focus stack(1).focus',\n", " embedding_dim=32)\n", "\n", "# Recurrent connection for the arc-standard parser. For both tokens on the\n", "# stack, we connect to the last time step to either SHIFT or REDUCE that\n", "# token. This allows the parser to build up compositional representations of\n", "# phrases.\n", "parser.add_link(\n", " source=parser, # recurrent connection\n", " name='rnn-stack', # unique identifier\n", " fml='stack.focus stack(1).focus', # look for both stack tokens\n", " source_translator='shift-reduce-step', # maps token indices -> step\n", " embedding_dim=32) # project down to 32 dims\n", "\n", "parser.fill_from_resources(DATA_DIR)\n", "\n", "master_spec = spec_pb2.MasterSpec()\n", "master_spec.component.extend([lookahead.spec, tagger.spec, parser.spec])\n", "HTML(render_spec_with_graphviz.master_spec_graph(master_spec))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "# Build the TensorFlow graph based on the DRAGNN network spec.\n", "graph = tf.Graph()\n", "with graph.as_default():\n", " hyperparam_config = spec_pb2.GridPoint(\n", " learning_method='adam',\n", " learning_rate=0.0005, \n", " adam_beta1=0.9, adam_beta2=0.9, adam_eps=0.001,\n", " dropout_rate=0.8, gradient_clip_norm=1,\n", " use_moving_average=True,\n", " seed=1)\n", " builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)\n", " target = spec_pb2.TrainTarget(\n", " name='all',\n", " unroll_using_oracle=[False, True, True], # train tagger & parser on gold unrolling, skip lookahead\n", " component_weights=[0, 0.5, 0.5]) # tagger and parser losses have equal weights\n", " trainer = builder.add_training_from_config(target)\n", " annotator = builder.add_annotation(enable_tracing=True)\n", " builder.add_saver()\n", "\n", "# Train on Spanish data for N_STEPS steps and evaluate.\n", "N_STEPS = 20\n", "BATCH_SIZE = 64\n", "with tf.Session(graph=graph) as sess:\n", " sess.run(tf.global_variables_initializer())\n", " training_corpus = ConllSentenceReader(\n", " TRAINING_CORPUS_PATH, projectivize=True).corpus()\n", " dev_corpus = ConllSentenceReader(DEV_CORPUS_PATH).corpus()[:200]\n", " for step in xrange(N_STEPS):\n", " trainer_lib.run_training_step(sess, trainer, training_corpus, batch_size=BATCH_SIZE)\n", " tf.logging.warning('Step %d/%d', step + 1, N_STEPS)\n", " parsed_dev_corpus = trainer_lib.annotate_dataset(sess, annotator, dev_corpus)\n", " pos, uas, las = evaluation.calculate_parse_metrics(dev_corpus, parsed_dev_corpus)\n", " tf.logging.warning('POS %.2f UAS %.2f LAS %.2f', pos, uas, las)\n", " builder.saver.save(sess, CHECKPOINT_FILENAME)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "# Visualize the output of our mini-trained model on a test sentence.\n", "\n", "text = '¿ Viste ese gran coche rojo ?'\n", "tokens = [sentence_pb2.Token(word=word, start=-1, end=-1) for word in text.split()]\n", "sentence = sentence_pb2.Sentence()\n", "sentence.token.extend(tokens)\n", "\n", "with tf.Session(graph=graph) as sess:\n", " # Restore the model we just trained.\n", " builder.saver.restore(sess, CHECKPOINT_FILENAME)\n", " annotations, traces = sess.run([annotator['annotations'], annotator['traces']],\n", " feed_dict={annotator['input_batch']: [sentence.SerializeToString()]})\n", "\n", "HTML(visualization.trace_html(traces[0]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "parsed_sentence = sentence_pb2.Sentence.FromString(annotations[0])\n", "HTML(render_parse_tree_graphviz.parse_tree_graph(parsed_sentence))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.9" } }, "nbformat": 4, "nbformat_minor": 2 }