visualization.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. """Helper library for visualizations.
  2. TODO(googleuser): Find a more reliable way to serve stuff from IPython
  3. notebooks (e.g. determining where the root notebook directory is).
  4. """
  5. from __future__ import absolute_import
  6. from __future__ import division
  7. from __future__ import print_function
  8. import gzip
  9. import os
  10. import uuid
  11. from google.protobuf import json_format
  12. from dragnn.protos import trace_pb2
  13. # Make a guess about where the IPython kernel root is.
  14. _IPYTHON_KERNEL_PATH = os.path.realpath(os.getcwd())
  15. # Bazel uses the 'data' attribute for this library to ensure viz.min.js.gz is
  16. # packaged.
  17. module_path = os.path.dirname(os.path.abspath(__file__))
  18. viz_script = os.path.join(os.path.dirname(module_path), 'viz', 'viz.min.js.gz')
  19. def _load_viz_script():
  20. """Reads the bundled visualization script.
  21. Raises:
  22. EnvironmentError: If the visualization script could not be found.
  23. Returns:
  24. str JavaScript source code.
  25. """
  26. if not os.path.isfile(viz_script):
  27. raise EnvironmentError(
  28. 'Visualization script should be built into {}'.format(viz_script))
  29. with gzip.GzipFile(viz_script) as f:
  30. return f.read()
  31. def parse_trace_json(trace):
  32. """Converts a binary-encoded MasterTrace proto to a JSON parser trace.
  33. Args:
  34. trace: Binary string containing a MasterTrace.
  35. Returns:
  36. JSON str, as expected by visualization tools.
  37. """
  38. as_proto = trace_pb2.MasterTrace.FromString(trace)
  39. as_json = json_format.MessageToJson(
  40. as_proto, preserving_proto_field_name=True)
  41. return as_json
  42. def _container_div(height='700px', contents=''):
  43. elt_id = str(uuid.uuid4())
  44. html = """
  45. <div id="{elt_id}" style="width: 100%; min-width: 200px; height: {height};">
  46. {contents}</div>
  47. """.format(
  48. elt_id=elt_id, height=height, contents=contents)
  49. return elt_id, html
  50. def trace_html(trace, convert_to_unicode=True, height='700px', script=None):
  51. """Generates HTML that will render a master trace.
  52. This will result in a self-contained "div" element.
  53. Args:
  54. trace: binary-encoded MasterTrace string.
  55. convert_to_unicode: Whether to convert the output to unicode. Defaults to
  56. True because IPython.display.HTML expects unicode, and we expect users to
  57. often pass the output of this function to IPython.display.HTML.
  58. height: CSS string representing the height of the element, default '700px'.
  59. script: Visualization script contents, if the defaults are unacceptable.
  60. Returns:
  61. unicode or str with HTML contents.
  62. """
  63. if script is None:
  64. script = _load_viz_script()
  65. json_trace = parse_trace_json(trace)
  66. elt_id, div_html = _container_div(height=height)
  67. as_str = """
  68. <meta charset="utf-8"/>
  69. {div_html}
  70. <script type='text/javascript'>
  71. {script}
  72. visualizeToDiv({json}, "{elt_id}");
  73. </script>
  74. """.format(
  75. script=script, json=json_trace, elt_id=elt_id, div_html=div_html)
  76. return unicode(as_str, 'utf-8') if convert_to_unicode else as_str
  77. def open_in_new_window(html, notebook_html_fcn=None, temp_file_basename=None):
  78. """Opens an HTML visualization in a new window.
  79. This function assumes that the module was loaded when the current working
  80. directory is the IPython/Jupyter notebook root directory. Then it writes a
  81. file ./tmp/_new_window_html/<random-uuid>.html, and returns an HTML display
  82. element, which will call `window.open("/files/<filename>")`. This works
  83. because IPython serves files from the /files root.
  84. Args:
  85. html: HTML to write to a file.
  86. notebook_html_fcn: Function to generate an HTML element; defaults to
  87. IPython.display.HTML (lazily imported).
  88. temp_file_basename: File name to write (defaults to <random-uuid>.html).
  89. Returns:
  90. HTML notebook element, which will trigger the browser to open a new window.
  91. """
  92. if isinstance(html, unicode):
  93. html = html.encode('utf-8')
  94. if notebook_html_fcn is None:
  95. from IPython import display
  96. notebook_html_fcn = display.HTML
  97. if temp_file_basename is None:
  98. temp_file_basename = '{}.html'.format(str(uuid.uuid4()))
  99. rel_path = os.path.join('tmp', '_new_window_html', temp_file_basename)
  100. abs_path = os.path.join(_IPYTHON_KERNEL_PATH, rel_path)
  101. # Write the file, creating the directory if it doesn't exist.
  102. if not os.path.isdir(os.path.dirname(abs_path)):
  103. os.makedirs(os.path.dirname(abs_path))
  104. with open(abs_path, 'w') as f:
  105. f.write(html)
  106. return notebook_html_fcn("""
  107. <script type='text/javascript'>
  108. window.open("/files/{}");
  109. </script>
  110. """.format(rel_path))
  111. class InteractiveVisualization(object):
  112. """Helper class for displaying visualizations interactively.
  113. See usage in examples/dragnn/interactive_text_analyzer.ipynb.
  114. """
  115. def initial_html(self, height='700px', script=None, init_message=None):
  116. """Returns HTML for a container, which will be populated later.
  117. Args:
  118. height: CSS string representing the height of the element, default
  119. '700px'.
  120. script: Visualization script contents, if the defaults are unacceptable.
  121. init_message: Initial message to display.
  122. Returns:
  123. unicode with HTML contents.
  124. """
  125. if script is None:
  126. script = _load_viz_script()
  127. if init_message is None:
  128. init_message = 'Type a sentence and press (enter) to see the trace.'
  129. self.elt_id, div_html = _container_div(
  130. height=height, contents='<strong>{}</strong>'.format(init_message))
  131. html = """
  132. <meta charset="utf-8"/>
  133. {div_html}
  134. <script type='text/javascript'>
  135. {script}
  136. </script>
  137. """.format(
  138. script=script, div_html=div_html)
  139. return unicode(html, 'utf-8') # IPython expects unicode.
  140. def show_trace(self, trace):
  141. """Returns a JS script HTML fragment, which will populate the container.
  142. Args:
  143. trace: binary-encoded MasterTrace string.
  144. Returns:
  145. unicode with HTML contents.
  146. """
  147. html = """
  148. <meta charset="utf-8"/>
  149. <script type='text/javascript'>
  150. document.getElementById("{elt_id}").innerHTML = ""; // Clear previous.
  151. visualizeToDiv({json}, "{elt_id}");
  152. </script>
  153. """.format(
  154. json=parse_trace_json(trace), elt_id=self.elt_id)
  155. return unicode(html, 'utf-8') # IPython expects unicode.