analogy.cc 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. /* -*- Mode: C++ -*- */
  2. /*
  3. * Copyright 2016 Google Inc. All Rights Reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. /*
  18. * Computes embedding performance on analogy tasks. Accepts as input one or
  19. * more files containing four words per line (A B C D), and determines if:
  20. *
  21. * vec(C) - vec(A) + vec(B) ~= vec(D)
  22. *
  23. * Cosine distance in the embedding space is used to retrieve neighbors. Any
  24. * missing vocabulary items are scored as losses.
  25. */
  26. #include <fcntl.h>
  27. #include <math.h>
  28. #include <pthread.h>
  29. #include <stdio.h>
  30. #include <stdlib.h>
  31. #include <string.h>
  32. #include <sys/stat.h>
  33. #include <sys/types.h>
  34. #include <unistd.h>
  35. #include <fstream>
  36. #include <iostream>
  37. #include <string>
  38. #include <unordered_map>
  39. #include <vector>
  40. static const char usage[] = R"(
  41. Performs analogy testing of embedding vectors.
  42. Usage:
  43. analogy --embeddings <embeddings> --vocab <vocab> eval1.tab ...
  44. Options:
  45. --embeddings <filename>
  46. The file containing the binary embedding vectors to evaluate.
  47. --vocab <filename>
  48. The vocabulary file corresponding to the embedding vectors.
  49. --nthreads <integer>
  50. The number of evaluation threads to run (default: 8)
  51. )";
  52. // Reads the vocabulary file into a map from token to vector index.
  53. static std::unordered_map<std::string, int> ReadVocab(
  54. const std::string& vocab_filename) {
  55. std::unordered_map<std::string, int> vocab;
  56. std::ifstream fin(vocab_filename);
  57. int index = 0;
  58. for (std::string token; std::getline(fin, token); ++index) {
  59. auto n = token.find('\t');
  60. if (n != std::string::npos) token = token.substr(n);
  61. vocab[token] = index;
  62. }
  63. return vocab;
  64. }
  65. // An analogy query: "A is to B as C is to D".
  66. typedef std::tuple<int, int, int, int> AnalogyQuery;
  67. std::vector<AnalogyQuery> ReadQueries(
  68. const std::string &filename,
  69. const std::unordered_map<std::string, int> &vocab, int *total) {
  70. std::ifstream fin(filename);
  71. std::vector<AnalogyQuery> queries;
  72. int lineno = 0;
  73. while (1) {
  74. // Read the four words.
  75. std::string words[4];
  76. int nread = 0;
  77. for (int i = 0; i < 4; ++i) {
  78. fin >> words[i];
  79. if (!words[i].empty()) ++nread;
  80. }
  81. ++lineno;
  82. if (nread == 0) break;
  83. if (nread < 4) {
  84. std::cerr << "expected four words at line " << lineno << std::endl;
  85. break;
  86. }
  87. // Look up each word's index.
  88. int ixs[4], nvalid;
  89. for (nvalid = 0; nvalid < 4; ++nvalid) {
  90. std::unordered_map<std::string, int>::const_iterator it =
  91. vocab.find(words[nvalid]);
  92. if (it == vocab.end()) break;
  93. ixs[nvalid] = it->second;
  94. }
  95. // If we don't have all the words, count it as a loss.
  96. if (nvalid >= 4)
  97. queries.push_back(std::make_tuple(ixs[0], ixs[1], ixs[2], ixs[3]));
  98. }
  99. *total = lineno;
  100. return queries;
  101. }
  102. // A thread that evaluates some fraction of the analogies.
  103. class AnalogyEvaluator {
  104. public:
  105. // Creates a new Analogy evaluator for a range of analogy queries.
  106. AnalogyEvaluator(std::vector<AnalogyQuery>::const_iterator begin,
  107. std::vector<AnalogyQuery>::const_iterator end,
  108. const float *embeddings, const int num_embeddings,
  109. const int dim)
  110. : begin_(begin),
  111. end_(end),
  112. embeddings_(embeddings),
  113. num_embeddings_(num_embeddings),
  114. dim_(dim) {}
  115. // A thunk for pthreads.
  116. static void* Run(void *param) {
  117. AnalogyEvaluator *self = static_cast<AnalogyEvaluator*>(param);
  118. self->Evaluate();
  119. return nullptr;
  120. }
  121. // Evaluates the analogies.
  122. void Evaluate();
  123. // Returns the number of correct analogies after evaluation is complete.
  124. int GetNumCorrect() const { return correct_; }
  125. protected:
  126. // The beginning of the range of queries to consider.
  127. std::vector<AnalogyQuery>::const_iterator begin_;
  128. // The end of the range of queries to consider.
  129. std::vector<AnalogyQuery>::const_iterator end_;
  130. // The raw embedding vectors.
  131. const float *embeddings_;
  132. // The number of embedding vectors.
  133. const int num_embeddings_;
  134. // The embedding vector dimensionality.
  135. const int dim_;
  136. // The number of correct analogies.
  137. int correct_;
  138. };
  139. void AnalogyEvaluator::Evaluate() {
  140. float* sum = new float[dim_];
  141. correct_ = 0;
  142. for (auto query = begin_; query < end_; ++query) {
  143. const float* vec;
  144. int a, b, c, d;
  145. std::tie(a, b, c, d) = *query;
  146. // Compute C - A + B.
  147. vec = embeddings_ + dim_ * c;
  148. for (int i = 0; i < dim_; ++i) sum[i] = vec[i];
  149. vec = embeddings_ + dim_ * a;
  150. for (int i = 0; i < dim_; ++i) sum[i] -= vec[i];
  151. vec = embeddings_ + dim_ * b;
  152. for (int i = 0; i < dim_; ++i) sum[i] += vec[i];
  153. // Find the nearest neighbor that isn't one of the query words.
  154. int best_ix = -1;
  155. float best_dot = -1.0;
  156. for (int i = 0; i < num_embeddings_; ++i) {
  157. if (i == a || i == b || i == c) continue;
  158. vec = embeddings_ + dim_ * i;
  159. float dot = 0;
  160. for (int j = 0; j < dim_; ++j) dot += vec[j] * sum[j];
  161. if (dot > best_dot) {
  162. best_ix = i;
  163. best_dot = dot;
  164. }
  165. }
  166. // The fourth word is the answer; did we get it right?
  167. if (best_ix == d) ++correct_;
  168. }
  169. delete[] sum;
  170. }
  171. int main(int argc, char *argv[]) {
  172. if (argc <= 1) {
  173. printf(usage);
  174. return 2;
  175. }
  176. std::string embeddings_filename, vocab_filename;
  177. int nthreads = 8;
  178. std::vector<std::string> input_filenames;
  179. std::vector<std::tuple<int, int, int, int>> queries;
  180. for (int i = 1; i < argc; ++i) {
  181. std::string arg = argv[i];
  182. if (arg == "--embeddings") {
  183. if (++i >= argc) goto argmissing;
  184. embeddings_filename = argv[i];
  185. } else if (arg == "--vocab") {
  186. if (++i >= argc) goto argmissing;
  187. vocab_filename = argv[i];
  188. } else if (arg == "--nthreads") {
  189. if (++i >= argc) goto argmissing;
  190. if ((nthreads = atoi(argv[i])) <= 0) goto badarg;
  191. } else if (arg == "--help") {
  192. std::cout << usage << std::endl;
  193. return 0;
  194. } else if (arg[0] == '-') {
  195. std::cerr << "unknown option: '" << arg << "'" << std::endl;
  196. return 2;
  197. } else {
  198. input_filenames.push_back(arg);
  199. }
  200. continue;
  201. argmissing:
  202. std::cerr << "missing value for '" << argv[i - 1] << "' (--help for help)"
  203. << std::endl;
  204. return 2;
  205. badarg:
  206. std::cerr << "invalid value '" << argv[i] << "' for '" << argv[i - 1]
  207. << "' (--help for help)" << std::endl;
  208. return 2;
  209. }
  210. // Read the vocabulary.
  211. std::unordered_map<std::string, int> vocab = ReadVocab(vocab_filename);
  212. if (!vocab.size()) {
  213. std::cerr << "unable to read vocabulary file '" << vocab_filename << "'"
  214. << std::endl;
  215. return 1;
  216. }
  217. const int n = vocab.size();
  218. // Read the vectors.
  219. int fd;
  220. if ((fd = open(embeddings_filename.c_str(), O_RDONLY)) < 0) {
  221. std::cerr << "unable to open embeddings file '" << embeddings_filename
  222. << "'" << std::endl;
  223. return 1;
  224. }
  225. off_t nbytes = lseek(fd, 0, SEEK_END);
  226. if (nbytes == -1) {
  227. std::cerr << "unable to determine file size for '" << embeddings_filename
  228. << "'" << std::endl;
  229. return 1;
  230. }
  231. if (nbytes % (sizeof(float) * n) != 0) {
  232. std::cerr << "'" << embeddings_filename
  233. << "' has a strange file size; expected it to be "
  234. "a multiple of the vocabulary size"
  235. << std::endl;
  236. return 1;
  237. }
  238. const int dim = nbytes / (sizeof(float) * n);
  239. float *embeddings = static_cast<float *>(malloc(nbytes));
  240. lseek(fd, 0, SEEK_SET);
  241. if (read(fd, embeddings, nbytes) < nbytes) {
  242. std::cerr << "unable to read embeddings from " << embeddings_filename
  243. << std::endl;
  244. return 1;
  245. }
  246. close(fd);
  247. /* Normalize the vectors. */
  248. for (int i = 0; i < n; ++i) {
  249. float *vec = embeddings + dim * i;
  250. float norm = 0;
  251. for (int j = 0; j < dim; ++j) norm += vec[j] * vec[j];
  252. norm = sqrt(norm);
  253. for (int j = 0; j < dim; ++j) vec[j] /= norm;
  254. }
  255. pthread_attr_t attr;
  256. if (pthread_attr_init(&attr) != 0) {
  257. std::cerr << "unable to initalize pthreads" << std::endl;
  258. return 1;
  259. }
  260. /* Read each input file. */
  261. for (const auto filename : input_filenames) {
  262. int total = 0;
  263. std::vector<AnalogyQuery> queries =
  264. ReadQueries(filename.c_str(), vocab, &total);
  265. const int queries_per_thread = queries.size() / nthreads;
  266. std::vector<AnalogyEvaluator*> evaluators;
  267. std::vector<pthread_t> threads;
  268. for (int i = 0; i < nthreads; ++i) {
  269. auto begin = queries.begin() + i * queries_per_thread;
  270. auto end = (i + 1 < nthreads)
  271. ? queries.begin() + (i + 1) * queries_per_thread
  272. : queries.end();
  273. AnalogyEvaluator *evaluator =
  274. new AnalogyEvaluator(begin, end, embeddings, n, dim);
  275. pthread_t thread;
  276. pthread_create(&thread, &attr, AnalogyEvaluator::Run, evaluator);
  277. evaluators.push_back(evaluator);
  278. threads.push_back(thread);
  279. }
  280. for (auto &thread : threads) pthread_join(thread, 0);
  281. int correct = 0;
  282. for (const AnalogyEvaluator* evaluator : evaluators) {
  283. correct += evaluator->GetNumCorrect();
  284. delete evaluator;
  285. }
  286. printf("%0.3f %s\n", static_cast<float>(correct) / total, filename.c_str());
  287. }
  288. return 0;
  289. }