|
@@ -710,7 +710,10 @@ def train_embeddings(index, matrix, tmpdir, shard_size=IDEAL_SHARD_SIZE):
|
|
print("Training Swivel model...")
|
|
print("Training Swivel model...")
|
|
swivel.FLAGS.submatrix_rows = shard_size
|
|
swivel.FLAGS.submatrix_rows = shard_size
|
|
swivel.FLAGS.submatrix_cols = shard_size
|
|
swivel.FLAGS.submatrix_cols = shard_size
|
|
- if len(meta_index) <= IDEAL_SHARD_SIZE:
|
|
|
|
|
|
+ if len(meta_index) <= IDEAL_SHARD_SIZE / 16:
|
|
|
|
+ embedding_size = 50
|
|
|
|
+ num_epochs = 20000
|
|
|
|
+ elif len(meta_index) <= IDEAL_SHARD_SIZE:
|
|
embedding_size = 50
|
|
embedding_size = 50
|
|
num_epochs = 10000
|
|
num_epochs = 10000
|
|
elif len(meta_index) <= IDEAL_SHARD_SIZE * 2:
|
|
elif len(meta_index) <= IDEAL_SHARD_SIZE * 2:
|