瀏覽代碼

Speed up reading yaml and keep the shard remainder

Vadim Markovtsev 7 年之前
父節點
當前提交
d8e1154edc
共有 1 個文件被更改,包括 11 次插入5 次删除
  1. 11 5
      labours.py

+ 11 - 5
labours.py

@@ -53,11 +53,15 @@ def read_input(args):
     sys.stdout.write("Reading the input... ")
     sys.stdout.flush()
     yaml.reader.Reader.NON_PRINTABLE = re.compile(r"(?!x)x")
+    try:
+        loader = yaml.CLoader
+    except AttributeError:
+        loader = yaml.Loader
     if args.input != "-":
         with open(args.input) as fin:
-            data = yaml.load(fin)
+            data = yaml.load(fin, Loader=loader)
     else:
-        data = yaml.load(sys.stdin)
+        data = yaml.load(sys.stdin, Loader=loader)
     print("done")
     return data["burndown"], data["project"], data.get("files"), data.get("people_sequence"), \
            data.get("people"), data.get("people_interaction"), data.get("files_coocc"), \
@@ -418,9 +422,11 @@ def train_embeddings(coocc_tree, tmpdir, shard_size=4096):
         del bool_sums_str
         reorder = numpy.argsort(-bool_sums)
         nshards = len(index) // shard_size
-        if nshards == 0:
-            nshards = 1
-            shard_size = len(index)
+        if nshards * shard_size < len(index):
+            nshards += 1
+            shard_size = len(index) // nshards
+            nshards = len(index) // shard_size
+
         print("Writing Swivel shards...")
         for row in range(nshards):
             for col in range(nshards):