瀏覽代碼

Merge pull request #4 from tensorflow/swivel

added swivel model
Martin Wicke 9 年之前
父節點
當前提交
1ecaf090a5
共有 12 個文件被更改,包括 2431 次插入0 次删除
  1. 12 0
      swivel/.gitignore
  2. 182 0
      swivel/README.md
  3. 365 0
      swivel/analogy.cc
  4. 98 0
      swivel/eval.mk
  5. 680 0
      swivel/fastprep.cc
  6. 87 0
      swivel/fastprep.mk
  7. 75 0
      swivel/nearest.py
  8. 317 0
      swivel/prep.py
  9. 347 0
      swivel/swivel.py
  10. 88 0
      swivel/text2bin.py
  11. 90 0
      swivel/vecs.py
  12. 90 0
      swivel/wordsim.py

+ 12 - 0
swivel/.gitignore

@@ -0,0 +1,12 @@
+*.an.tab
+*.pyc
+*.ws.tab
+MEN.tar.gz
+Mtruk.csv
+SimLex-999.zip
+analogy
+fastprep
+myz_naacl13_test_set.tgz
+questions-words.txt
+rw.zip
+ws353simrel.tar.gz

+ 182 - 0
swivel/README.md

@@ -0,0 +1,182 @@
+# Swivel in Tensorflow
+
+This is a [TensorFlow](http://www.tensorflow.org/) implementation of the
+[Swivel algorithm](http://arxiv.org/abs/1602.02215) for generating word
+embeddings.
+
+Swivel works as follows:
+
+1. Compute the co-occurrence statistics from a corpus; that is, determine how
+   often a word *c* appears the context (e.g., "within ten words") of a focus
+   word *f*.  This results in a sparse *co-occurrence matrix* whose rows
+   represent the focus words, and whose columns represent the context
+   words. Each cell value is the number of times the focus and context words
+   were observed together.
+2. Re-organize the co-occurrence matrix and chop it into smaller pieces.
+3. Assign a random *embedding vector* of fixed dimension (say, 300) to each
+   focus word and to each context word.
+4. Iteratively attempt to approximate the
+   [pointwise mutual information](https://en.wikipedia.org/wiki/Pointwise_mutual_information)
+   (PMI) between words with the dot product of the corresponding embedding
+   vectors.
+
+Note that the resulting co-occurrence matrix is very sparse (i.e., contains many
+zeros) since most words won't have been observed in the context of other words.
+In the case of very rare words, it seems reasonable to assume that you just
+haven't sampled enough data to spot their co-occurrence yet.  On the other hand,
+if we've failed to observed to common words co-occuring, it seems likely that
+they are *anti-correlated*.
+
+Swivel attempts to capture this intuition by using both the observed and the
+un-observed co-occurrences to inform the way it iteratively adjusts vectors.
+Empirically, this seems to lead to better embeddings, especially for rare words.
+
+# Contents
+
+This release includes the following programs.
+
+* `prep.py` is a program that takes a text corpus and pre-processes it for
+  training. Specifically, it computes a vocabulary and token co-occurrence
+  statistics for the corpus.  It then outputs the information into a format that
+  can be digested by the TensorFlow trainer.
+* `swivel.py` is a TensorFlow program that generates embeddings from the
+  co-occurrence statistics.  It uses the files created by `prep.py` as input,
+  and generates two text files as output: the row and column embeddings.
+* `text2bin.py` combines the row and column vectors generated by Swivel into a
+  flat binary file that can be quickly loaded into memory to perform vector
+  arithmetic.  This can also be used to convert embeddings from
+  [Glove](http://nlp.stanford.edu/projects/glove/) and
+  [word2vec](https://code.google.com/archive/p/word2vec/) into a form that can
+  be used by the following tools.
+* `nearest.py` is a program that you can use to manually inspect binary
+  embeddings.
+* `eval.mk` is a GNU makefile that fill retrieve and normalize several common
+  word similarity and analogy evaluation data sets.
+* `wordsim.py` performs word similarity evaluation of the resulting vectors.
+* `analogy` performs analogy evaluation of the resulting vectors.
+* `fastprep` is a C++ program that works much more quickly that `prep.py`, but
+  also has some additional dependencies to build.
+
+# Building Embeddings with Swivel
+
+To build your own word embeddings with Swivel, you'll need the following:
+
+* A large corpus of text; for example, the
+  [dump of English Wikipedia](https://dumps.wikimedia.org/enwiki/).
+* A working [TensorFlow](http://www.tensorflow.org/) implementation.
+* A machine with plenty of disk space and, ideally, a beefy GPU card.  (We've
+  experimented with the
+  [Nvidia Titan X](http://www.geforce.com/hardware/desktop-gpus/geforce-gtx-titan-x),
+  for example.)
+
+You'll then run `prep.py` (or `fastprep`) to prepare the data for Swivel and run
+`swivel.py` to create the embeddings. The resulting embeddings will be output
+into two large text files: one for the row vectors and one for the column
+vectors.  You can use those "as is", or convert them into a binary file using
+`text2bin.py` and then use the tools here to experiment with the resulting
+vectors.
+
+## Preparing the data for training
+
+Once you've downloaded the corpus (e.g., to `/tmp/wiki.txt`), run `prep.py` to
+prepare the data for training:
+
+    ./prep.py --output_dir /tmp/swivel_data --input /tmp/wiki.txt
+
+By default, `prep.py` will make one pass through the text file to compute a
+"vocabulary" of the most frequent words, and then a second pass to compute the
+co-occurrence statistics.  The following options allow you to control this
+behavior:
+
+|:--- |:--- |
+| `--min_count <n>` | Only include words in the generated vocabulary that appear at least *n* times. |
+| `--max_vocab <n>` | Admit at most *n* words into the vocabulary. |
+| `--vocab <filename>` | Use the specified filename as the vocabulary instead of computing it from the corpus.  The file should contain one word per line. |
+
+The `prep.py` program is pretty simple.  Notably, it does almost no text
+processing: it does no case translation and simply breaks text into tokens by
+splitting on spaces. Feel free to experiment with the `words` function if you'd
+like to do something more sophisticated.
+
+Unfortunately, `prep.py` is pretty slow.  Also included is `fastprep`, a C++
+equivalent that works much more quickly.  Building `fastprep.cc` is a bit more
+involved: it requires you to pull and build the Tensorflow source code in order
+to provide the libraries and headers that it needs.  See `fastprep.mk` for more
+details.
+
+## Training the embeddings
+
+When `prep.py` completes, it will have produced a directory containing the data
+that the Swivel trainer needs to run.  Train embeddings as follows:
+
+    ./swivel.py --input_base_path /tmp/swivel_data \
+       --output_base_path /tmp/swivel_data
+
+There are a variety of parameters that you can fiddle with to customize the
+embeddings; some that you may want to experiment with include:
+
+|:--- |:--- |
+| `--embedding_size <dim>` | The dimensionality of the embeddings that are created.  By default, 300 dimensional embeddings are created. |
+| `--num_epochs <n>` | The number of iterations through the data that are performed.  By default, 40 epochs are trained. |
+
+As mentioned above, access to beefy GPU will dramatically reduce the amount of
+time it takes Swivel to train embeddings.
+
+When complete, you should find `row_embeddings.tsv` and `col_embedding.tsv` in
+the directory specified by `--ouput_base_path`.  These files are tab-delimited
+files that contain one embedding per line.  Each line contains the token
+followed by *dim* floating point numbers.
+
+## Exploring and evaluating the embeddings
+
+There are also some simple tools you can to explore the embeddings.  These tools
+work with a simple binary vector format that can be `mmap`-ed into memory along
+with a separate vocabulary file.  Use `text2bin.py` to generate these files:
+
+    ./text2bin.py -o vecs.bin -v vocab.txt /tmp/swivel_data/*_embedding.tsv
+
+You can do some simple exploration using `nearest.py`:
+
+    ./nearest.py -v vocab.txt -e vecs.bin
+    query> dog
+    dog
+    dogs
+    cat
+    ...
+    query> man woman king
+    king
+    queen
+    princess
+    ...
+
+To evaluate the embeddings using common word similarity and analogy datasets,
+use `eval.mk` to retrieve the data sets and build the tools:
+
+    make -f eval.mk
+    ./wordsim.py -v vocab.txt -e vecs.bin *.ws.tab
+    ./analogy --vocab vocab.txt --embeddings vecs.bin *.an.tab
+
+The word similarity evaluation compares the embeddings' estimate of "similarity"
+with human judgement using
+[Spearman's rho](https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient)
+as the measure of correlation.  (Bigger numbers are better.)
+
+The analogy evaluation tests how well the embeddings can predict analogies like
+"man is to woman as king is to queen".
+
+Note that `eval.mk` forces all evaluation data into lower case.  From there,
+both the word similarity and analogy evaluations assume that the eval data and
+the embeddings use consistent capitalization: if you train embeddings using
+mixed case and evaluate them using lower case, things won't work well.
+
+# Contact
+
+If you have any questions about Swivel, feel free to post to
+[swivel-embeddings@googlegroups.com](https://groups.google.com/forum/#!forum/swivel-embeddings)
+or contact us directly:
+
+* Noam Shazeer (`noam@google.com`)
+* Ryan Doherty (`portalfire@google.com`)
+* Colin Evans (`colinhevans@google.com`)
+* Chris Waterson (`waterson@google.com`)
+

+ 365 - 0
swivel/analogy.cc

@@ -0,0 +1,365 @@
+/* -*- Mode: C++ -*- */
+
+/*
+ * Copyright 2016 Google Inc. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Computes embedding performance on analogy tasks.  Accepts as input one or
+ * more files containing four words per line (A B C D), and determines if:
+ *
+ *   vec(C) - vec(A) + vec(B) ~= vec(D)
+ *
+ * Cosine distance in the embedding space is used to retrieve neighbors. Any
+ * missing vocabulary items are scored as losses.
+ */
+#include <fcntl.h>
+#include <math.h>
+#include <pthread.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <fstream>
+#include <iostream>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+static const char usage[] = R"(
+Performs analogy testing of embedding vectors.
+
+Usage:
+
+  analogy --embeddings <embeddings> --vocab <vocab> eval1.tab ...
+
+Options:
+
+  --embeddings <filename>
+    The file containing the binary embedding vectors to evaluate.
+
+  --vocab <filename>
+    The vocabulary file corresponding to the embedding vectors.
+
+  --nthreads <integer>
+    The number of evaluation threads to run (default: 8)
+)";
+
+// Reads the vocabulary file into a map from token to vector index.
+static std::unordered_map<std::string, int> ReadVocab(
+    const std::string& vocab_filename) {
+  std::unordered_map<std::string, int> vocab;
+  std::ifstream fin(vocab_filename);
+
+  int index = 0;
+  for (std::string token; std::getline(fin, token); ++index) {
+    auto n = token.find('\t');
+    if (n != std::string::npos) token = token.substr(n);
+
+    vocab[token] = index;
+  }
+
+  return vocab;
+}
+
+// An analogy query: "A is to B as C is to D".
+typedef std::tuple<int, int, int, int> AnalogyQuery;
+
+std::vector<AnalogyQuery> ReadQueries(
+    const std::string &filename,
+    const std::unordered_map<std::string, int> &vocab, int *total) {
+  std::ifstream fin(filename);
+
+  std::vector<AnalogyQuery> queries;
+  int lineno = 0;
+  while (1) {
+    // Read the four words.
+    std::string words[4];
+    int nread = 0;
+    for (int i = 0; i < 4; ++i) {
+      fin >> words[i];
+      if (!words[i].empty()) ++nread;
+    }
+
+    ++lineno;
+    if (nread == 0) break;
+
+    if (nread < 4) {
+      std::cerr << "expected four words at line " << lineno << std::endl;
+      break;
+    }
+
+    // Look up each word's index.
+    int ixs[4], nvalid;
+    for (nvalid = 0; nvalid < 4; ++nvalid) {
+      std::unordered_map<std::string, int>::const_iterator it =
+          vocab.find(words[nvalid]);
+
+      if (it == vocab.end()) break;
+
+      ixs[nvalid] = it->second;
+    }
+
+    // If we don't have all the words, count it as a loss.
+    if (nvalid >= 4)
+      queries.push_back(std::make_tuple(ixs[0], ixs[1], ixs[2], ixs[3]));
+  }
+
+  *total = lineno;
+  return queries;
+}
+
+
+// A thread that evaluates some fraction of the analogies.
+class AnalogyEvaluator {
+ public:
+  // Creates a new Analogy evaluator for a range of analogy queries.
+  AnalogyEvaluator(std::vector<AnalogyQuery>::const_iterator begin,
+                   std::vector<AnalogyQuery>::const_iterator end,
+                   const float *embeddings, const int num_embeddings,
+                   const int dim)
+      : begin_(begin),
+        end_(end),
+        embeddings_(embeddings),
+        num_embeddings_(num_embeddings),
+        dim_(dim) {}
+
+  // A thunk for pthreads.
+  static void* Run(void *param) {
+    AnalogyEvaluator *self = static_cast<AnalogyEvaluator*>(param);
+    self->Evaluate();
+    return nullptr;
+  }
+
+  // Evaluates the analogies.
+  void Evaluate();
+
+  // Returns the number of correct analogies after evaluation is complete.
+  int GetNumCorrect() const { return correct_; }
+
+ protected:
+  // The beginning of the range of queries to consider.
+  std::vector<AnalogyQuery>::const_iterator begin_;
+
+  // The end of the range of queries to consider.
+  std::vector<AnalogyQuery>::const_iterator end_;
+
+  // The raw embedding vectors.
+  const float *embeddings_;
+
+  // The number of embedding vectors.
+  const int num_embeddings_;
+
+  // The embedding vector dimensionality.
+  const int dim_;
+
+  // The number of correct analogies.
+  int correct_;
+};
+
+
+void AnalogyEvaluator::Evaluate() {
+  float* sum = new float[dim_];
+
+  correct_ = 0;
+  for (auto query = begin_; query < end_; ++query) {
+    const float* vec;
+    int a, b, c, d;
+    std::tie(a, b, c, d) = *query;
+
+    // Compute C - A + B.
+    vec = embeddings_ + dim_ * c;
+    for (int i = 0; i < dim_; ++i) sum[i] = vec[i];
+
+    vec = embeddings_ + dim_ * a;
+    for (int i = 0; i < dim_; ++i) sum[i] -= vec[i];
+
+    vec = embeddings_ + dim_ * b;
+    for (int i = 0; i < dim_; ++i) sum[i] += vec[i];
+
+    // Find the nearest neighbor that isn't one of the query words.
+    int best_ix = -1;
+    float best_dot = -1.0;
+    for (int i = 0; i < num_embeddings_; ++i) {
+      if (i == a || i == b || i == c) continue;
+
+      vec = embeddings_ + dim_ * i;
+
+      float dot = 0;
+      for (int j = 0; j < dim_; ++j) dot += vec[j] * sum[j];
+
+      if (dot > best_dot) {
+        best_ix = i;
+        best_dot = dot;
+      }
+    }
+
+    // The fourth word is the answer; did we get it right?
+    if (best_ix == d) ++correct_;
+  }
+
+  delete[] sum;
+}
+
+
+int main(int argc, char *argv[]) {
+  if (argc <= 1) {
+    printf(usage);
+    return 2;
+  }
+
+  std::string embeddings_filename, vocab_filename;
+  int nthreads = 8;
+
+  std::vector<std::string> input_filenames;
+  std::vector<std::tuple<int, int, int, int>> queries;
+
+  for (int i = 1; i < argc; ++i) {
+    std::string arg = argv[i];
+    if (arg == "--embeddings") {
+      if (++i >= argc) goto argmissing;
+      embeddings_filename = argv[i];
+    } else if (arg == "--vocab") {
+      if (++i >= argc) goto argmissing;
+      vocab_filename = argv[i];
+    } else if (arg == "--nthreads") {
+      if (++i >= argc) goto argmissing;
+      if ((nthreads = atoi(argv[i])) <= 0) goto badarg;
+    } else if (arg == "--help") {
+      std::cout << usage << std::endl;
+      return 0;
+    } else if (arg[0] == '-') {
+      std::cerr << "unknown option: '" << arg << "'" << std::endl;
+      return 2;
+    } else {
+      input_filenames.push_back(arg);
+    }
+
+    continue;
+
+  argmissing:
+    std::cerr << "missing value for '" << argv[i - 1] << "' (--help for help)"
+              << std::endl;
+    return 2;
+
+  badarg:
+    std::cerr << "invalid value '" << argv[i] << "' for '" << argv[i - 1]
+              << "' (--help for help)" << std::endl;
+
+    return 2;
+  }
+
+  // Read the vocabulary.
+  std::unordered_map<std::string, int> vocab = ReadVocab(vocab_filename);
+  if (!vocab.size()) {
+    std::cerr << "unable to read vocabulary file '" << vocab_filename << "'"
+              << std::endl;
+    return 1;
+  }
+
+  const int n = vocab.size();
+
+  // Read the vectors.
+  int fd;
+  if ((fd = open(embeddings_filename.c_str(), O_RDONLY)) < 0) {
+    std::cerr << "unable to open embeddings file '" << embeddings_filename
+              << "'" << std::endl;
+    return 1;
+  }
+
+  off_t nbytes = lseek(fd, 0, SEEK_END);
+  if (nbytes == -1) {
+    std::cerr << "unable to determine file size for '" << embeddings_filename
+              << "'" << std::endl;
+    return 1;
+  }
+
+  if (nbytes % (sizeof(float) * n) != 0) {
+    std::cerr << "'" << embeddings_filename
+              << "' has a strange file size; expected it to be "
+                 "a multiple of the vocabulary size"
+              << std::endl;
+
+    return 1;
+  }
+
+  const int dim = nbytes / (sizeof(float) * n);
+  float *embeddings = static_cast<float *>(malloc(nbytes));
+  lseek(fd, 0, SEEK_SET);
+  if (read(fd, embeddings, nbytes) < nbytes) {
+    std::cerr << "unable to read embeddings from " << embeddings_filename
+              << std::endl;
+    return 1;
+  }
+
+  close(fd);
+
+  /* Normalize the vectors. */
+  for (int i = 0; i < n; ++i) {
+    float *vec = embeddings + dim * i;
+    float norm = 0;
+    for (int j = 0; j < dim; ++j) norm += vec[j] * vec[j];
+
+    norm = sqrt(norm);
+    for (int j = 0; j < dim; ++j) vec[j] /= norm;
+  }
+
+  pthread_attr_t attr;
+  if (pthread_attr_init(&attr) != 0) {
+    std::cerr << "unable to initalize pthreads" << std::endl;
+    return 1;
+  }
+
+  /* Read each input file. */
+  for (const auto filename : input_filenames) {
+    int total = 0;
+    std::vector<AnalogyQuery> queries =
+        ReadQueries(filename.c_str(), vocab, &total);
+
+    const int queries_per_thread = queries.size() / nthreads;
+    std::vector<AnalogyEvaluator*> evaluators;
+    std::vector<pthread_t> threads;
+
+    for (int i = 0; i < nthreads; ++i) {
+      auto begin = queries.begin() + i * queries_per_thread;
+      auto end = (i + 1 < nthreads)
+                     ? queries.begin() + (i + 1) * queries_per_thread
+                     : queries.end();
+
+      AnalogyEvaluator *evaluator =
+          new AnalogyEvaluator(begin, end, embeddings, n, dim);
+
+      pthread_t thread;
+      pthread_create(&thread, &attr, AnalogyEvaluator::Run, evaluator);
+      evaluators.push_back(evaluator);
+      threads.push_back(thread);
+    }
+
+    for (auto &thread : threads) pthread_join(thread, 0);
+
+    int correct = 0;
+    for (const AnalogyEvaluator* evaluator : evaluators) {
+      correct += evaluator->GetNumCorrect();
+      delete evaluator;
+    }
+
+    printf("%0.3f %s\n", static_cast<float>(correct) / total, filename.c_str());
+  }
+
+  return 0;
+}

+ 98 - 0
swivel/eval.mk

@@ -0,0 +1,98 @@
+# -*- Mode: Makefile -*-
+#
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This makefile pulls down the evaluation datasets and formats them uniformly.
+# Word similarity evaluations are formatted to contain exactly three columns:
+# the two words being compared and the human judgement.
+#
+# Use wordsim.py and analogy to run the actual evaluations.
+
+CXXFLAGS=-std=c++11 -m64 -mavx -g -Ofast -Wall
+LDLIBS=-lpthread -lm
+
+WORDSIM_EVALS=	ws353sim.ws.tab \
+		ws353rel.ws.tab \
+		men.ws.tab	\
+		mturk.ws.tab \
+		rarewords.ws.tab \
+		simlex999.ws.tab \
+		$(NULL)
+
+ANALOGY_EVALS=	mikolov.an.tab \
+		msr.an.tab \
+		$(NULL)
+
+all: $(WORDSIM_EVALS) $(ANALOGY_EVALS) analogy
+
+ws353sim.ws.tab: ws353simrel.tar.gz
+	tar Oxfz $^ wordsim353_sim_rel/wordsim_similarity_goldstandard.txt > $@
+
+ws353rel.ws.tab: ws353simrel.tar.gz
+	tar Oxfz $^ wordsim353_sim_rel/wordsim_relatedness_goldstandard.txt > $@
+
+men.ws.tab: MEN.tar.gz
+	tar Oxfz $^ MEN/MEN_dataset_natural_form_full | tr ' ' '\t' > $@
+
+mturk.ws.tab: Mtruk.csv
+	cat $^ | tr -d '\r' | tr ',' '\t' > $@
+
+rarewords.ws.tab: rw.zip
+	unzip -p $^ rw/rw.txt | cut -f1-3 -d $$'\t' > $@
+
+simlex999.ws.tab: SimLex-999.zip
+	unzip -p $^ SimLex-999/SimLex-999.txt \
+	| tail -n +2 | cut -f1,2,4 -d $$'\t' > $@
+
+mikolov.an.tab: questions-words.txt
+	egrep -v -E '^:' $^ | tr '[A-Z] ' '[a-z]\t' > $@
+
+msr.an.tab: myz_naacl13_test_set.tgz
+	tar Oxfz $^ test_set/word_relationship.questions | tr ' ' '\t' > /tmp/q
+	tar Oxfz $^ test_set/word_relationship.answers | cut -f2 -d ' ' > /tmp/a
+	paste /tmp/q /tmp/a > $@
+	rm -f /tmp/q /tmp/a
+
+
+# wget commands to fetch the datasets.  Please see the original datasets for
+# appropriate references if you use these.
+ws353simrel.tar.gz:
+	wget http://alfonseca.org/pubs/ws353simrel.tar.gz
+
+MEN.tar.gz:
+	wget http://clic.cimec.unitn.it/~elia.bruni/resources/MEN.tar.gz
+
+Mtruk.csv:
+	wget http://tx.technion.ac.il/~kirar/files/Mtruk.csv
+
+rw.zip:
+	wget http://www-nlp.stanford.edu/~lmthang/morphoNLM/rw.zip
+
+SimLex-999.zip:
+	wget http://www.cl.cam.ac.uk/~fh295/SimLex-999.zip
+
+questions-words.txt:
+	wget http://word2vec.googlecode.com/svn/trunk/questions-words.txt
+
+myz_naacl13_test_set.tgz:
+	wget http://research.microsoft.com/en-us/um/people/gzweig/Pubs/myz_naacl13_test_set.tgz
+
+analogy: analogy.cc
+
+clean:
+	rm -f *.ws.tab *.an.tab analogy *.pyc
+
+distclean: clean
+	rm -f *.tgz *.tar.gz *.zip Mtruk.csv questions-words.txt

+ 680 - 0
swivel/fastprep.cc

@@ -0,0 +1,680 @@
+/* -*- Mode: C++ -*- */
+
+/*
+ * Copyright 2016 Google Inc. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * This program starts with a text file (and optionally a vocabulary file) and
+ * computes co-occurrence statistics. It emits output in a format that can be
+ * consumed by the "swivel" program.  It's functionally equivalent to "prep.py",
+ * but works much more quickly.
+ */
+
+#include <assert.h>
+#include <fcntl.h>
+#include <pthread.h>
+#include <stdio.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include <algorithm>
+#include <fstream>
+#include <iomanip>
+#include <iostream>
+#include <map>
+#include <string>
+#include <tuple>
+#include <unordered_map>
+#include <vector>
+
+#include "google/protobuf/io/zero_copy_stream_impl.h"
+#include "tensorflow/core/example/example.pb.h"
+#include "tensorflow/core/example/feature.pb.h"
+
+static const char usage[] = R"(
+Prepares a corpus for processing by Swivel.
+
+Usage:
+
+  prep --output_dir <output-dir> --input <text-file>
+
+Options:
+
+  --input <filename>
+      The input text.
+
+  --output_dir <directory>
+      Specifies the output directory where the various Swivel data
+      files should be placed.  This directory must exist.
+
+  --shard_size <int>
+      Specifies the shard size; default 4096.
+
+  --min_count <int>
+      The minimum number of times a word should appear to be included in the
+      generated vocabulary; default 5.  (Ignored if --vocab is used.)
+
+  --max_vocab <int>
+      The maximum vocabulary size to generate from the input corpus; default
+      102,400.  (Ignored if --vocab is used.)
+
+  --vocab <filename>
+      Use the specified unigram vocabulary instead of generating
+      it from the corpus.
+
+  --window_size <int>
+      Specifies the window size for computing co-occurrence stats;
+      default 10.
+)";
+
+struct cooc_t {
+  int row;
+  int col;
+  float cnt;
+};
+
+typedef std::map<long long, float> cooc_counts_t;
+
+// Retrieves the next word from the input stream, treating words as simply being
+// delimited by whitespace.  Returns true if this is the end of a "sentence";
+// i.e., a newline.
+bool NextWord(std::ifstream &fin, std::string* word) {
+  std::string buf;
+  char c;
+
+  if (fin.eof()) {
+    word->erase();
+    return true;
+  }
+
+  // Skip leading whitespace.
+  do {
+    c = fin.get();
+  } while (!fin.eof() && std::isspace(c));
+
+  if (fin.eof()) {
+    word->erase();
+    return true;
+  }
+
+  // Read the next word.
+  do {
+    buf += c;
+    c = fin.get();
+  } while (!fin.eof() && !std::isspace(c));
+
+  *word = buf;
+  if (c == '\n' || fin.eof()) return true;
+
+  // Skip trailing whitespace.
+  do {
+    c = fin.get();
+  } while (!fin.eof() && std::isspace(c));
+
+  if (fin.eof()) return true;
+
+  fin.unget();
+  return false;
+}
+
+// Creates a vocabulary from the most frequent terms in the input file.
+std::vector<std::string> CreateVocabulary(const std::string input_filename,
+                                          const int shard_size,
+                                          const int min_vocab_count,
+                                          const int max_vocab_size) {
+  std::vector<std::string> vocab;
+
+  // Count all the distinct tokens in the file.  (XXX this will eventually
+  // consume all memory and should be re-written to periodically trim the data.)
+  std::unordered_map<std::string, long long> counts;
+
+  std::ifstream fin(input_filename, std::ifstream::ate);
+
+  if (!fin) {
+    std::cerr << "couldn't read input file '" << input_filename << "'"
+              << std::endl;
+
+    return vocab;
+  }
+
+  const auto input_size = fin.tellg();
+  fin.seekg(0);
+
+  long long ntokens = 0;
+  while (!fin.eof()) {
+    std::string word;
+    NextWord(fin, &word);
+    counts[word] += 1;
+
+    if (++ntokens % 1000000 == 0) {
+      const float pct = 100.0 * static_cast<float>(fin.tellg()) / input_size;
+      fprintf(stdout, "\rComputing vocabulary: %0.1f%% complete...", pct);
+      std::flush(std::cout);
+    }
+  }
+
+  std::cout << counts.size() << " distinct tokens" << std::endl;
+
+  // Sort the vocabulary from most frequent to least frequent.
+  std::vector<std::pair<std::string, long long>> buf;
+  std::copy(counts.begin(), counts.end(), std::back_inserter(buf));
+  std::sort(buf.begin(), buf.end(),
+            [](const std::pair<std::string, long long> &a,
+               const std::pair<std::string, long long> &b) {
+              return b.second < a.second;
+            });
+
+  // Truncate to the maximum vocabulary size
+  if (static_cast<int>(buf.size()) > max_vocab_size) buf.resize(max_vocab_size);
+  if (buf.empty()) return vocab;
+
+  // Eliminate rare tokens and truncate to a size modulo the shard size.
+  int vocab_size = buf.size();
+  while (vocab_size > 0 && buf[vocab_size - 1].second < min_vocab_count)
+    --vocab_size;
+
+  vocab_size -= vocab_size % shard_size;
+  if (static_cast<int>(buf.size()) > vocab_size) buf.resize(vocab_size);
+
+  // Copy out the tokens.
+  for (const auto& pair : buf) vocab.push_back(pair.first);
+
+  return vocab;
+}
+
+std::vector<std::string> ReadVocabulary(const std::string vocab_filename) {
+  std::vector<std::string> vocab;
+
+  std::ifstream fin(vocab_filename);
+  int index = 0;
+  for (std::string token; std::getline(fin, token); ++index) {
+    auto n = token.find('\t');
+    if (n != std::string::npos) token = token.substr(n);
+
+    vocab.push_back(token);
+  }
+
+  return vocab;
+}
+
+void WriteVocabulary(const std::vector<std::string> &vocab,
+                     const std::string &output_dirname) {
+  for (const std::string filename : {"row_vocab.txt", "col_vocab.txt"}) {
+    std::ofstream fout(output_dirname + "/" + filename);
+    for (const auto &token : vocab) fout << token << std::endl;
+  }
+}
+
+// Manages accumulation of co-occurrence data into temporary disk buffer files.
+class CoocBuffer {
+ public:
+  CoocBuffer(const std::string &output_dirname, const int num_shards,
+             const int shard_size);
+
+  // Accumulate the co-occurrence counts to the buffer.
+  void AccumulateCoocs(const cooc_counts_t &coocs);
+
+  // Read the buffer to produce shard files.
+  void WriteShards();
+
+ protected:
+  // The output directory. Also used for temporary buffer files.
+  const std::string output_dirname_;
+
+  // The number of row/column shards.
+  const int num_shards_;
+
+  // The number of elements per shard.
+  const int shard_size_;
+
+  // Parallel arrays of temporary file paths and file descriptors.
+  std::vector<std::string> paths_;
+  std::vector<int> fds_;
+
+  // Ensures that only one buffer file is getting written at a time.
+  pthread_mutex_t writer_mutex_;
+};
+
+CoocBuffer::CoocBuffer(const std::string &output_dirname, const int num_shards,
+                       const int shard_size)
+    : output_dirname_(output_dirname),
+      num_shards_(num_shards),
+      shard_size_(shard_size),
+      writer_mutex_(PTHREAD_MUTEX_INITIALIZER) {
+  for (int row = 0; row < num_shards_; ++row) {
+    for (int col = 0; col < num_shards_; ++col) {
+      char filename[256];
+      sprintf(filename, "shard-%03d-%03d.tmp", row, col);
+
+      std::string path = output_dirname + "/" + filename;
+      int fd = open(path.c_str(), O_RDWR | O_CREAT | O_TRUNC, 0666);
+      assert(fd > 0);
+
+      paths_.push_back(path);
+      fds_.push_back(fd);
+    }
+  }
+}
+
+void CoocBuffer::AccumulateCoocs(const cooc_counts_t &coocs) {
+  std::vector<std::vector<cooc_t>> bufs(fds_.size());
+
+  for (const auto &cooc : coocs) {
+    const int row_id = cooc.first >> 32;
+    const int col_id = cooc.first & 0xffffffff;
+    const float cnt = cooc.second;
+
+    const int row_shard = row_id % num_shards_;
+    const int row_off = row_id / num_shards_;
+    const int col_shard = col_id % num_shards_;
+    const int col_off = col_id / num_shards_;
+
+    const int top_shard_idx = row_shard * num_shards_ + col_shard;
+    bufs[top_shard_idx].push_back(cooc_t{row_off, col_off, cnt});
+
+    const int bot_shard_idx = col_shard * num_shards_ + row_shard;
+    bufs[bot_shard_idx].push_back(cooc_t{col_off, row_off, cnt});
+  }
+
+  // XXX TODO: lock
+  for (int i = 0; i < static_cast<int>(fds_.size()); ++i) {
+    int rv = pthread_mutex_lock(&writer_mutex_);
+    assert(rv == 0);
+    const int nbytes = bufs[i].size() * sizeof(cooc_t);
+    int nwritten = write(fds_[i], bufs[i].data(), nbytes);
+    assert(nwritten == nbytes);
+    pthread_mutex_unlock(&writer_mutex_);
+  }
+}
+
+void CoocBuffer::WriteShards() {
+  for (int shard = 0; shard < static_cast<int>(fds_.size()); ++shard) {
+    const int row_shard = shard / num_shards_;
+    const int col_shard = shard % num_shards_;
+
+    std::cout << "\rwriting shard " << (shard + 1) << "/"
+              << (num_shards_ * num_shards_);
+    std::flush(std::cout);
+
+    // Construct the tf::Example proto.  First, we add the global rows and
+    // column that are present in the shard.
+    tensorflow::Example example;
+
+    auto &feature = *example.mutable_features()->mutable_feature();
+    auto global_row = feature["global_row"].mutable_int64_list();
+    auto global_col = feature["global_col"].mutable_int64_list();
+
+    for (int i = 0; i < shard_size_; ++i) {
+      global_row->add_value(row_shard + i * num_shards_);
+      global_col->add_value(col_shard + i * num_shards_);
+    }
+
+    // Next we add co-occurrences as a sparse representation.  Map the
+    // co-occurrence counts that we've spooled off to disk: these are in
+    // arbitrary order and may contain duplicates.
+    const off_t nbytes = lseek(fds_[shard], 0, SEEK_END);
+    cooc_t *coocs = static_cast<cooc_t*>(
+        mmap(0, nbytes, PROT_READ | PROT_WRITE, MAP_SHARED, fds_[shard], 0));
+
+    const int ncoocs = nbytes / sizeof(cooc_t);
+    cooc_t* cur = coocs;
+    cooc_t* end = coocs + ncoocs;
+
+    auto sparse_value = feature["sparse_value"].mutable_float_list();
+    auto sparse_local_row = feature["sparse_local_row"].mutable_int64_list();
+    auto sparse_local_col = feature["sparse_local_col"].mutable_int64_list();
+
+    std::sort(cur, end, [](const cooc_t &a, const cooc_t &b) {
+      return a.row < b.row || (a.row == b.row && a.col < b.col);
+    });
+
+    // Accumulate the counts into the protocol buffer.
+    int last_row = -1, last_col = -1;
+    float count = 0;
+    for (; cur != end; ++cur) {
+      if (cur->row != last_row || cur->col != last_col) {
+        if (last_row >= 0 && last_col >= 0) {
+          sparse_local_row->add_value(last_row);
+          sparse_local_col->add_value(last_col);
+          sparse_value->add_value(count);
+        }
+
+        last_row = cur->row;
+        last_col = cur->col;
+        count = 0;
+      }
+
+      count += cur->cnt;
+    }
+
+    if (last_row >= 0 && last_col >= 0) {
+      sparse_local_row->add_value(last_row);
+      sparse_local_col->add_value(last_col);
+      sparse_value->add_value(count);
+    }
+
+    munmap(coocs, nbytes);
+    close(fds_[shard]);
+
+    // Write the protocol buffer as a binary blob to disk.
+    char filename[256];
+    snprintf(filename, sizeof(filename), "shard-%03d-%03d.pb", row_shard,
+             col_shard);
+
+    const std::string path = output_dirname_ + "/" + filename;
+    int fd = open(path.c_str(), O_WRONLY | O_TRUNC | O_CREAT, 0666);
+    assert(fd != -1);
+
+    google::protobuf::io::FileOutputStream fout(fd);
+    example.SerializeToZeroCopyStream(&fout);
+    fout.Close();
+
+    // Remove the temporary file.
+    unlink(paths_[shard].c_str());
+  }
+
+  std::cout << std::endl;
+}
+
+// Counts the co-occurrences in part of the file.
+class CoocCounter {
+ public:
+  CoocCounter(const std::string &input_filename, const off_t start,
+              const off_t end, const int window_size,
+              const std::unordered_map<std::string, int> &token_to_id_map,
+              CoocBuffer *coocbuf)
+      : fin_(input_filename, std::ifstream::ate),
+        start_(start),
+        end_(end),
+        window_size_(window_size),
+        token_to_id_map_(token_to_id_map),
+        coocbuf_(coocbuf),
+        marginals_(token_to_id_map.size()) {}
+
+  // PTthreads-friendly thunk to Count.
+  static void* Run(void* param) {
+    CoocCounter* self = static_cast<CoocCounter*>(param);
+    self->Count();
+    return nullptr;
+  }
+
+  // Counts the co-occurrences.
+  void Count();
+
+  const std::vector<double>& Marginals() const { return marginals_; }
+
+ protected:
+  // The input stream.
+  std::ifstream fin_;
+
+  // The range of the file to which this counter should attend.
+  const off_t start_;
+  const off_t end_;
+
+  // The window size for computing co-occurrences.
+  const int window_size_;
+
+  // A reference to the mapping from tokens to IDs.
+  const std::unordered_map<std::string, int> &token_to_id_map_;
+
+  // The buffer into which counts are to be accumulated.
+  CoocBuffer* coocbuf_;
+
+  // The marginal counts accumulated by this counter.
+  std::vector<double> marginals_;
+};
+
+void CoocCounter::Count() {
+  const int max_coocs_size = 16 * 1024 * 1024;
+
+  // A buffer of co-occurrence counts that we'll periodically sort into
+  // shards.
+  cooc_counts_t coocs;
+
+  fin_.seekg(start_);
+
+  int nlines = 0;
+  for (off_t filepos = start_; filepos < end_; filepos = fin_.tellg()) {
+    // Buffer a single sentence.
+    std::vector<int> sentence;
+    bool eos;
+    do {
+      std::string word;
+      eos = NextWord(fin_, &word);
+      auto it = token_to_id_map_.find(word);
+      if (it != token_to_id_map_.end()) sentence.push_back(it->second);
+    } while (!eos);
+
+    // Generate the co-occurrences for the sentence.
+    for (int pos = 0; pos < static_cast<int>(sentence.size()); ++pos) {
+      const int left_id = sentence[pos];
+
+      const int window_extent =
+          std::min(static_cast<int>(sentence.size()) - pos, 1 + window_size_);
+
+      for (int off = 1; off < window_extent; ++off) {
+        const int right_id = sentence[pos + off];
+        const double count = 1.0 / static_cast<double>(off);
+        const long long lo = std::min(left_id, right_id);
+        const long long hi = std::max(left_id, right_id);
+        const long long key = (hi << 32) | lo;
+        coocs[key] += count;
+
+        marginals_[left_id] += count;
+        marginals_[right_id] += count;
+      }
+
+      marginals_[left_id] += 1.0;
+      const long long key = (static_cast<long long>(left_id) << 32) |
+                            static_cast<long long>(left_id);
+
+      coocs[key] += 0.5;
+    }
+
+    // Periodically flush the co-occurrences to disk.
+    if (coocs.size() > max_coocs_size) {
+      coocbuf_->AccumulateCoocs(coocs);
+      coocs.clear();
+    }
+
+    if (start_ == 0 && ++nlines % 1000 == 0) {
+      const double pct = 100.0 * filepos / end_;
+      fprintf(stdout, "\rComputing co-occurrences: %0.1f%% complete...", pct);
+      std::flush(std::cout);
+    }
+  }
+
+  // Accumulate anything we haven't flushed yet.
+  coocbuf_->AccumulateCoocs(coocs);
+
+  if (start_ == 0) std::cout << "done." << std::endl;
+}
+
+void WriteMarginals(const std::vector<double> &marginals,
+                    const std::string &output_dirname) {
+  for (const std::string filename : {"row_sums.txt", "col_sums.txt"}) {
+    std::ofstream fout(output_dirname + "/" + filename);
+    fout.setf(std::ios::fixed);
+    for (double sum : marginals) fout << sum << std::endl;
+  }
+}
+
+int main(int argc, char *argv[]) {
+  std::string input_filename;
+  std::string vocab_filename;
+  std::string output_dirname;
+  bool generate_vocab = true;
+  int max_vocab_size = 100 * 1024;
+  int min_vocab_count = 5;
+  int window_size = 10;
+  int shard_size = 4096;
+  int num_threads = 4;
+
+  for (int i = 1; i < argc; ++i) {
+    std::string arg(argv[i]);
+    if (arg == "--vocab") {
+      if (++i >= argc) goto argmissing;
+      generate_vocab = false;
+      vocab_filename = argv[i];
+    } else if (arg == "--max_vocab") {
+      if (++i >= argc) goto argmissing;
+      if ((max_vocab_size = atoi(argv[i])) <= 0) goto badarg;
+    } else if (arg == "--min_count") {
+      if (++i >= argc) goto argmissing;
+      if ((min_vocab_count = atoi(argv[i])) <= 0) goto badarg;
+    } else if (arg == "--window_size") {
+      if (++i >= argc) goto argmissing;
+      if ((window_size = atoi(argv[i])) <= 0) goto badarg;
+    } else if (arg == "--input") {
+      if (++i >= argc) goto argmissing;
+      input_filename = argv[i];
+    } else if (arg == "--output_dir") {
+      if (++i >= argc) goto argmissing;
+      output_dirname = argv[i];
+    } else if (arg == "--shard_size") {
+      if (++i >= argc) goto argmissing;
+      shard_size = atoi(argv[i]);
+    } else if (arg == "--num_threads") {
+      if (++i >= argc) goto argmissing;
+      num_threads = atoi(argv[i]);
+    } else if (arg == "--help") {
+      std::cout << usage << std::endl;
+      return 0;
+    } else {
+      std::cerr << "unknown arg '" << arg << "'; try --help?" << std::endl;
+      return 2;
+    }
+
+    continue;
+
+  badarg:
+    std::cerr << "'" << argv[i] << "' is not a valid value for '" << arg
+              << "'; try --help?" << std::endl;
+
+    return 2;
+
+  argmissing:
+    std::cerr << arg << " requires an argument; try --help?" << std::endl;
+  }
+
+  if (input_filename.empty()) {
+    std::cerr << "please specify the input text with '--input'; try --help?"
+              << std::endl;
+    return 2;
+  }
+
+  if (output_dirname.empty()) {
+    std::cerr << "please specify the output directory with '--output_dir'"
+              << std::endl;
+
+    return 2;
+  }
+
+  struct stat sb;
+  if (lstat(output_dirname.c_str(), &sb) != 0 || !S_ISDIR(sb.st_mode)) {
+    std::cerr << "output directory '" << output_dirname
+              << "' does not exist of is not a directory." << std::endl;
+
+    return 1;
+  }
+
+  if (lstat(input_filename.c_str(), &sb) != 0 || !S_ISREG(sb.st_mode)) {
+    std::cerr << "input file '" << input_filename
+              << "' does not exist or is not a file." << std::endl;
+
+    return 1;
+  }
+
+  // The total size of the input.
+  const off_t input_size = sb.st_size;
+
+  const std::vector<std::string> vocab =
+      generate_vocab ? CreateVocabulary(input_filename, shard_size,
+                                        min_vocab_count, max_vocab_size)
+                     : ReadVocabulary(vocab_filename);
+
+  if (!vocab.size()) {
+    std::cerr << "Empty vocabulary." << std::endl;
+    return 1;
+  }
+
+  std::cout << "Generating Swivel co-occurrence data into " << output_dirname
+            << std::endl;
+
+  std::cout << "Shard size: " << shard_size << "x" << shard_size << std::endl;
+  std::cout << "Vocab size: " << vocab.size() << std::endl;
+
+  // Write the vocabulary files into  the output directory.
+  WriteVocabulary(vocab, output_dirname);
+
+  const int num_shards = vocab.size() / shard_size;
+  CoocBuffer coocbuf(output_dirname, num_shards, shard_size);
+
+  // Build a mapping from the token to its position in the vocabulary file.
+  std::unordered_map<std::string, int> token_to_id_map;
+  for (int i = 0; i < static_cast<int>(vocab.size()); ++i)
+    token_to_id_map[vocab[i]] = i;
+
+  // Compute the co-occurrences
+  std::vector<pthread_t> threads;
+  std::vector<CoocCounter*> counters;
+  const off_t nbytes_per_thread = input_size / num_threads;
+
+  pthread_attr_t attr;
+  if (pthread_attr_init(&attr) != 0) {
+    std::cerr << "unable to initalize pthreads" << std::endl;
+    return 1;
+  }
+
+  for (int i = 0; i < num_threads; ++i) {
+    // We could make this smarter and look around for newlines.  But
+    // realistically that's not going to change things much.
+    const off_t start = i * nbytes_per_thread;
+    const off_t end =
+        i < num_threads - 1 ? (i + 1) * nbytes_per_thread : input_size;
+
+    CoocCounter *counter = new CoocCounter(
+        input_filename, start, end, window_size, token_to_id_map, &coocbuf);
+
+    counters.push_back(counter);
+
+    pthread_t thread;
+    pthread_create(&thread, &attr, CoocCounter::Run, counter);
+
+    threads.push_back(thread);
+  }
+
+  // Wait for threads to finish and collect marginals.
+  std::vector<double> marginals(vocab.size());
+  for (int i = 0; i < num_threads; ++i) {
+    pthread_join(threads[i], 0);
+
+    const std::vector<double>& counter_marginals = counters[i]->Marginals();
+    for (int j = 0; j < static_cast<int>(vocab.size()); ++j)
+      marginals[j] += counter_marginals[j];
+
+    delete counters[i];
+  }
+
+  std::cout << "writing marginals..." << std::endl;
+  WriteMarginals(marginals, output_dirname);
+
+  std::cout << "writing shards..." << std::endl;
+  coocbuf.WriteShards();
+
+  return 0;
+}

+ 87 - 0
swivel/fastprep.mk

@@ -0,0 +1,87 @@
+# -*- Mode: Makefile -*-
+
+#
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+
+# This makefile builds "fastprep", a faster version of prep.py that can be used
+# to build training data for Swivel.  Building "fastprep" is a bit more
+# involved: you'll need to pull and build the Tensorflow source, and then build
+# and install compatible protobuf software.  We've tested this with Tensorflow
+# version 0.7.
+#
+# = Step 1. Pull and Build Tensorflow. =
+#
+# These instructions are somewhat abridged; for pre-requisites and the most
+# up-to-date instructions, refer to:
+#
+#   <https://www.tensorflow.org/versions/r0.7/get_started/os_setup.html#installing-from-sources>
+#
+# To build the Tensorflow components required for "fastpret", you'll need to
+# install Bazel, Numpy, Swig, and Python development headers as described in at
+# the above URL.  Run the "configure" script as appropriate for your
+# environment and then build the "build_pip_package" target:
+#
+#   bazel build -c opt [--config=cuda] //tensorflow/tools/pip_package:build_pip_package
+#
+# This will generate the Tensorflow headers and libraries necessary for
+# "fastprep".
+#
+#
+# = Step 2. Build and Install Compatible Protobuf Libraries =
+#
+# "fastprep" also needs compatible protocol buffer libraries, which you can
+# build from the protobuf implementation included with the Tensorflow
+# distribution:
+#
+#   cd ${TENSORFLOW_SRCDIR}/google/protobuf
+#   ./autogen.sh
+#   ./configure --prefix=${HOME}  # ...or whatever
+#   make
+#   make install  # ...or maybe "sudo make install"
+#
+# This will install the headers and libraries appropriately.
+#
+#
+# = Step 3. Build "fastprep". =
+#
+# Finally modify this file (if necessary) to update PB_DIR and TF_DIR to refer
+# to appropriate locations, and:
+#
+#   make -f fastprep.mk
+#
+# If all goes well, you should have a program that is "flag compatible" with
+# "prep.py" and runs significantly faster.  Use it to generate the co-occurrence
+# matrices and other files necessary to train a Swivel matrix.
+
+
+# The root directory where the Google Protobuf software is installed.
+# Alternative locations might be "/usr" or "/usr/local".
+PB_DIR=$(HOME)
+
+# Assuming you've got the Tensorflow source unpacked and built in ${HOME}/src:
+TF_DIR=$(HOME)/src/tensorflow
+
+PB_INCLUDE=$(PB_DIR)/include
+TF_INCLUDE=$(TF_DIR)/bazel-genfiles
+CXXFLAGS=-std=c++11 -m64 -mavx -g -Ofast -Wall -I$(TF_INCLUDE) -I$(PB_INCLUDE)
+
+PB_LIBDIR=$(PB_DIR)/lib
+TF_LIBDIR=$(TF_DIR)/bazel-bin/tensorflow/core
+LDFLAGS=-L$(TF_LIBDIR) -L$(PB_LIBDIR)
+LDLIBS=-lprotos_all_cc -lprotobuf -lpthread -lm
+
+fastprep: fastprep.cc

+ 75 - 0
swivel/nearest.py

@@ -0,0 +1,75 @@
+#!/usr/bin/env python
+#
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Simple tool for inspecting nearest neighbors and analogies."""
+
+import re
+import sys
+from getopt import GetoptError, getopt
+
+from vecs import Vecs
+
+try:
+  opts, args = getopt(sys.argv[1:], 'v:e:', ['vocab=', 'embeddings='])
+except GetoptError, e:
+  print >> sys.stderr, e
+  sys.exit(2)
+
+opt_vocab = 'vocab.txt'
+opt_embeddings = None
+
+for o, a in opts:
+  if o in ('-v', '--vocab'):
+    opt_vocab = a
+  if o in ('-e', '--embeddings'):
+    opt_embeddings = a
+
+vecs = Vecs(opt_vocab, opt_embeddings)
+
+while True:
+  sys.stdout.write('query> ')
+  sys.stdout.flush()
+
+  query = sys.stdin.readline().strip()
+  if not query:
+    break
+
+  parts = re.split(r'\s+', query)
+
+  if len(parts) == 1:
+    res = vecs.neighbors(parts[0])
+
+  elif len(parts) == 3:
+    vs = [vecs.lookup(w) for w in parts]
+    if any(v is None for v in vs):
+      print 'not in vocabulary: %s' % (
+          ', '.join(tok for tok, v in zip(parts, vs) if v is None))
+
+      continue
+
+    res = vecs.neighbors(vs[2] - vs[0] + vs[1])
+
+  else:
+    print 'use a single word to query neighbors, or three words for analogy'
+    continue
+
+  if not res:
+    continue
+
+  for word, sim in res[:20]:
+    print '%0.4f: %s' % (sim, word)
+
+  print

+ 317 - 0
swivel/prep.py

@@ -0,0 +1,317 @@
+#!/usr/bin/env python
+#
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Prepare a corpus for processing by swivel.
+
+Creates a sharded word co-occurrence matrix from a text file input corpus.
+
+Usage:
+
+  prep.py --output_dir <output-dir> --input <text-file>
+
+Options:
+
+  --input <filename>
+      The input text.
+
+  --output_dir <directory>
+      Specifies the output directory where the various Swivel data
+      files should be placed.
+
+  --shard_size <int>
+      Specifies the shard size; default 4096.
+
+  --min_count <int>
+      Specifies the minimum number of times a word should appear
+      to be included in the vocabulary; default 5.
+
+  --max_vocab <int>
+      Specifies the maximum vocabulary size; default shard size
+      times 1024.
+
+  --vocab <filename>
+      Use the specified unigram vocabulary instead of generating
+      it from the corpus.
+
+  --window_size <int>
+      Specifies the window size for computing co-occurrence stats;
+      default 10.
+
+  --bufsz <int>
+      The number of co-occurrences that are buffered; default 16M.
+
+"""
+
+import itertools
+import math
+import os
+import struct
+import sys
+
+import tensorflow as tf
+
+flags = tf.app.flags
+
+flags.DEFINE_string('input', '', 'The input text.')
+flags.DEFINE_string('output_dir', '/tmp/swivel_data',
+                    'Output directory for Swivel data')
+flags.DEFINE_integer('shard_size', 4096, 'The size for each shard')
+flags.DEFINE_integer('min_count', 5,
+                     'The minimum number of times a word should occur to be '
+                     'included in the vocabulary')
+flags.DEFINE_integer('max_vocab', 4096 * 64, 'The maximum vocabulary size')
+flags.DEFINE_string('vocab', '', 'Vocabulary to use instead of generating one')
+flags.DEFINE_integer('window_size', 10, 'The window size')
+flags.DEFINE_integer('bufsz', 16 * 1024 * 1024,
+                     'The number of co-occurrences to buffer')
+
+FLAGS = flags.FLAGS
+
+shard_cooc_fmt = struct.Struct('iif')
+
+
+def words(line):
+  """Splits a line of text into tokens."""
+  return line.strip().split()
+
+
+def create_vocabulary(lines):
+  """Reads text lines and generates a vocabulary."""
+  lines.seek(0, os.SEEK_END)
+  nbytes = lines.tell()
+  lines.seek(0, os.SEEK_SET)
+
+  vocab = {}
+  for lineno, line in enumerate(lines, start=1):
+    for word in words(line):
+      vocab.setdefault(word, 0)
+      vocab[word] += 1
+
+    if lineno % 100000 == 0:
+      pos = lines.tell()
+      sys.stdout.write('\rComputing vocabulary: %0.1f%% (%d/%d)...' % (
+          100.0 * pos / nbytes, pos, nbytes))
+      sys.stdout.flush()
+
+  sys.stdout.write('\n')
+
+  vocab = [(tok, n) for tok, n in vocab.iteritems() if n >= FLAGS.min_count]
+  vocab.sort(key=lambda kv: (-kv[1], kv[0]))
+
+  num_words = max(len(vocab), FLAGS.shard_size)
+  num_words = min(len(vocab), FLAGS.max_vocab)
+  if num_words % FLAGS.shard_size != 0:
+    num_words -= num_words % FLAGS.shard_size
+
+  if not num_words:
+    raise Exception('empty vocabulary')
+
+  print 'vocabulary contains %d tokens' % num_words
+
+  vocab = vocab[:num_words]
+  return [tok for tok, n in vocab]
+
+
+def write_vocab_and_sums(vocab, sums, vocab_filename, sums_filename):
+  """Writes vocabulary and marginal sum files."""
+  with open(os.path.join(FLAGS.output_dir, vocab_filename), 'w') as vocab_out:
+    with open(os.path.join(FLAGS.output_dir, sums_filename), 'w') as sums_out:
+      for tok, cnt in itertools.izip(vocab, sums):
+        print >> vocab_out, tok
+        print >> sums_out, cnt
+
+
+def compute_coocs(lines, vocab):
+  """Compute the co-occurrence statistics from the text.
+
+  This generates a temporary file for each shard that contains the intermediate
+  counts from the shard: these counts must be subsequently sorted and collated.
+
+  """
+  word_to_id = {tok: idx for idx, tok in enumerate(vocab)}
+
+  lines.seek(0, os.SEEK_END)
+  nbytes = lines.tell()
+  lines.seek(0, os.SEEK_SET)
+
+  num_shards = len(vocab) / FLAGS.shard_size
+
+  shardfiles = {}
+  for row in range(num_shards):
+    for col in range(num_shards):
+      filename = os.path.join(
+          FLAGS.output_dir, 'shard-%03d-%03d.tmp' % (row, col))
+
+      shardfiles[(row, col)] = open(filename, 'w+')
+
+  def flush_coocs():
+    for (row_id, col_id), cnt in coocs.iteritems():
+      row_shard = row_id % num_shards
+      row_off = row_id / num_shards
+      col_shard = col_id % num_shards
+      col_off = col_id / num_shards
+
+      # Since we only stored (a, b), we emit both (a, b) and (b, a).
+      shardfiles[(row_shard, col_shard)].write(
+          shard_cooc_fmt.pack(row_off, col_off, cnt))
+
+      shardfiles[(col_shard, row_shard)].write(
+          shard_cooc_fmt.pack(col_off, row_off, cnt))
+
+  coocs = {}
+  sums = [0.0] * len(vocab)
+
+  for lineno, line in enumerate(lines, start=1):
+    # Computes the word IDs for each word in the sentence.  This has the effect
+    # of "stretching" the window past OOV tokens.
+    wids = filter(
+        lambda wid: wid is not None,
+        (word_to_id.get(w) for w in words(line)))
+
+    for pos in xrange(len(wids)):
+      lid = wids[pos]
+      window_extent = min(FLAGS.window_size + 1, len(wids) - pos)
+      for off in xrange(1, window_extent):
+        rid = wids[pos + off]
+        pair = (min(lid, rid), max(lid, rid))
+        count = 1.0 / off
+        sums[lid] += count
+        sums[rid] += count
+        coocs.setdefault(pair, 0.0)
+        coocs[pair] += count
+
+      sums[lid] += 1.0
+      pair = (lid, lid)
+      coocs.setdefault(pair, 0.0)
+      coocs[pair] += 0.5  # Only add 1/2 since we output (a, b) and (b, a)
+
+    if lineno % 10000 == 0:
+      pos = lines.tell()
+      sys.stdout.write('\rComputing co-occurrences: %0.1f%% (%d/%d)...' % (
+          100.0 * pos / nbytes, pos, nbytes))
+      sys.stdout.flush()
+
+      if len(coocs) > FLAGS.bufsz:
+        flush_coocs()
+        coocs = {}
+
+  flush_coocs()
+  sys.stdout.write('\n')
+
+  return shardfiles, sums
+
+
+def write_shards(vocab, shardfiles):
+  """Processes the temporary files to generate the final shard data.
+
+  The shard data is stored as a tf.Example protos using a TFRecordWriter. The
+  temporary files are removed from the filesystem once they've been processed.
+
+  """
+  num_shards = len(vocab) / FLAGS.shard_size
+
+  ix = 0
+  for (row, col), fh in shardfiles.iteritems():
+    ix += 1
+    sys.stdout.write('\rwriting shard %d/%d' % (ix, len(shardfiles)))
+    sys.stdout.flush()
+
+    # Read the entire binary co-occurrence and unpack it into an array.
+    fh.seek(0)
+    buf = fh.read()
+    os.unlink(fh.name)
+    fh.close()
+
+    coocs = [
+        shard_cooc_fmt.unpack_from(buf, off)
+        for off in range(0, len(buf), shard_cooc_fmt.size)]
+
+    # Sort and merge co-occurrences for the same pairs.
+    coocs.sort()
+
+    if coocs:
+      current_pos = 0
+      current_row_col = (coocs[current_pos][0], coocs[current_pos][1])
+      for next_pos in range(1, len(coocs)):
+        next_row_col = (coocs[next_pos][0], coocs[next_pos][1])
+        if current_row_col == next_row_col:
+          coocs[current_pos] = (
+              coocs[current_pos][0],
+              coocs[current_pos][1],
+              coocs[current_pos][2] + coocs[next_pos][2])
+        else:
+          current_pos += 1
+          if current_pos < next_pos:
+            coocs[current_pos] = coocs[next_pos]
+
+          current_row_col = (coocs[current_pos][0], coocs[current_pos][1])
+
+      coocs = coocs[:(1 + current_pos)]
+
+    # Convert to a TF Example proto.
+    def _int64s(xs):
+      return tf.train.Feature(int64_list=tf.train.Int64List(value=list(xs)))
+
+    def _floats(xs):
+      return tf.train.Feature(float_list=tf.train.FloatList(value=list(xs)))
+
+    example = tf.train.Example(features=tf.train.Features(feature={
+        'global_row': _int64s(
+            row + num_shards * i for i in range(FLAGS.shard_size)),
+        'global_col': _int64s(
+            col + num_shards * i for i in range(FLAGS.shard_size)),
+
+        'sparse_local_row': _int64s(cooc[0] for cooc in coocs),
+        'sparse_local_col': _int64s(cooc[1] for cooc in coocs),
+        'sparse_value': _floats(cooc[2] for cooc in coocs),
+    }))
+
+    filename = os.path.join(FLAGS.output_dir, 'shard-%03d-%03d.pb' % (row, col))
+    with open(filename, 'w') as out:
+      out.write(example.SerializeToString())
+
+  sys.stdout.write('\n')
+
+
+def main(_):
+  # Create the output directory, if necessary
+  if FLAGS.output_dir and not os.path.isdir(FLAGS.output_dir):
+    os.makedirs(FLAGS.output_dir)
+
+  # Read the file onces to create the vocabulary.
+  if FLAGS.vocab:
+    with open(FLAGS.vocab, 'r') as lines:
+      vocab = [line.strip() for line in lines]
+  else:
+    with open(FLAGS.input, 'r') as lines:
+      vocab = create_vocabulary(lines)
+
+  # Now read the file again to determine the co-occurrence stats.
+  with open(FLAGS.input, 'r') as lines:
+    shardfiles, sums = compute_coocs(lines, vocab)
+
+  # Collect individual shards into the shards.recs file.
+  write_shards(vocab, shardfiles)
+
+  # Now write the marginals.  They're symmetric for this application.
+  write_vocab_and_sums(vocab, sums, 'row_vocab.txt', 'row_sums.txt')
+  write_vocab_and_sums(vocab, sums, 'col_vocab.txt', 'col_sums.txt')
+
+  print 'done!'
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 347 - 0
swivel/swivel.py

@@ -0,0 +1,347 @@
+#!/usr/bin/env python
+#
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Submatrix-wise Vector Embedding Learner.
+
+Implementation of SwiVel algorithm described at:
+http://arxiv.org/abs/1602.02215
+
+This program expects an input directory that contains the following files.
+
+  row_vocab.txt, col_vocab.txt
+
+    The row an column vocabulary files.  Each file should contain one token per
+    line; these will be used to generate a tab-separate file containing the
+    trained embeddings.
+
+  row_sums.txt, col_sum.txt
+
+    The matrix row and column marginal sums.  Each file should contain one
+    decimal floating point number per line which corresponds to the marginal
+    count of the matrix for that row or column.
+
+  shards.recs
+
+    A file containing the sub-matrix shards, stored as TFRecords.  Each shard is
+    expected to be a serialzed tf.Example protocol buffer with the following
+    properties:
+
+      global_row: the global row indicies contained in the shard
+      global_col: the global column indicies contained in the shard
+      sparse_local_row, sparse_local_col, sparse_value: three parallel arrays
+      that are a sparse representation of the submatrix counts.
+
+It will generate embeddings, training from the input directory for the specified
+number of epochs.  When complete, it will output the trained vectors to a
+tab-separated file that contains one line per embedding.  Row and column
+embeddings are stored in separate files.
+
+"""
+
+import argparse
+import glob
+import math
+import os
+import sys
+import time
+import threading
+
+import numpy as np
+import tensorflow as tf
+
+flags = tf.app.flags
+
+flags.DEFINE_string('input_base_path', '/tmp/swivel_data',
+                    'Directory containing input shards, vocabularies, '
+                    'and marginals.')
+flags.DEFINE_string('output_base_path', '/tmp/swivel_data',
+                    'Path where to write the trained embeddings.')
+flags.DEFINE_integer('embedding_size', 300, 'Size of the embeddings')
+flags.DEFINE_boolean('trainable_bias', False, 'Biases are trainable')
+flags.DEFINE_integer('submatrix_rows', 4096, 'Rows in each training submatrix. '
+                     'This must match the training data.')
+flags.DEFINE_integer('submatrix_cols', 4096, 'Rows in each training submatrix. '
+                     'This must match the training data.')
+flags.DEFINE_float('loss_multiplier', 1.0 / 4096,
+                   'constant multiplier on loss.')
+flags.DEFINE_float('confidence_exponent', 0.5,
+                   'Exponent for l2 confidence function')
+flags.DEFINE_float('confidence_scale', 0.25, 'Scale for l2 confidence function')
+flags.DEFINE_float('confidence_base', 0.1, 'Base for l2 confidence function')
+flags.DEFINE_float('learning_rate', 1.0, 'Initial learning rate')
+flags.DEFINE_integer('num_concurrent_steps', 2,
+                     'Number of threads to train with')
+flags.DEFINE_float('num_epochs', 40, 'Number epochs to train for')
+flags.DEFINE_float('per_process_gpu_memory_fraction', 0.25,
+                   'Fraction of GPU memory to use')
+
+FLAGS = flags.FLAGS
+
+
+def embeddings_with_init(vocab_size, embedding_dim, name):
+  """Creates and initializes the embedding tensors."""
+  return tf.get_variable(name=name,
+                         shape=[vocab_size, embedding_dim],
+                         initializer=tf.random_normal_initializer(
+                             stddev=math.sqrt(1.0 / embedding_dim)))
+
+
+def count_matrix_input(filenames, submatrix_rows, submatrix_cols):
+  """Reads submatrix shards from disk."""
+  filename_queue = tf.train.string_input_producer(filenames)
+  reader = tf.WholeFileReader()
+  _, serialized_example = reader.read(filename_queue)
+  features = tf.parse_single_example(
+      serialized_example,
+      features={
+          'global_row': tf.FixedLenFeature([submatrix_rows], dtype=tf.int64),
+          'global_col': tf.FixedLenFeature([submatrix_cols], dtype=tf.int64),
+          'sparse_local_row': tf.VarLenFeature(dtype=tf.int64),
+          'sparse_local_col': tf.VarLenFeature(dtype=tf.int64),
+          'sparse_value': tf.VarLenFeature(dtype=tf.float32)
+      })
+
+  global_row = features['global_row']
+  global_col = features['global_col']
+
+  sparse_local_row = features['sparse_local_row'].values
+  sparse_local_col = features['sparse_local_col'].values
+  sparse_count = features['sparse_value'].values
+
+  sparse_indices = tf.concat(1, [tf.expand_dims(sparse_local_row, 1),
+                                 tf.expand_dims(sparse_local_col, 1)])
+  count = tf.sparse_to_dense(sparse_indices, [submatrix_rows, submatrix_cols],
+                             sparse_count)
+
+  queued_global_row, queued_global_col, queued_count = tf.train.batch(
+      [global_row, global_col, count],
+      batch_size=1,
+      num_threads=4,
+      capacity=32)
+
+  queued_global_row = tf.reshape(queued_global_row, [submatrix_rows])
+  queued_global_col = tf.reshape(queued_global_col, [submatrix_cols])
+  queued_count = tf.reshape(queued_count, [submatrix_rows, submatrix_cols])
+
+  return queued_global_row, queued_global_col, queued_count
+
+
+def read_marginals_file(filename):
+  """Reads text file with one number per line to an array."""
+  with open(filename) as lines:
+    return [float(line) for line in lines]
+
+
+def write_embedding_tensor_to_disk(vocab_path, output_path, sess, embedding):
+  """Writes tensor to output_path as tsv"""
+  # Fetch the embedding values from the model
+  embeddings = sess.run(embedding)
+
+  with open(output_path, 'w') as out_f:
+    with open(vocab_path) as vocab_f:
+      for index, word in enumerate(vocab_f):
+        word = word.strip()
+        embedding = embeddings[index]
+        out_f.write(word + '\t' + '\t'.join([str(x) for x in embedding]) + '\n')
+
+
+def write_embeddings_to_disk(config, model, sess):
+  """Writes row and column embeddings disk"""
+  # Row Embedding
+  row_vocab_path = config.input_base_path + '/row_vocab.txt'
+  row_embedding_output_path = config.output_base_path + '/row_embedding.tsv'
+  print 'Writing row embeddings to:', row_embedding_output_path
+  write_embedding_tensor_to_disk(row_vocab_path, row_embedding_output_path,
+                                 sess, model.row_embedding)
+
+  # Column Embedding
+  col_vocab_path = config.input_base_path + '/col_vocab.txt'
+  col_embedding_output_path = config.output_base_path + '/col_embedding.tsv'
+  print 'Writing column embeddings to:', col_embedding_output_path
+  write_embedding_tensor_to_disk(col_vocab_path, col_embedding_output_path,
+                                 sess, model.col_embedding)
+
+
+class SwivelModel(object):
+  """Small class to gather needed pieces from a Graph being built."""
+
+  def __init__(self, config):
+    """Construct graph for dmc."""
+    self._config = config
+
+    # Create paths to input data files
+    print 'Reading model from:', config.input_base_path
+    count_matrix_files = glob.glob(config.input_base_path + '/shard-*.pb')
+    row_sums_path = config.input_base_path + '/row_sums.txt'
+    col_sums_path = config.input_base_path + '/col_sums.txt'
+
+    # Read marginals
+    row_sums = read_marginals_file(row_sums_path)
+    col_sums = read_marginals_file(col_sums_path)
+
+    self.n_rows = len(row_sums)
+    self.n_cols = len(col_sums)
+    print 'Matrix dim: (%d,%d) SubMatrix dim: (%d,%d) ' % (
+        self.n_rows, self.n_cols, config.submatrix_rows, config.submatrix_cols)
+    self.n_submatrices = (self.n_rows * self.n_cols /
+                          (config.submatrix_rows * config.submatrix_cols))
+    print 'n_submatrices: %d' % (self.n_submatrices)
+
+    # ===== CREATE VARIABLES ======
+
+    with tf.device('/cpu:0'):
+      # embeddings
+      self.row_embedding = embeddings_with_init(
+          embedding_dim=config.embedding_size,
+          vocab_size=self.n_rows,
+          name='row_embedding')
+      self.col_embedding = embeddings_with_init(
+          embedding_dim=config.embedding_size,
+          vocab_size=self.n_cols,
+          name='col_embedding')
+      tf.histogram_summary('row_emb', self.row_embedding)
+      tf.histogram_summary('col_emb', self.col_embedding)
+
+      matrix_log_sum = math.log(np.sum(row_sums) + 1)
+      row_bias_init = [math.log(x + 1) for x in row_sums]
+      col_bias_init = [math.log(x + 1) for x in col_sums]
+      self.row_bias = tf.Variable(row_bias_init,
+                                  trainable=config.trainable_bias)
+      self.col_bias = tf.Variable(col_bias_init,
+                                  trainable=config.trainable_bias)
+      tf.histogram_summary('row_bias', self.row_bias)
+      tf.histogram_summary('col_bias', self.col_bias)
+
+    # ===== CREATE GRAPH =====
+
+    # Get input
+    with tf.device('/cpu:0'):
+      global_row, global_col, count = count_matrix_input(
+          count_matrix_files, config.submatrix_rows, config.submatrix_cols)
+
+      # Fetch embeddings.
+      selected_row_embedding = tf.nn.embedding_lookup(self.row_embedding,
+                                                      global_row)
+      selected_col_embedding = tf.nn.embedding_lookup(self.col_embedding,
+                                                      global_col)
+
+      # Fetch biases.
+      selected_row_bias = tf.nn.embedding_lookup([self.row_bias], global_row)
+      selected_col_bias = tf.nn.embedding_lookup([self.col_bias], global_col)
+
+    # Multiply the row and column embeddings to generate predictions.
+    predictions = tf.matmul(
+        selected_row_embedding, selected_col_embedding, transpose_b=True)
+
+    # These binary masks separate zero from non-zero values.
+    count_is_nonzero = tf.to_float(tf.cast(count, tf.bool))
+    count_is_zero = 1 - tf.to_float(tf.cast(count, tf.bool))
+
+    objectives = count_is_nonzero * tf.log(count + 1e-30)
+    objectives -= tf.reshape(selected_row_bias, [config.submatrix_rows, 1])
+    objectives -= selected_col_bias
+    objectives += matrix_log_sum
+
+    err = predictions - objectives
+
+    # The confidence function scales the L2 loss based on the raw co-occurrence
+    # count.
+    l2_confidence = (config.confidence_base + config.confidence_scale * tf.pow(
+        count, config.confidence_exponent))
+
+    l2_loss = config.loss_multiplier * tf.reduce_sum(
+        0.5 * l2_confidence * err * err * count_is_nonzero)
+
+    sigmoid_loss = config.loss_multiplier * tf.reduce_sum(
+        tf.nn.softplus(err) * count_is_zero)
+
+    self.loss = l2_loss + sigmoid_loss
+
+    tf.scalar_summary("l2_loss", l2_loss)
+    tf.scalar_summary("sigmoid_loss", sigmoid_loss)
+    tf.scalar_summary("loss", self.loss)
+
+    # Add optimizer.
+    self.global_step = tf.Variable(0, name='global_step')
+    opt = tf.train.AdagradOptimizer(config.learning_rate)
+    self.train_op = opt.minimize(self.loss, global_step=self.global_step)
+    self.saver = tf.train.Saver(sharded=True)
+
+
+def main(_):
+  # Create the output path.  If this fails, it really ought to fail
+  # now. :)
+  if not os.path.isdir(FLAGS.output_base_path):
+    os.makedirs(FLAGS.output_base_path)
+
+  # Create and run model
+  with tf.Graph().as_default():
+    model = SwivelModel(FLAGS)
+
+    # Create a session for running Ops on the Graph.
+    gpu_options = tf.GPUOptions(
+        per_process_gpu_memory_fraction=FLAGS.per_process_gpu_memory_fraction)
+    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
+
+    # Run the Op to initialize the variables.
+    sess.run(tf.initialize_all_variables())
+
+    # Start feeding input
+    coord = tf.train.Coordinator()
+    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
+
+    # Calculate how many steps each thread should run
+    n_total_steps = int(FLAGS.num_epochs * model.n_rows * model.n_cols) / (
+        FLAGS.submatrix_rows * FLAGS.submatrix_cols)
+    n_steps_per_thread = n_total_steps / FLAGS.num_concurrent_steps
+    n_submatrices_to_train = model.n_submatrices * FLAGS.num_epochs
+    t0 = [time.time()]
+
+    def TrainingFn():
+      for _ in range(n_steps_per_thread):
+        _, global_step = sess.run([model.train_op, model.global_step])
+        n_steps_between_status_updates = 100
+        if (global_step % n_steps_between_status_updates) == 0:
+          elapsed = float(time.time() - t0[0])
+          print '%d/%d submatrices trained (%.1f%%), %.1f submatrices/sec' % (
+              global_step, n_submatrices_to_train,
+              100.0 * global_step / n_submatrices_to_train,
+              n_steps_between_status_updates / elapsed)
+          t0[0] = time.time()
+
+    # Start training threads
+    train_threads = []
+    for _ in range(FLAGS.num_concurrent_steps):
+      t = threading.Thread(target=TrainingFn)
+      train_threads.append(t)
+      t.start()
+
+    # Wait for threads to finish.
+    for t in train_threads:
+      t.join()
+
+    coord.request_stop()
+    coord.join(threads)
+
+    # Write out vectors
+    write_embeddings_to_disk(FLAGS, model, sess)
+
+    #Shutdown
+    sess.close()
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 88 - 0
swivel/text2bin.py

@@ -0,0 +1,88 @@
+#!/usr/bin/env python
+#
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Converts vectors from text to a binary format for quicker manipulation.
+
+Usage:
+
+  text2bin.py -o <out> -v <vocab> vec1.txt [vec2.txt ...]
+
+Optiona:
+
+  -o <filename>, --output <filename>
+    The name of the file into which the binary vectors are written.
+
+  -v <filename>, --vocab <filename>
+    The name of the file into which the vocabulary is written.
+
+Description
+
+This program merges one or more whitespace separated vector files into a single
+binary vector file that can be used by downstream evaluation tools in this
+directory ("wordsim.py" and "analogy").
+
+If more than one vector file is specified, then the files must be aligned
+row-wise (i.e., each line must correspond to the same embedding), and they must
+have the same number of columns (i.e., be the same dimension).
+
+"""
+
+from itertools import izip
+from getopt import GetoptError, getopt
+import os
+import struct
+import sys
+
+try:
+  opts, args = getopt(
+      sys.argv[1:], 'o:v:', ['output=', 'vocab='])
+except GetoptError, e:
+  print >> sys.stderr, e
+  sys.exit(2)
+
+opt_output = 'vecs.bin'
+opt_vocab = 'vocab.txt'
+for o, a in opts:
+  if o in ('-o', '--output'):
+    opt_output = a
+  if o in ('-v', '--vocab'):
+    opt_vocab = a
+
+def go(fhs):
+  fmt = None
+  with open(opt_vocab, 'w') as vocab_out:
+    with open(opt_output, 'w') as vecs_out:
+      for lines in izip(*fhs):
+        parts = [line.split() for line in lines]
+        token = parts[0][0]
+        if any(part[0] != token for part in parts[1:]):
+          raise IOError('vector files must be aligned')
+
+        print >> vocab_out, token
+
+        vec = [sum(float(x) for x in xs) for xs in zip(*parts)[1:]]
+        if not fmt:
+          fmt = struct.Struct('%df' % len(vec))
+
+        vecs_out.write(fmt.pack(*vec))
+
+if args:
+  fhs = [open(filename) for filename in args]
+  go(fhs)
+  for fh in fhs:
+    fh.close()
+else:
+  go([sys.stdin])

+ 90 - 0
swivel/vecs.py

@@ -0,0 +1,90 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import mmap
+import numpy as np
+import os
+import struct
+
+class Vecs(object):
+  def __init__(self, vocab_filename, rows_filename, cols_filename=None):
+    """Initializes the vectors from a text vocabulary and binary data."""
+    with open(vocab_filename, 'r') as lines:
+      self.vocab = [line.split()[0] for line in lines]
+      self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)}
+
+    n = len(self.vocab)
+
+    with open(rows_filename, 'r') as rows_fh:
+      rows_fh.seek(0, os.SEEK_END)
+      size = rows_fh.tell()
+
+      # Make sure that the file size seems reasonable.
+      if size % (4 * n) != 0:
+        raise IOError(
+            'unexpected file size for binary vector file %s' % rows_filename)
+
+      # Memory map the rows.
+      dim = size / (4 * n)
+      rows_mm = mmap.mmap(rows_fh.fileno(), 0, prot=mmap.PROT_READ)
+      rows = np.matrix(
+          np.frombuffer(rows_mm, dtype=np.float32).reshape(n, dim))
+
+      # If column vectors were specified, then open them and add them to the row
+      # vectors.
+      if cols_filename:
+        with open(cols_filename, 'r') as cols_fh:
+          cols_mm = mmap.mmap(cols_fh.fileno(), 0, prot=mmap.PROT_READ)
+          cols_fh.seek(0, os.SEEK_END)
+          if cols_fh.tell() != size:
+            raise IOError('row and column vector files have different sizes')
+
+          cols = np.matrix(
+              np.frombuffer(cols_mm, dtype=np.float32).reshape(n, dim))
+
+          rows += cols
+          cols_mm.close()
+
+      # Normalize so that dot products are just cosine similarity.
+      self.vecs = rows / np.linalg.norm(rows, axis=1).reshape(n, 1)
+      rows_mm.close()
+
+  def similarity(self, word1, word2):
+    """Computes the similarity of two tokens."""
+    idx1 = self.word_to_idx.get(word1)
+    idx2 = self.word_to_idx.get(word2)
+    if not idx1 or not idx2:
+      return None
+
+    return float(self.vecs[idx1] * self.vecs[idx2].transpose())
+
+  def neighbors(self, query):
+    """Returns the nearest neighbors to the query (a word or vector)."""
+    if isinstance(query, basestring):
+      idx = self.word_to_idx.get(query)
+      if idx is None:
+        return None
+
+      query = self.vecs[idx]
+
+    neighbors = self.vecs * query.transpose()
+
+    return sorted(
+      zip(self.vocab, neighbors.flat),
+      key=lambda kv: kv[1], reverse=True)
+
+  def lookup(self, word):
+    """Returns the embedding for a token, or None if no embedding exists."""
+    idx = self.word_to_idx.get(word)
+    return None if idx is None else self.vecs[idx]

+ 90 - 0
swivel/wordsim.py

@@ -0,0 +1,90 @@
+#!/usr/bin/env python
+#
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Computes Spearman's rho with respect to human judgements.
+
+Given a set of row (and potentially column) embeddings, this computes Spearman's
+rho between the rank ordering of predicted word similarity and human judgements.
+
+Usage:
+
+  wordim.py --embeddings=<binvecs> --vocab=<vocab> eval1.tab eval2.tab ...
+
+Options:
+
+  --embeddings=<filename>: the vectors to test
+  --vocab=<filename>: the vocabulary file
+
+Evaluation files are assumed to be tab-separated files with exactly three
+columns.  The first two columns contain the words, and the third column contains
+the scored human judgement.
+
+"""
+
+import scipy.stats
+import sys
+from getopt import GetoptError, getopt
+
+from vecs import Vecs
+
+try:
+  opts, args = getopt(sys.argv[1:], '', ['embeddings=', 'vocab='])
+except GetoptError, e:
+  print >> sys.stderr, e
+  sys.exit(2)
+
+opt_embeddings = None
+opt_vocab = None
+
+for o, a in opts:
+  if o == '--embeddings':
+    opt_embeddings = a
+  if o == '--vocab':
+    opt_vocab = a
+
+if not opt_vocab:
+  print >> sys.stderr, 'please specify a vocabulary file with "--vocab"'
+  sys.exit(2)
+
+if not opt_embeddings:
+  print >> sys.stderr, 'please specify the embeddings with "--embeddings"'
+  sys.exit(2)
+
+try:
+  vecs = Vecs(opt_vocab, opt_embeddings)
+except IOError, e:
+  print >> sys.stderr, e
+  sys.exit(1)
+
+def evaluate(lines):
+  acts, preds = [], []
+
+  with open(filename, 'r') as lines:
+    for line in lines:
+      w1, w2, act = line.strip().split('\t')
+      pred = vecs.similarity(w1, w2)
+      if pred is None:
+        continue
+
+      acts.append(float(act))
+      preds.append(pred)
+
+  rho, _ = scipy.stats.spearmanr(acts, preds)
+  return rho
+
+for filename in args:
+  with open(filename, 'r') as lines:
+    print '%0.3f %s' % (evaluate(lines), filename)