|
@@ -913,7 +913,11 @@ def train_embeddings(index, matrix, tmpdir, shard_size=IDEAL_SHARD_SIZE):
|
|
|
swivel.FLAGS.output_base_path = tmproot
|
|
|
swivel.FLAGS.loss_multiplier = 1.0 / shard_size
|
|
|
swivel.FLAGS.num_epochs = num_epochs
|
|
|
+ # Tensorflow 1.5 parses sys.argv unconditionally *applause*
|
|
|
+ argv_backup = sys.argv[1:]
|
|
|
+ del sys.argv[1:]
|
|
|
swivel.main(None)
|
|
|
+ sys.argv.extend(argv_backup)
|
|
|
print("Reading Swivel embeddings...")
|
|
|
embeddings = []
|
|
|
with open(os.path.join(tmproot, "row_embedding.tsv")) as frow:
|