Explorar o código

Extend epoch/embed size table

Vadim Markovtsev %!s(int64=7) %!d(string=hai) anos
pai
achega
2f2337e654
Modificáronse 1 ficheiros con 18 adicións e 7 borrados
  1. 18 7
      labours.py

+ 18 - 7
labours.py

@@ -602,8 +602,10 @@ def plot_ownership(args, repo, names, people, date_range, last):
         output = args.output
     deploy_plot("%s code ownership through time" % repo, output, args.style)
 
+IDEAL_SHARD_SIZE = 4096
 
-def train_embeddings(index, matrix, tmpdir, shard_size=4096):
+
+def train_embeddings(index, matrix, tmpdir, shard_size=IDEAL_SHARD_SIZE):
     try:
         from . import swivel
     except (SystemError, ImportError):
@@ -674,15 +676,24 @@ def train_embeddings(index, matrix, tmpdir, shard_size=4096):
         print("Training Swivel model...")
         swivel.FLAGS.submatrix_rows = shard_size
         swivel.FLAGS.submatrix_cols = shard_size
-        if len(meta_index) < 10000:
+        if len(meta_index) <= IDEAL_SHARD_SIZE:
             embedding_size = 50
-            num_epochs = 500
-        elif len(meta_index) < 100000:
+            num_epochs = 10000
+        elif len(meta_index) <= IDEAL_SHARD_SIZE * 2:
+            embedding_size = 60
+            num_epochs = 5000
+        elif len(meta_index) <= IDEAL_SHARD_SIZE * 4:
+            embedding_size = 70
+            num_epochs = 4000
+        elif len(meta_index) <= IDEAL_SHARD_SIZE * 10:
+            embedding_size = 80
+            num_epochs = 2500
+        elif len(meta_index) <= IDEAL_SHARD_SIZE * 25:
             embedding_size = 100
-            num_epochs = 300
-        elif len(meta_index) < 500000:
+            num_epochs = 500
+        elif len(meta_index) <= IDEAL_SHARD_SIZE * 100:
             embedding_size = 200
-            num_epochs = 250
+            num_epochs = 300
         else:
             embedding_size = 300
             num_epochs = 200