Просмотр исходного кода

Fix the sharding and the convergence

Vadim Markovtsev 7 лет назад
Родитель
Сommit
f52cd6e01a
1 измененных файлов с 24 добавлено и 15 удалено
  1. 24 15
      labours.py

+ 24 - 15
labours.py

@@ -64,7 +64,7 @@ def read_input(args):
                 data = yaml.load(fin, Loader=loader)
         else:
             data = yaml.load(sys.stdin, Loader=loader)
-    except UnicodeEncodeError as e:
+    except (UnicodeEncodeError, yaml.reader.ReaderError) as e:
         print("\nInvalid unicode in the input: %s\nPlease filter it through fix_yaml_unicode.py" %
               e)
         sys.exit(1)
@@ -398,6 +398,17 @@ def train_embeddings(coocc_tree, tmpdir, shard_size=4096):
         import swivel
 
     index = coocc_tree["index"]
+    nshards = len(index) // shard_size
+    if nshards * shard_size < len(index):
+        nshards += 1
+        shard_size = len(index) // nshards
+        nshards = len(index) // shard_size
+    remainder = len(index) - nshards * shard_size
+    if remainder > 0:
+        lengths = numpy.array([len(cd) for cd in coocc_tree["matrix"]])
+        filtered = sorted(numpy.argsort(lengths)[remainder:])
+    else:
+        filtered = list(range(len(index)))
     print("Reading the sparse matrix...")
     data = []
     indices = []
@@ -408,9 +419,12 @@ def train_embeddings(coocc_tree, tmpdir, shard_size=4096):
             indices.append(col)
         indptr.append(indptr[-1] + len(cd))
     matrix = csr_matrix((data, indices, indptr), shape=(len(index), len(index)))
+    if len(filtered) < len(index):
+        matrix = matrix[filtered, :][:, filtered]
     meta_index = []
-    for i, name in enumerate(index):
-        meta_index.append((name, matrix[i, i]))
+    for i, j in enumerate(filtered):
+        meta_index.append((index[j], matrix[i, i]))
+    index = [mi[0] for mi in meta_index]
     with tempfile.TemporaryDirectory(prefix="hercules_labours_", dir=tmpdir or None) as tmproot:
         print("Writing Swivel metadata...")
         vocabulary = "\n".join(index)
@@ -427,11 +441,6 @@ def train_embeddings(coocc_tree, tmpdir, shard_size=4096):
             out.write(bool_sums_str)
         del bool_sums_str
         reorder = numpy.argsort(-bool_sums)
-        nshards = len(index) // shard_size
-        if nshards * shard_size < len(index):
-            nshards += 1
-            shard_size = len(index) // nshards
-            nshards = len(index) // shard_size
 
         print("Writing Swivel shards...")
         for row in range(nshards):
@@ -460,18 +469,18 @@ def train_embeddings(coocc_tree, tmpdir, shard_size=4096):
         print("Training Swivel model...")
         swivel.FLAGS.submatrix_rows = shard_size
         swivel.FLAGS.submatrix_cols = shard_size
-        if len(index) < 10000:
+        if len(meta_index) < 10000:
             embedding_size = 50
-            num_epochs = 40
-        elif len(index) < 100000:
+            num_epochs = 100
+        elif len(meta_index) < 100000:
             embedding_size = 100
-            num_epochs = 50
-        elif len(index) < 500000:
+            num_epochs = 200
+        elif len(meta_index) < 500000:
             embedding_size = 200
-            num_epochs = 60
+            num_epochs = 300
         else:
             embedding_size = 300
-            num_epochs = 80
+            num_epochs = 200
         swivel.FLAGS.embedding_size = embedding_size
         swivel.FLAGS.input_base_path = tmproot
         swivel.FLAGS.output_base_path = tmproot