|
@@ -676,13 +676,13 @@ def train_embeddings(index, matrix, tmpdir, shard_size=4096):
|
|
|
swivel.FLAGS.submatrix_cols = shard_size
|
|
swivel.FLAGS.submatrix_cols = shard_size
|
|
|
if len(meta_index) < 10000:
|
|
if len(meta_index) < 10000:
|
|
|
embedding_size = 50
|
|
embedding_size = 50
|
|
|
- num_epochs = 200
|
|
|
|
|
|
|
+ num_epochs = 500
|
|
|
elif len(meta_index) < 100000:
|
|
elif len(meta_index) < 100000:
|
|
|
embedding_size = 100
|
|
embedding_size = 100
|
|
|
- num_epochs = 250
|
|
|
|
|
|
|
+ num_epochs = 300
|
|
|
elif len(meta_index) < 500000:
|
|
elif len(meta_index) < 500000:
|
|
|
embedding_size = 200
|
|
embedding_size = 200
|
|
|
- num_epochs = 300
|
|
|
|
|
|
|
+ num_epochs = 250
|
|
|
else:
|
|
else:
|
|
|
embedding_size = 300
|
|
embedding_size = 300
|
|
|
num_epochs = 200
|
|
num_epochs = 200
|