|
@@ -917,28 +917,28 @@ def train_embeddings(index, matrix, tmpdir, shard_size=IDEAL_SHARD_SIZE):
|
|
|
swivel.FLAGS.submatrix_cols = shard_size
|
|
|
if len(meta_index) <= IDEAL_SHARD_SIZE / 16:
|
|
|
embedding_size = 50
|
|
|
- num_epochs = 20000
|
|
|
+ num_epochs = 100000
|
|
|
elif len(meta_index) <= IDEAL_SHARD_SIZE:
|
|
|
embedding_size = 50
|
|
|
- num_epochs = 10000
|
|
|
+ num_epochs = 50000
|
|
|
elif len(meta_index) <= IDEAL_SHARD_SIZE * 2:
|
|
|
embedding_size = 60
|
|
|
- num_epochs = 5000
|
|
|
+ num_epochs = 10000
|
|
|
elif len(meta_index) <= IDEAL_SHARD_SIZE * 4:
|
|
|
embedding_size = 70
|
|
|
- num_epochs = 4000
|
|
|
+ num_epochs = 8000
|
|
|
elif len(meta_index) <= IDEAL_SHARD_SIZE * 10:
|
|
|
embedding_size = 80
|
|
|
- num_epochs = 2500
|
|
|
+ num_epochs = 5000
|
|
|
elif len(meta_index) <= IDEAL_SHARD_SIZE * 25:
|
|
|
embedding_size = 100
|
|
|
- num_epochs = 500
|
|
|
+ num_epochs = 1000
|
|
|
elif len(meta_index) <= IDEAL_SHARD_SIZE * 100:
|
|
|
embedding_size = 200
|
|
|
- num_epochs = 300
|
|
|
+ num_epochs = 600
|
|
|
else:
|
|
|
embedding_size = 300
|
|
|
- num_epochs = 200
|
|
|
+ num_epochs = 300
|
|
|
swivel.FLAGS.embedding_size = embedding_size
|
|
|
swivel.FLAGS.input_base_path = tmproot
|
|
|
swivel.FLAGS.output_base_path = tmproot
|