Browse Source

Open Tensorflow Projector with couples

Vadim Markovtsev 7 years ago
parent
commit
8724bd1703
1 changed files with 79 additions and 5 deletions
  1. 79 5
      labours.py

+ 79 - 5
labours.py

@@ -5,6 +5,8 @@ import os
 import re
 import sys
 import tempfile
+import threading
+import time
 import warnings
 
 try:
@@ -45,6 +47,8 @@ def parse_args():
              "\"month\", \"year\", \"no\", \"raw\" and pandas offset aliases ("
              "http://pandas.pydata.org/pandas-docs/stable/timeseries.html"
              "#offset-aliases).")
+    parser.add_argument("--disable-projector", action="store_true",
+                        help="Do not run Tensorflow Projector on couples.")
     args = parser.parse_args()
     return args
 
@@ -501,7 +505,50 @@ def train_embeddings(coocc_tree, tmpdir, shard_size=4096):
     return meta_index, embeddings
 
 
-def write_embeddings(name, output, index, embeddings):
+class CORSWebServer(object):
+    def __init__(self):
+        self.thread = threading.Thread(target=self.serve)
+        self.server = None
+
+    def serve(self):
+        outer = self
+
+        try:
+            from http.server import HTTPServer, SimpleHTTPRequestHandler, test
+        except ImportError: # Python 2
+            from BaseHTTPServer import HTTPServer, test
+            from SimpleHTTPServer import SimpleHTTPRequestHandler
+
+        class ClojureServer(HTTPServer):
+            def __init__(self, *args, **kwargs):
+                HTTPServer.__init__(self, *args, **kwargs)
+                outer.server = self
+
+
+        class CORSRequestHandler(SimpleHTTPRequestHandler):
+            def end_headers (self):
+                self.send_header("Access-Control-Allow-Origin", "*")
+                SimpleHTTPRequestHandler.end_headers(self)
+
+        test(CORSRequestHandler, ClojureServer)
+
+    def start(self):
+        self.thread.start()
+
+    def stop(self):
+        if self.running:
+            self.server.shutdown()
+            self.thread.join()
+
+    @property
+    def running(self):
+        return self.server is not None
+
+
+web_server = CORSWebServer()
+
+
+def write_embeddings(name, output, run_server, index, embeddings):
     print("Writing Tensorflow Projector files...")
     if not output:
         output = "couples_" + name
@@ -517,6 +564,26 @@ def write_embeddings(name, output, index, embeddings):
             fout.write("\t".join(str(v) for v in vec))
             fout.write("\n")
     print("Wrote", dataf)
+    jsonf = "%s_%s.json" % (output, name)
+    with open(jsonf, "w") as fout:
+        fout.write("""{
+  "embeddings": [
+    {
+      "tensorName": "%s %s coupling",
+      "tensorShape": [%s, %s],
+      "tensorPath": "http://0.0.0.0:8000/%s",
+      "metadataPath": "http://0.0.0.0:8000/%s"
+    }
+  ]
+}
+""" % (output, name, len(embeddings), len(embeddings[0]), dataf, metaf))
+    print("Wrote %s", jsonf)
+    if run_server and not web_server.running:
+        web_server.start()
+    url = "http://projector.tensorflow.org/?config=http://0.0.0.0:8000/" + jsonf
+    print(url)
+    if run_server:
+        os.system("xdg-open " + url)
 
 
 def main():
@@ -552,9 +619,9 @@ def main():
             return
         plot_people(args, name, *load_people(header, people_sequence, people_contents))
     elif args.mode == "couples":
-        write_embeddings("files", args.output,
+        write_embeddings("files", args.output, not args.disable_projector,
                          *train_embeddings(files_coocc, args.couples_tmp_dir))
-        write_embeddings("people", args.output,
+        write_embeddings("people", args.output, not args.disable_projector,
                          *train_embeddings(people_coocc, args.couples_tmp_dir))
     elif args.mode == "all":
         plot_burndown(args, "project",
@@ -567,10 +634,17 @@ def main():
             plot_people(args, name, *load_people(header, people_sequence, people_contents))
         if people_coocc:
             assert files_coocc
-            write_embeddings("files", args.output,
+            write_embeddings("files", args.output, not args.disable_projector,
                              *train_embeddings(files_coocc, args.couples_tmp_dir))
-            write_embeddings("people", args.output,
+            write_embeddings("people", args.output, not args.disable_projector,
                              *train_embeddings(people_coocc, args.couples_tmp_dir))
+    if web_server.running:
+        print("Sleeping for 60 seconds, safe to Ctrl-C")
+        try:
+            time.sleep(60)
+        except KeyboardInterrupt:
+            pass
+        web_server.stop()
 
 if __name__ == "__main__":
     sys.exit(main())