{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "deletable": true, "editable": true, "scrolled": false }, "outputs": [], "source": [ "import os\n", "import ipywidgets as widgets\n", "import tensorflow as tf\n", "from IPython import display\n", "from dragnn.protos import spec_pb2\n", "from dragnn.python import graph_builder\n", "from dragnn.python import spec_builder\n", "from dragnn.python import load_dragnn_cc_impl # This loads the actual op definitions\n", "from dragnn.python import render_parse_tree_graphviz\n", "from dragnn.python import visualization\n", "from google.protobuf import text_format\n", "from syntaxnet import load_parser_ops # This loads the actual op definitions\n", "from syntaxnet import sentence_pb2\n", "from syntaxnet.ops import gen_parser_ops\n", "from tensorflow.python.platform import tf_logging as logging\n", "\n", "def load_model(base_dir, master_spec_name, checkpoint_name):\n", " # Read the master spec\n", " master_spec = spec_pb2.MasterSpec()\n", " with open(os.path.join(base_dir, master_spec_name), \"r\") as f:\n", " text_format.Merge(f.read(), master_spec)\n", " spec_builder.complete_master_spec(master_spec, None, base_dir)\n", " logging.set_verbosity(logging.WARN) # Turn off TensorFlow spam.\n", "\n", " # Initialize a graph\n", " graph = tf.Graph()\n", " with graph.as_default():\n", " hyperparam_config = spec_pb2.GridPoint()\n", " builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)\n", " # This is the component that will annotate test sentences.\n", " annotator = builder.add_annotation(enable_tracing=True)\n", " builder.add_saver() # \"Savers\" can save and load models; here, we're only going to load.\n", "\n", " sess = tf.Session(graph=graph)\n", " with graph.as_default():\n", " #sess.run(tf.global_variables_initializer())\n", " #sess.run('save/restore_all', {'save/Const:0': os.path.join(base_dir, checkpoint_name)})\n", " builder.saver.restore(sess, os.path.join(base_dir, checkpoint_name))\n", " \n", " def annotate_sentence(sentence):\n", " with graph.as_default():\n", " return sess.run([annotator['annotations'], annotator['traces']],\n", " feed_dict={annotator['input_batch']: [sentence]})\n", " return annotate_sentence\n", "\n", "segmenter_model = load_model(\"data/en/segmenter\", \"spec.textproto\", \"checkpoint\")\n", "parser_model = load_model(\"data/en\", \"parser_spec.textproto\", \"checkpoint\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "def annotate_text(text):\n", " sentence = sentence_pb2.Sentence(\n", " text=text,\n", " token=[sentence_pb2.Token(word=text, start=-1, end=-1)]\n", " )\n", "\n", " # preprocess\n", " with tf.Session(graph=tf.Graph()) as tmp_session:\n", " char_input = gen_parser_ops.char_token_generator([sentence.SerializeToString()])\n", " preprocessed = tmp_session.run(char_input)[0]\n", " segmented, _ = segmenter_model(preprocessed)\n", "\n", " annotations, traces = parser_model(segmented[0])\n", " assert len(annotations) == 1\n", " assert len(traces) == 1\n", " return sentence_pb2.Sentence.FromString(annotations[0]), traces[0]\n", "annotate_text(\"John is eating pizza with a fork\"); None # just make sure it works" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "# Interactive parse tree explorer\n", "Run the cell below, and then enter text in the interactive widget." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "def _parse_tree_explorer(): # put stuff in a function to not pollute global scope\n", " text = widgets.Text(\"John is eating pizza with anchovies\")\n", " # Also try: John is eating pizza with a fork\n", " display.display(text)\n", " html = widgets.HTML()\n", " display.display(html)\n", "\n", " def handle_submit(sender):\n", " del sender # unused\n", " parse_tree, trace = annotate_text(text.value)\n", " html.value = u\"\"\"\n", "
{}
\n", " \n", " \"\"\".format(render_parse_tree_graphviz.parse_tree_graph(parse_tree))\n", "\n", " text.on_submit(handle_submit)\n", "_parse_tree_explorer()" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "# Interactive trace explorer\n", "Run the cell below, and then enter text in the interactive widget." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "def _trace_explorer(): # put stuff in a function to not pollute global scope\n", " text = widgets.Text(\"John is eating pizza with anchovies\")\n", " display.display(text)\n", "\n", " output = visualization.InteractiveVisualization()\n", " display.display(display.HTML(output.initial_html()))\n", "\n", " def handle_submit(sender):\n", " del sender # unused\n", " parse_tree, trace = annotate_text(text.value)\n", " display.display(display.HTML(output.show_trace(trace)))\n", "\n", "\n", " text.on_submit(handle_submit)\n", "_trace_explorer()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 2 }