Browse Source

Fix TensorBoard labels

Signed-off-by: Vadim Markovtsev <vadim@sourced.tech>
Vadim Markovtsev 7 years ago
parent
commit
3d2dc109df
1 changed files with 7 additions and 5 deletions
  1. 7 5
      swivel.py

+ 7 - 5
swivel.py

@@ -72,10 +72,10 @@ flags.DEFINE_string("output_base_path", None,
 flags.DEFINE_integer("embedding_size", 300, "Size of the embeddings")
 flags.DEFINE_boolean("trainable_bias", False, "Biases are trainable")
 flags.DEFINE_integer("submatrix_rows", 4096,
-                     "Rows in each training submatrix. This must match"
+                     "Rows in each training submatrix. This must match "
                      "the training data.")
 flags.DEFINE_integer("submatrix_cols", 4096,
-                     "Rows in each training submatrix. This must match"
+                     "Rows in each training submatrix. This must match "
                      "the training data.")
 flags.DEFINE_float("loss_multiplier", 1.0 / 4096,
                    "constant multiplier on loss.")
@@ -205,9 +205,9 @@ class SwivelModel:
 
         # Create paths to input data files
         log("Reading model from: %s", config.input_base_path)
-        count_matrix_files = glob.glob(config.input_base_path + "/shard-*.pb")
-        row_sums_path = config.input_base_path + "/row_sums.txt"
-        col_sums_path = config.input_base_path + "/col_sums.txt"
+        count_matrix_files = glob.glob(os.path.join(config.input_base_path, "shard-*.pb"))
+        row_sums_path = os.path.join(config.input_base_path, "row_sums.txt")
+        col_sums_path = os.path.join(config.input_base_path, "col_sums.txt")
 
         # Read marginals
         row_sums = read_marginals_file(row_sums_path)
@@ -367,6 +367,8 @@ class SwivelModel:
             tf.zeros((length, self._config.embedding_size)),
             name="top10k_embedding")
         embedding_config.tensor_name = self.embedding10k.name
+        embedding_config.metadata_path = os.path.join(
+            self._config.input_base_path, "row_vocab.txt")
         tf.contrib.tensorboard.plugins.projector.visualize_embeddings(
             self.writer, projector_config)
         self.saver = tf.train.Saver((self.embedding10k,), max_to_keep=1)