fastprep.cc 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687
  1. /* -*- Mode: C++ -*- */
  2. /*
  3. * Copyright 2016 Google Inc. All Rights Reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. /*
  18. * This program starts with a text file (and optionally a vocabulary file) and
  19. * computes co-occurrence statistics. It emits output in a format that can be
  20. * consumed by the "swivel" program. It's functionally equivalent to "prep.py",
  21. * but works much more quickly.
  22. */
  23. #include <assert.h>
  24. #include <fcntl.h>
  25. #include <pthread.h>
  26. #include <stdio.h>
  27. #include <sys/mman.h>
  28. #include <sys/stat.h>
  29. #include <unistd.h>
  30. #include <algorithm>
  31. #include <fstream>
  32. #include <iomanip>
  33. #include <iostream>
  34. #include <map>
  35. #include <string>
  36. #include <tuple>
  37. #include <unordered_map>
  38. #include <vector>
  39. #include "google/protobuf/io/zero_copy_stream_impl.h"
  40. #include "tensorflow/core/example/example.pb.h"
  41. #include "tensorflow/core/example/feature.pb.h"
  42. static const char usage[] = R"(
  43. Prepares a corpus for processing by Swivel.
  44. Usage:
  45. prep --output_dir <output-dir> --input <text-file>
  46. Options:
  47. --input <filename>
  48. The input text.
  49. --output_dir <directory>
  50. Specifies the output directory where the various Swivel data
  51. files should be placed. This directory must exist.
  52. --shard_size <int>
  53. Specifies the shard size; default 4096.
  54. --min_count <int>
  55. The minimum number of times a word should appear to be included in the
  56. generated vocabulary; default 5. (Ignored if --vocab is used.)
  57. --max_vocab <int>
  58. The maximum vocabulary size to generate from the input corpus; default
  59. 102,400. (Ignored if --vocab is used.)
  60. --vocab <filename>
  61. Use the specified unigram vocabulary instead of generating
  62. it from the corpus.
  63. --window_size <int>
  64. Specifies the window size for computing co-occurrence stats;
  65. default 10.
  66. --num_threads <int>
  67. The number of workers to calculate the co-occurrence matrix;
  68. default 4.
  69. )";
  70. struct cooc_t {
  71. int row;
  72. int col;
  73. float cnt;
  74. };
  75. typedef std::map<long long, float> cooc_counts_t;
  76. // Retrieves the next word from the input stream, treating words as simply being
  77. // delimited by whitespace. Returns true if this is the end of a "sentence";
  78. // i.e., a newline.
  79. bool NextWord(std::ifstream &fin, std::string* word) {
  80. std::string buf;
  81. char c;
  82. if (fin.eof()) {
  83. word->erase();
  84. return true;
  85. }
  86. // Skip leading whitespace.
  87. do {
  88. c = fin.get();
  89. } while (!fin.eof() && std::isspace(c));
  90. if (fin.eof()) {
  91. word->erase();
  92. return true;
  93. }
  94. // Read the next word.
  95. do {
  96. buf += c;
  97. c = fin.get();
  98. } while (!fin.eof() && !std::isspace(c));
  99. *word = buf;
  100. if (c == '\n' || fin.eof()) return true;
  101. // Skip trailing whitespace.
  102. do {
  103. c = fin.get();
  104. } while (!fin.eof() && std::isspace(c));
  105. if (fin.eof()) return true;
  106. fin.unget();
  107. return false;
  108. }
  109. // Creates a vocabulary from the most frequent terms in the input file.
  110. std::vector<std::string> CreateVocabulary(const std::string input_filename,
  111. const int shard_size,
  112. const int min_vocab_count,
  113. const int max_vocab_size) {
  114. std::vector<std::string> vocab;
  115. // Count all the distinct tokens in the file. (XXX this will eventually
  116. // consume all memory and should be re-written to periodically trim the data.)
  117. std::unordered_map<std::string, long long> counts;
  118. std::ifstream fin(input_filename, std::ifstream::ate);
  119. if (!fin) {
  120. std::cerr << "couldn't read input file '" << input_filename << "'"
  121. << std::endl;
  122. return vocab;
  123. }
  124. const auto input_size = fin.tellg();
  125. fin.seekg(0);
  126. long long ntokens = 0;
  127. while (!fin.eof()) {
  128. std::string word;
  129. NextWord(fin, &word);
  130. counts[word] += 1;
  131. if (++ntokens % 1000000 == 0) {
  132. const float pct = 100.0 * static_cast<float>(fin.tellg()) / input_size;
  133. fprintf(stdout, "\rComputing vocabulary: %0.1f%% complete...", pct);
  134. std::flush(std::cout);
  135. }
  136. }
  137. std::cout << counts.size() << " distinct tokens" << std::endl;
  138. // Sort the vocabulary from most frequent to least frequent.
  139. std::vector<std::pair<std::string, long long>> buf;
  140. std::copy(counts.begin(), counts.end(), std::back_inserter(buf));
  141. std::sort(buf.begin(), buf.end(),
  142. [](const std::pair<std::string, long long> &a,
  143. const std::pair<std::string, long long> &b) {
  144. return b.second < a.second;
  145. });
  146. // Truncate to the maximum vocabulary size
  147. if (static_cast<int>(buf.size()) > max_vocab_size) buf.resize(max_vocab_size);
  148. if (buf.empty()) return vocab;
  149. // Eliminate rare tokens and truncate to a size modulo the shard size.
  150. int vocab_size = buf.size();
  151. while (vocab_size > 0 && buf[vocab_size - 1].second < min_vocab_count)
  152. --vocab_size;
  153. vocab_size -= vocab_size % shard_size;
  154. if (static_cast<int>(buf.size()) > vocab_size) buf.resize(vocab_size);
  155. // Copy out the tokens.
  156. for (const auto& pair : buf) vocab.push_back(pair.first);
  157. return vocab;
  158. }
  159. std::vector<std::string> ReadVocabulary(const std::string vocab_filename) {
  160. std::vector<std::string> vocab;
  161. std::ifstream fin(vocab_filename);
  162. int index = 0;
  163. for (std::string token; std::getline(fin, token); ++index) {
  164. auto n = token.find('\t');
  165. if (n != std::string::npos) token = token.substr(n);
  166. vocab.push_back(token);
  167. }
  168. return vocab;
  169. }
  170. void WriteVocabulary(const std::vector<std::string> &vocab,
  171. const std::string &output_dirname) {
  172. for (const std::string filename : {"row_vocab.txt", "col_vocab.txt"}) {
  173. std::ofstream fout(output_dirname + "/" + filename);
  174. for (const auto &token : vocab) fout << token << std::endl;
  175. }
  176. }
  177. // Manages accumulation of co-occurrence data into temporary disk buffer files.
  178. class CoocBuffer {
  179. public:
  180. CoocBuffer(const std::string &output_dirname, const int num_shards,
  181. const int shard_size);
  182. // Accumulate the co-occurrence counts to the buffer.
  183. void AccumulateCoocs(const cooc_counts_t &coocs);
  184. // Read the buffer to produce shard files.
  185. void WriteShards();
  186. protected:
  187. // The output directory. Also used for temporary buffer files.
  188. const std::string output_dirname_;
  189. // The number of row/column shards.
  190. const int num_shards_;
  191. // The number of elements per shard.
  192. const int shard_size_;
  193. // Parallel arrays of temporary file paths and file descriptors.
  194. std::vector<std::string> paths_;
  195. std::vector<int> fds_;
  196. // Ensures that only one buffer file is getting written at a time.
  197. pthread_mutex_t writer_mutex_;
  198. };
  199. CoocBuffer::CoocBuffer(const std::string &output_dirname, const int num_shards,
  200. const int shard_size)
  201. : output_dirname_(output_dirname),
  202. num_shards_(num_shards),
  203. shard_size_(shard_size),
  204. writer_mutex_(PTHREAD_MUTEX_INITIALIZER) {
  205. for (int row = 0; row < num_shards_; ++row) {
  206. for (int col = 0; col < num_shards_; ++col) {
  207. char filename[256];
  208. sprintf(filename, "shard-%03d-%03d.tmp", row, col);
  209. std::string path = output_dirname + "/" + filename;
  210. int fd = open(path.c_str(), O_RDWR | O_CREAT | O_TRUNC, 0666);
  211. assert(fd > 0);
  212. paths_.push_back(path);
  213. fds_.push_back(fd);
  214. }
  215. }
  216. }
  217. void CoocBuffer::AccumulateCoocs(const cooc_counts_t &coocs) {
  218. std::vector<std::vector<cooc_t>> bufs(fds_.size());
  219. for (const auto &cooc : coocs) {
  220. const int row_id = cooc.first >> 32;
  221. const int col_id = cooc.first & 0xffffffff;
  222. const float cnt = cooc.second;
  223. const int row_shard = row_id % num_shards_;
  224. const int row_off = row_id / num_shards_;
  225. const int col_shard = col_id % num_shards_;
  226. const int col_off = col_id / num_shards_;
  227. const int top_shard_idx = row_shard * num_shards_ + col_shard;
  228. bufs[top_shard_idx].push_back(cooc_t{row_off, col_off, cnt});
  229. const int bot_shard_idx = col_shard * num_shards_ + row_shard;
  230. bufs[bot_shard_idx].push_back(cooc_t{col_off, row_off, cnt});
  231. }
  232. // XXX TODO: lock
  233. for (int i = 0; i < static_cast<int>(fds_.size()); ++i) {
  234. int rv = pthread_mutex_lock(&writer_mutex_);
  235. assert(rv == 0);
  236. const int nbytes = bufs[i].size() * sizeof(cooc_t);
  237. int nwritten = write(fds_[i], bufs[i].data(), nbytes);
  238. assert(nwritten == nbytes);
  239. pthread_mutex_unlock(&writer_mutex_);
  240. }
  241. }
  242. void CoocBuffer::WriteShards() {
  243. for (int shard = 0; shard < static_cast<int>(fds_.size()); ++shard) {
  244. const int row_shard = shard / num_shards_;
  245. const int col_shard = shard % num_shards_;
  246. std::cout << "\rwriting shard " << (shard + 1) << "/"
  247. << (num_shards_ * num_shards_);
  248. std::flush(std::cout);
  249. // Construct the tf::Example proto. First, we add the global rows and
  250. // column that are present in the shard.
  251. tensorflow::Example example;
  252. auto &feature = *example.mutable_features()->mutable_feature();
  253. auto global_row = feature["global_row"].mutable_int64_list();
  254. auto global_col = feature["global_col"].mutable_int64_list();
  255. for (int i = 0; i < shard_size_; ++i) {
  256. global_row->add_value(row_shard + i * num_shards_);
  257. global_col->add_value(col_shard + i * num_shards_);
  258. }
  259. // Next we add co-occurrences as a sparse representation. Map the
  260. // co-occurrence counts that we've spooled off to disk: these are in
  261. // arbitrary order and may contain duplicates.
  262. const off_t nbytes = lseek(fds_[shard], 0, SEEK_END);
  263. cooc_t *coocs = static_cast<cooc_t*>(
  264. mmap(0, nbytes, PROT_READ | PROT_WRITE, MAP_SHARED, fds_[shard], 0));
  265. const int ncoocs = nbytes / sizeof(cooc_t);
  266. cooc_t* cur = coocs;
  267. cooc_t* end = coocs + ncoocs;
  268. auto sparse_value = feature["sparse_value"].mutable_float_list();
  269. auto sparse_local_row = feature["sparse_local_row"].mutable_int64_list();
  270. auto sparse_local_col = feature["sparse_local_col"].mutable_int64_list();
  271. std::sort(cur, end, [](const cooc_t &a, const cooc_t &b) {
  272. return a.row < b.row || (a.row == b.row && a.col < b.col);
  273. });
  274. // Accumulate the counts into the protocol buffer.
  275. int last_row = -1, last_col = -1;
  276. float count = 0;
  277. for (; cur != end; ++cur) {
  278. if (cur->row != last_row || cur->col != last_col) {
  279. if (last_row >= 0 && last_col >= 0) {
  280. sparse_local_row->add_value(last_row);
  281. sparse_local_col->add_value(last_col);
  282. sparse_value->add_value(count);
  283. }
  284. last_row = cur->row;
  285. last_col = cur->col;
  286. count = 0;
  287. }
  288. count += cur->cnt;
  289. }
  290. if (last_row >= 0 && last_col >= 0) {
  291. sparse_local_row->add_value(last_row);
  292. sparse_local_col->add_value(last_col);
  293. sparse_value->add_value(count);
  294. }
  295. munmap(coocs, nbytes);
  296. close(fds_[shard]);
  297. // Write the protocol buffer as a binary blob to disk.
  298. char filename[256];
  299. snprintf(filename, sizeof(filename), "shard-%03d-%03d.pb", row_shard,
  300. col_shard);
  301. const std::string path = output_dirname_ + "/" + filename;
  302. int fd = open(path.c_str(), O_WRONLY | O_TRUNC | O_CREAT, 0666);
  303. assert(fd != -1);
  304. google::protobuf::io::FileOutputStream fout(fd);
  305. example.SerializeToZeroCopyStream(&fout);
  306. fout.Close();
  307. // Remove the temporary file.
  308. unlink(paths_[shard].c_str());
  309. }
  310. std::cout << std::endl;
  311. }
  312. // Counts the co-occurrences in part of the file.
  313. class CoocCounter {
  314. public:
  315. CoocCounter(const std::string &input_filename, const off_t start,
  316. const off_t end, const int window_size,
  317. const std::unordered_map<std::string, int> &token_to_id_map,
  318. CoocBuffer *coocbuf)
  319. : fin_(input_filename, std::ifstream::ate),
  320. start_(start),
  321. end_(end),
  322. window_size_(window_size),
  323. token_to_id_map_(token_to_id_map),
  324. coocbuf_(coocbuf),
  325. marginals_(token_to_id_map.size()) {}
  326. // PTthreads-friendly thunk to Count.
  327. static void* Run(void* param) {
  328. CoocCounter* self = static_cast<CoocCounter*>(param);
  329. self->Count();
  330. return nullptr;
  331. }
  332. // Counts the co-occurrences.
  333. void Count();
  334. const std::vector<double>& Marginals() const { return marginals_; }
  335. protected:
  336. // The input stream.
  337. std::ifstream fin_;
  338. // The range of the file to which this counter should attend.
  339. const off_t start_;
  340. const off_t end_;
  341. // The window size for computing co-occurrences.
  342. const int window_size_;
  343. // A reference to the mapping from tokens to IDs.
  344. const std::unordered_map<std::string, int> &token_to_id_map_;
  345. // The buffer into which counts are to be accumulated.
  346. CoocBuffer* coocbuf_;
  347. // The marginal counts accumulated by this counter.
  348. std::vector<double> marginals_;
  349. };
  350. void CoocCounter::Count() {
  351. const int max_coocs_size = 16 * 1024 * 1024;
  352. // A buffer of co-occurrence counts that we'll periodically sort into
  353. // shards.
  354. cooc_counts_t coocs;
  355. fin_.seekg(start_);
  356. int nlines = 0;
  357. for (off_t filepos = start_; filepos < end_ && !fin_.eof(); filepos = fin_.tellg()) {
  358. // Buffer a single sentence.
  359. std::vector<int> sentence;
  360. bool eos;
  361. do {
  362. std::string word;
  363. eos = NextWord(fin_, &word);
  364. auto it = token_to_id_map_.find(word);
  365. if (it != token_to_id_map_.end()) sentence.push_back(it->second);
  366. } while (!eos);
  367. // Generate the co-occurrences for the sentence.
  368. for (int pos = 0; pos < static_cast<int>(sentence.size()); ++pos) {
  369. const int left_id = sentence[pos];
  370. const int window_extent =
  371. std::min(static_cast<int>(sentence.size()) - pos, 1 + window_size_);
  372. for (int off = 1; off < window_extent; ++off) {
  373. const int right_id = sentence[pos + off];
  374. const double count = 1.0 / static_cast<double>(off);
  375. const long long lo = std::min(left_id, right_id);
  376. const long long hi = std::max(left_id, right_id);
  377. const long long key = (hi << 32) | lo;
  378. coocs[key] += count;
  379. marginals_[left_id] += count;
  380. marginals_[right_id] += count;
  381. }
  382. marginals_[left_id] += 1.0;
  383. const long long key = (static_cast<long long>(left_id) << 32) |
  384. static_cast<long long>(left_id);
  385. coocs[key] += 0.5;
  386. }
  387. // Periodically flush the co-occurrences to disk.
  388. if (coocs.size() > max_coocs_size) {
  389. coocbuf_->AccumulateCoocs(coocs);
  390. coocs.clear();
  391. }
  392. if (start_ == 0 && ++nlines % 1000 == 0) {
  393. const double pct = 100.0 * filepos / end_;
  394. fprintf(stdout, "\rComputing co-occurrences: %0.1f%% complete...", pct);
  395. std::flush(std::cout);
  396. }
  397. }
  398. // Accumulate anything we haven't flushed yet.
  399. coocbuf_->AccumulateCoocs(coocs);
  400. if (start_ == 0) std::cout << "done." << std::endl;
  401. }
  402. void WriteMarginals(const std::vector<double> &marginals,
  403. const std::string &output_dirname) {
  404. for (const std::string filename : {"row_sums.txt", "col_sums.txt"}) {
  405. std::ofstream fout(output_dirname + "/" + filename);
  406. fout.setf(std::ios::fixed);
  407. for (double sum : marginals) fout << sum << std::endl;
  408. }
  409. }
  410. int main(int argc, char *argv[]) {
  411. std::string input_filename;
  412. std::string vocab_filename;
  413. std::string output_dirname;
  414. bool generate_vocab = true;
  415. int max_vocab_size = 100 * 1024;
  416. int min_vocab_count = 5;
  417. int window_size = 10;
  418. int shard_size = 4096;
  419. int num_threads = 4;
  420. for (int i = 1; i < argc; ++i) {
  421. std::string arg(argv[i]);
  422. if (arg == "--vocab") {
  423. if (++i >= argc) goto argmissing;
  424. generate_vocab = false;
  425. vocab_filename = argv[i];
  426. } else if (arg == "--max_vocab") {
  427. if (++i >= argc) goto argmissing;
  428. if ((max_vocab_size = atoi(argv[i])) <= 0) goto badarg;
  429. } else if (arg == "--min_count") {
  430. if (++i >= argc) goto argmissing;
  431. if ((min_vocab_count = atoi(argv[i])) <= 0) goto badarg;
  432. } else if (arg == "--window_size") {
  433. if (++i >= argc) goto argmissing;
  434. if ((window_size = atoi(argv[i])) <= 0) goto badarg;
  435. } else if (arg == "--input") {
  436. if (++i >= argc) goto argmissing;
  437. input_filename = argv[i];
  438. } else if (arg == "--output_dir") {
  439. if (++i >= argc) goto argmissing;
  440. output_dirname = argv[i];
  441. } else if (arg == "--shard_size") {
  442. if (++i >= argc) goto argmissing;
  443. shard_size = atoi(argv[i]);
  444. } else if (arg == "--num_threads") {
  445. if (++i >= argc) goto argmissing;
  446. num_threads = atoi(argv[i]);
  447. } else if (arg == "--help") {
  448. std::cout << usage << std::endl;
  449. return 0;
  450. } else {
  451. std::cerr << "unknown arg '" << arg << "'; try --help?" << std::endl;
  452. return 2;
  453. }
  454. continue;
  455. badarg:
  456. std::cerr << "'" << argv[i] << "' is not a valid value for '" << arg
  457. << "'; try --help?" << std::endl;
  458. return 2;
  459. argmissing:
  460. std::cerr << arg << " requires an argument; try --help?" << std::endl;
  461. }
  462. if (input_filename.empty()) {
  463. std::cerr << "please specify the input text with '--input'; try --help?"
  464. << std::endl;
  465. return 2;
  466. }
  467. if (output_dirname.empty()) {
  468. std::cerr << "please specify the output directory with '--output_dir'"
  469. << std::endl;
  470. return 2;
  471. }
  472. struct stat sb;
  473. if (lstat(output_dirname.c_str(), &sb) != 0 || !S_ISDIR(sb.st_mode)) {
  474. std::cerr << "output directory '" << output_dirname
  475. << "' does not exist of is not a directory." << std::endl;
  476. return 1;
  477. }
  478. if (lstat(input_filename.c_str(), &sb) != 0 || !S_ISREG(sb.st_mode)) {
  479. std::cerr << "input file '" << input_filename
  480. << "' does not exist or is not a file." << std::endl;
  481. return 1;
  482. }
  483. // The total size of the input.
  484. const off_t input_size = sb.st_size;
  485. const std::vector<std::string> vocab =
  486. generate_vocab ? CreateVocabulary(input_filename, shard_size,
  487. min_vocab_count, max_vocab_size)
  488. : ReadVocabulary(vocab_filename);
  489. if (!vocab.size()) {
  490. std::cerr << "Empty vocabulary." << std::endl;
  491. return 1;
  492. }
  493. std::cout << "Generating Swivel co-occurrence data into " << output_dirname
  494. << std::endl;
  495. std::cout << "Shard size: " << shard_size << "x" << shard_size << std::endl;
  496. std::cout << "Vocab size: " << vocab.size() << std::endl;
  497. // Write the vocabulary files into the output directory.
  498. WriteVocabulary(vocab, output_dirname);
  499. const int num_shards = vocab.size() / shard_size;
  500. CoocBuffer coocbuf(output_dirname, num_shards, shard_size);
  501. // Build a mapping from the token to its position in the vocabulary file.
  502. std::unordered_map<std::string, int> token_to_id_map;
  503. for (int i = 0; i < static_cast<int>(vocab.size()); ++i)
  504. token_to_id_map[vocab[i]] = i;
  505. // Compute the co-occurrences
  506. std::vector<pthread_t> threads;
  507. std::vector<CoocCounter*> counters;
  508. const off_t nbytes_per_thread = input_size / num_threads;
  509. std::cout << "Running " << num_threads << " threads, each on "
  510. << nbytes_per_thread << " bytes" << std::endl;
  511. pthread_attr_t attr;
  512. if (pthread_attr_init(&attr) != 0) {
  513. std::cerr << "unable to initalize pthreads" << std::endl;
  514. return 1;
  515. }
  516. for (int i = 0; i < num_threads; ++i) {
  517. // We could make this smarter and look around for newlines. But
  518. // realistically that's not going to change things much.
  519. const off_t start = i * nbytes_per_thread;
  520. const off_t end =
  521. i < num_threads - 1 ? (i + 1) * nbytes_per_thread : input_size;
  522. CoocCounter *counter = new CoocCounter(
  523. input_filename, start, end, window_size, token_to_id_map, &coocbuf);
  524. counters.push_back(counter);
  525. pthread_t thread;
  526. pthread_create(&thread, &attr, CoocCounter::Run, counter);
  527. threads.push_back(thread);
  528. }
  529. // Wait for threads to finish and collect marginals.
  530. std::vector<double> marginals(vocab.size());
  531. for (int i = 0; i < num_threads; ++i) {
  532. pthread_join(threads[i], 0);
  533. const std::vector<double>& counter_marginals = counters[i]->Marginals();
  534. for (int j = 0; j < static_cast<int>(vocab.size()); ++j)
  535. marginals[j] += counter_marginals[j];
  536. delete counters[i];
  537. }
  538. std::cout << "writing marginals..." << std::endl;
  539. WriteMarginals(marginals, output_dirname);
  540. std::cout << "writing shards..." << std::endl;
  541. coocbuf.WriteShards();
  542. return 0;
  543. }