|
@@ -72,10 +72,10 @@ flags.DEFINE_string("output_base_path", None,
|
|
|
flags.DEFINE_integer("embedding_size", 300, "Size of the embeddings")
|
|
|
flags.DEFINE_boolean("trainable_bias", False, "Biases are trainable")
|
|
|
flags.DEFINE_integer("submatrix_rows", 4096,
|
|
|
- "Rows in each training submatrix. This must match"
|
|
|
+ "Rows in each training submatrix. This must match "
|
|
|
"the training data.")
|
|
|
flags.DEFINE_integer("submatrix_cols", 4096,
|
|
|
- "Rows in each training submatrix. This must match"
|
|
|
+ "Rows in each training submatrix. This must match "
|
|
|
"the training data.")
|
|
|
flags.DEFINE_float("loss_multiplier", 1.0 / 4096,
|
|
|
"constant multiplier on loss.")
|
|
@@ -205,9 +205,9 @@ class SwivelModel:
|
|
|
|
|
|
# Create paths to input data files
|
|
|
log("Reading model from: %s", config.input_base_path)
|
|
|
- count_matrix_files = glob.glob(config.input_base_path + "/shard-*.pb")
|
|
|
- row_sums_path = config.input_base_path + "/row_sums.txt"
|
|
|
- col_sums_path = config.input_base_path + "/col_sums.txt"
|
|
|
+ count_matrix_files = glob.glob(os.path.join(config.input_base_path, "shard-*.pb"))
|
|
|
+ row_sums_path = os.path.join(config.input_base_path, "row_sums.txt")
|
|
|
+ col_sums_path = os.path.join(config.input_base_path, "col_sums.txt")
|
|
|
|
|
|
# Read marginals
|
|
|
row_sums = read_marginals_file(row_sums_path)
|
|
@@ -367,6 +367,8 @@ class SwivelModel:
|
|
|
tf.zeros((length, self._config.embedding_size)),
|
|
|
name="top10k_embedding")
|
|
|
embedding_config.tensor_name = self.embedding10k.name
|
|
|
+ embedding_config.metadata_path = os.path.join(
|
|
|
+ self._config.input_base_path, "row_vocab.txt")
|
|
|
tf.contrib.tensorboard.plugins.projector.visualize_embeddings(
|
|
|
self.writer, projector_config)
|
|
|
self.saver = tf.train.Saver((self.embedding10k,), max_to_keep=1)
|