Browse Source

Add shotness couples support to labours

Vadim Markovtsev 7 years ago
parent
commit
83b52ec7c0
1 changed files with 46 additions and 0 deletions
  1. 46 0
      labours.py

+ 46 - 0
labours.py

@@ -104,6 +104,9 @@ class Reader(object):
     def get_people_coocc(self):
         raise NotImplementedError
 
+    def get_shotness_coocc(self):
+        raise NotImplementedError
+
     def get_shotness(self):
         raise NotImplementedError
 
@@ -168,6 +171,24 @@ class YamlReader(Reader):
         coocc = self.data["Couples"]["people_coocc"]
         return coocc["index"], self._parse_coocc_matrix(coocc["matrix"])
 
+    def get_shotness_coocc(self):
+        shotness = self.data["Shotness"]
+        index = ["%s:%s" % (i["file"], i["name"]) for i in shotness]
+        indptr = numpy.zeros(len(shotness) + 1, dtype=numpy.int64)
+        indices = []
+        data = []
+        for i, record in enumerate(shotness):
+            pairs = [(int(k), v) for k, v in record["counters"].items()]
+            pairs.sort()
+            indptr[i + 1] = indptr[i] + len(pairs)
+            for k, v in pairs:
+                indices.append(k)
+                data.append(v)
+        indices = numpy.array(indices, dtype=numpy.int32)
+        data = numpy.array(data, dtype=numpy.int32)
+        from scipy.sparse import csr_matrix
+        return index, csr_matrix((data, indices, indptr), shape=(len(shotness),) * 2)
+
     def get_shotness(self):
         from munch import munchify
         obj = munchify(self.data["Shotness"])
@@ -254,6 +275,25 @@ class ProtobufReader(Reader):
         node = self.contents["Couples"].people_couples
         return list(node.index), self._parse_sparse_matrix(node.matrix)
 
+    def get_shotness_coocc(self):
+        shotness = self.get_shotness()
+        index = ["%s:%s" % (i.file, i.name) for i in shotness]
+        indptr = numpy.zeros(len(shotness) + 1, dtype=numpy.int32)
+        indices = []
+        data = []
+        for i, record in enumerate(shotness):
+            pairs = list(record.counters.items())
+            pairs.sort()
+            indptr[i + 1] = indptr[i] + len(pairs)
+            for k, v in pairs:
+                indices.append(k)
+                data.append(v)
+        indices = numpy.array(indices, dtype=numpy.int32)
+        data = numpy.array(data, dtype=numpy.int32)
+        from scipy.sparse import csr_matrix
+        return index, csr_matrix((data, indices, indptr), shape=(len(shotness),) * 2)
+
+
     def get_shotness(self):
         return self.contents["Shotness"].records
 
@@ -1064,6 +1104,12 @@ def main():
                                                tmpdir=args.couples_tmp_dir))
         except KeyError:
             print(couples_warning)
+        try:
+            write_embeddings("shotness", args.output, not args.disable_projector,
+                             *train_embeddings(*reader.get_shotness_coocc(),
+                                               tmpdir=args.couples_tmp_dir))
+        except KeyError:
+            print(shotness_warning)
 
     def shotness():
         try: