fastprep.cc 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693
  1. /* -*- Mode: C++ -*- */
  2. /*
  3. * Copyright 2016 Google Inc. All Rights Reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. /*
  18. * This program starts with a text file (and optionally a vocabulary file) and
  19. * computes co-occurrence statistics. It emits output in a format that can be
  20. * consumed by the "swivel" program. It's functionally equivalent to "prep.py",
  21. * but works much more quickly.
  22. */
  23. #include <assert.h>
  24. #include <fcntl.h>
  25. #include <stdio.h>
  26. #include <sys/mman.h>
  27. #include <sys/stat.h>
  28. #include <unistd.h>
  29. #include <algorithm>
  30. #include <fstream>
  31. #include <iomanip>
  32. #include <iostream>
  33. #include <map>
  34. #include <mutex>
  35. #include <string>
  36. #include <thread>
  37. #include <tuple>
  38. #include <unordered_map>
  39. #include <vector>
  40. #include "google/protobuf/io/zero_copy_stream_impl.h"
  41. #include "tensorflow/core/example/example.pb.h"
  42. #include "tensorflow/core/example/feature.pb.h"
  43. static const char usage[] = R"(
  44. Prepares a corpus for processing by Swivel.
  45. Usage:
  46. prep --output_dir <output-dir> --input <text-file>
  47. Options:
  48. --input <filename>
  49. The input text.
  50. --output_dir <directory>
  51. Specifies the output directory where the various Swivel data
  52. files should be placed. This directory must exist.
  53. --shard_size <int>
  54. Specifies the shard size; default 4096.
  55. --min_count <int>
  56. The minimum number of times a word should appear to be included in the
  57. generated vocabulary; default 5. (Ignored if --vocab is used.)
  58. --max_vocab <int>
  59. The maximum vocabulary size to generate from the input corpus; default
  60. 102,400. (Ignored if --vocab is used.)
  61. --vocab <filename>
  62. Use the specified unigram vocabulary instead of generating
  63. it from the corpus.
  64. --window_size <int>
  65. Specifies the window size for computing co-occurrence stats;
  66. default 10.
  67. --num_threads <int>
  68. The number of workers to calculate the co-occurrence matrix;
  69. default 4.
  70. )";
  71. struct cooc_t {
  72. int row;
  73. int col;
  74. float cnt;
  75. };
  76. typedef std::map<long long, float> cooc_counts_t;
  77. // Retrieves the next word from the input stream, treating words as simply being
  78. // delimited by whitespace. Returns true if this is the end of a "sentence";
  79. // i.e., a newline.
  80. bool NextWord(std::ifstream &fin, std::string* word) {
  81. std::string buf;
  82. char c;
  83. if (fin.eof()) {
  84. word->erase();
  85. return true;
  86. }
  87. // Skip leading whitespace.
  88. do {
  89. c = fin.get();
  90. } while (!fin.eof() && std::isspace(c));
  91. if (fin.eof()) {
  92. word->erase();
  93. return true;
  94. }
  95. // Read the next word.
  96. do {
  97. buf += c;
  98. c = fin.get();
  99. } while (!fin.eof() && !std::isspace(c));
  100. *word = buf;
  101. if (c == '\n' || fin.eof()) return true;
  102. // Skip trailing whitespace.
  103. do {
  104. c = fin.get();
  105. } while (!fin.eof() && std::isspace(c));
  106. if (fin.eof()) return true;
  107. fin.unget();
  108. return false;
  109. }
  110. // Creates a vocabulary from the most frequent terms in the input file.
  111. std::vector<std::string> CreateVocabulary(const std::string input_filename,
  112. const int shard_size,
  113. const int min_vocab_count,
  114. const int max_vocab_size) {
  115. std::vector<std::string> vocab;
  116. // Count all the distinct tokens in the file. (XXX this will eventually
  117. // consume all memory and should be re-written to periodically trim the data.)
  118. std::unordered_map<std::string, long long> counts;
  119. std::ifstream fin(input_filename, std::ifstream::ate);
  120. if (!fin) {
  121. std::cerr << "couldn't read input file '" << input_filename << "'"
  122. << std::endl;
  123. return vocab;
  124. }
  125. const auto input_size = fin.tellg();
  126. fin.seekg(0);
  127. long long ntokens = 0;
  128. while (!fin.eof()) {
  129. std::string word;
  130. NextWord(fin, &word);
  131. counts[word] += 1;
  132. if (++ntokens % 1000000 == 0) {
  133. const float pct = 100.0 * static_cast<float>(fin.tellg()) / input_size;
  134. fprintf(stdout, "\rComputing vocabulary: %0.1f%% complete...", pct);
  135. std::flush(std::cout);
  136. }
  137. }
  138. std::cout << counts.size() << " distinct tokens" << std::endl;
  139. // Sort the vocabulary from most frequent to least frequent.
  140. std::vector<std::pair<std::string, long long>> buf;
  141. std::copy(counts.begin(), counts.end(), std::back_inserter(buf));
  142. std::sort(buf.begin(), buf.end(),
  143. [](const std::pair<std::string, long long> &a,
  144. const std::pair<std::string, long long> &b) {
  145. return b.second < a.second;
  146. });
  147. // Truncate to the maximum vocabulary size
  148. if (static_cast<int>(buf.size()) > max_vocab_size) buf.resize(max_vocab_size);
  149. if (buf.empty()) return vocab;
  150. // Eliminate rare tokens and truncate to a size modulo the shard size.
  151. int vocab_size = buf.size();
  152. while (vocab_size > 0 && buf[vocab_size - 1].second < min_vocab_count)
  153. --vocab_size;
  154. vocab_size -= vocab_size % shard_size;
  155. if (static_cast<int>(buf.size()) > vocab_size) buf.resize(vocab_size);
  156. // Copy out the tokens.
  157. for (const auto& pair : buf) vocab.push_back(pair.first);
  158. return vocab;
  159. }
  160. std::vector<std::string> ReadVocabulary(const std::string vocab_filename) {
  161. std::vector<std::string> vocab;
  162. std::ifstream fin(vocab_filename);
  163. int index = 0;
  164. for (std::string token; std::getline(fin, token); ++index) {
  165. auto n = token.find('\t');
  166. if (n != std::string::npos) token = token.substr(n);
  167. vocab.push_back(token);
  168. }
  169. return vocab;
  170. }
  171. void WriteVocabulary(const std::vector<std::string> &vocab,
  172. const std::string &output_dirname) {
  173. for (const std::string filename : {"row_vocab.txt", "col_vocab.txt"}) {
  174. std::ofstream fout(output_dirname + "/" + filename);
  175. for (const auto &token : vocab) fout << token << std::endl;
  176. }
  177. }
  178. // Manages accumulation of co-occurrence data into temporary disk buffer files.
  179. class CoocBuffer {
  180. public:
  181. CoocBuffer(const std::string &output_dirname, const int num_shards,
  182. const int shard_size);
  183. // Accumulate the co-occurrence counts to the buffer.
  184. void AccumulateCoocs(const cooc_counts_t &coocs);
  185. // Read the buffer to produce shard files.
  186. void WriteShards();
  187. protected:
  188. // The output directory. Also used for temporary buffer files.
  189. const std::string output_dirname_;
  190. // The number of row/column shards.
  191. const int num_shards_;
  192. // The number of elements per shard.
  193. const int shard_size_;
  194. // Parallel arrays of temporary file paths and file descriptors.
  195. std::vector<std::string> paths_;
  196. std::vector<int> fds_;
  197. // Ensures that only one buffer file is getting written at a time.
  198. std::mutex writer_mutex_;
  199. };
  200. CoocBuffer::CoocBuffer(const std::string &output_dirname, const int num_shards,
  201. const int shard_size)
  202. : output_dirname_(output_dirname),
  203. num_shards_(num_shards),
  204. shard_size_(shard_size) {
  205. for (int row = 0; row < num_shards_; ++row) {
  206. for (int col = 0; col < num_shards_; ++col) {
  207. char filename[256];
  208. sprintf(filename, "shard-%03d-%03d.tmp", row, col);
  209. std::string path = output_dirname + "/" + filename;
  210. int fd = open(path.c_str(), O_RDWR | O_CREAT | O_TRUNC, 0666);
  211. assert(fd > 0);
  212. paths_.push_back(path);
  213. fds_.push_back(fd);
  214. }
  215. }
  216. }
  217. void CoocBuffer::AccumulateCoocs(const cooc_counts_t &coocs) {
  218. std::vector<std::vector<cooc_t>> bufs(fds_.size());
  219. for (const auto &cooc : coocs) {
  220. const int row_id = cooc.first >> 32;
  221. const int col_id = cooc.first & 0xffffffff;
  222. const float cnt = cooc.second;
  223. const int row_shard = row_id % num_shards_;
  224. const int row_off = row_id / num_shards_;
  225. const int col_shard = col_id % num_shards_;
  226. const int col_off = col_id / num_shards_;
  227. const int top_shard_idx = row_shard * num_shards_ + col_shard;
  228. bufs[top_shard_idx].push_back(cooc_t{row_off, col_off, cnt});
  229. const int bot_shard_idx = col_shard * num_shards_ + row_shard;
  230. bufs[bot_shard_idx].push_back(cooc_t{col_off, row_off, cnt});
  231. }
  232. for (int i = 0; i < static_cast<int>(fds_.size()); ++i) {
  233. std::lock_guard<std::mutex> rv(writer_mutex_);
  234. const int nbytes = bufs[i].size() * sizeof(cooc_t);
  235. int nwritten = write(fds_[i], bufs[i].data(), nbytes);
  236. assert(nwritten == nbytes);
  237. }
  238. }
  239. void CoocBuffer::WriteShards() {
  240. for (int shard = 0; shard < static_cast<int>(fds_.size()); ++shard) {
  241. const int row_shard = shard / num_shards_;
  242. const int col_shard = shard % num_shards_;
  243. std::cout << "\rwriting shard " << (shard + 1) << "/"
  244. << (num_shards_ * num_shards_);
  245. std::flush(std::cout);
  246. // Construct the tf::Example proto. First, we add the global rows and
  247. // column that are present in the shard.
  248. tensorflow::Example example;
  249. auto &feature = *example.mutable_features()->mutable_feature();
  250. auto global_row = feature["global_row"].mutable_int64_list();
  251. auto global_col = feature["global_col"].mutable_int64_list();
  252. for (int i = 0; i < shard_size_; ++i) {
  253. global_row->add_value(row_shard + i * num_shards_);
  254. global_col->add_value(col_shard + i * num_shards_);
  255. }
  256. // Next we add co-occurrences as a sparse representation. Map the
  257. // co-occurrence counts that we've spooled off to disk: these are in
  258. // arbitrary order and may contain duplicates.
  259. const off_t nbytes = lseek(fds_[shard], 0, SEEK_END);
  260. cooc_t *coocs = static_cast<cooc_t*>(
  261. mmap(0, nbytes, PROT_READ | PROT_WRITE, MAP_SHARED, fds_[shard], 0));
  262. const int ncoocs = nbytes / sizeof(cooc_t);
  263. cooc_t* cur = coocs;
  264. cooc_t* end = coocs + ncoocs;
  265. auto sparse_value = feature["sparse_value"].mutable_float_list();
  266. auto sparse_local_row = feature["sparse_local_row"].mutable_int64_list();
  267. auto sparse_local_col = feature["sparse_local_col"].mutable_int64_list();
  268. std::sort(cur, end, [](const cooc_t &a, const cooc_t &b) {
  269. return a.row < b.row || (a.row == b.row && a.col < b.col);
  270. });
  271. // Accumulate the counts into the protocol buffer.
  272. int last_row = -1, last_col = -1;
  273. float count = 0;
  274. for (; cur != end; ++cur) {
  275. if (cur->row != last_row || cur->col != last_col) {
  276. if (last_row >= 0 && last_col >= 0) {
  277. sparse_local_row->add_value(last_row);
  278. sparse_local_col->add_value(last_col);
  279. sparse_value->add_value(count);
  280. }
  281. last_row = cur->row;
  282. last_col = cur->col;
  283. count = 0;
  284. }
  285. count += cur->cnt;
  286. }
  287. if (last_row >= 0 && last_col >= 0) {
  288. sparse_local_row->add_value(last_row);
  289. sparse_local_col->add_value(last_col);
  290. sparse_value->add_value(count);
  291. }
  292. munmap(coocs, nbytes);
  293. close(fds_[shard]);
  294. if (sparse_local_row->value_size() * 8 >= (64 << 20)) {
  295. std::cout << "Warning: you are likely to catch protobuf parsing errors "
  296. "in TF 1.0 and older because the shard is too fat (>= 64MiB); see "
  297. << std::endl <<
  298. "kDefaultTotalBytesLimit in src/google/protobuf/io/coded_stream.h "
  299. " changed in protobuf/commit/5a76e633ea9b5adb215e93fdc11e1c0c08b3fc74"
  300. << std::endl <<
  301. "https://github.com/tensorflow/tensorflow/issues/7311"
  302. << std::endl <<
  303. "Consider increasing the number of shards.";
  304. }
  305. // Write the protocol buffer as a binary blob to disk.
  306. const int filename_max_size = 4096;
  307. std::unique_ptr<char[]> filename(new char[filename_max_size]);
  308. snprintf(filename.get(), filename_max_size, "shard-%03d-%03d.pb", row_shard,
  309. col_shard);
  310. const std::string path = output_dirname_ + "/" + filename.get();
  311. int fd = open(path.c_str(), O_WRONLY | O_TRUNC | O_CREAT, 0666);
  312. assert(fd != -1);
  313. google::protobuf::io::FileOutputStream fout(fd);
  314. example.SerializeToZeroCopyStream(&fout);
  315. fout.Close();
  316. // Remove the temporary file.
  317. unlink(paths_[shard].c_str());
  318. }
  319. std::cout << std::endl;
  320. }
  321. // Counts the co-occurrences in part of the file.
  322. class CoocCounter {
  323. public:
  324. CoocCounter(const std::string &input_filename, const off_t start,
  325. const off_t end, const int window_size,
  326. const std::unordered_map<std::string, int> &token_to_id_map,
  327. CoocBuffer *coocbuf)
  328. : fin_(input_filename, std::ifstream::ate),
  329. start_(start),
  330. end_(end),
  331. window_size_(window_size),
  332. token_to_id_map_(token_to_id_map),
  333. coocbuf_(coocbuf),
  334. marginals_(token_to_id_map.size()) {}
  335. // PTthreads-friendly thunk to Count.
  336. static void* Run(void* param) {
  337. CoocCounter* self = static_cast<CoocCounter*>(param);
  338. self->Count();
  339. return nullptr;
  340. }
  341. // Counts the co-occurrences.
  342. void Count();
  343. const std::vector<double>& Marginals() const { return marginals_; }
  344. protected:
  345. // The input stream.
  346. std::ifstream fin_;
  347. // The range of the file to which this counter should attend.
  348. const off_t start_;
  349. const off_t end_;
  350. // The window size for computing co-occurrences.
  351. const int window_size_;
  352. // A reference to the mapping from tokens to IDs.
  353. const std::unordered_map<std::string, int> &token_to_id_map_;
  354. // The buffer into which counts are to be accumulated.
  355. CoocBuffer* coocbuf_;
  356. // The marginal counts accumulated by this counter.
  357. std::vector<double> marginals_;
  358. };
  359. void CoocCounter::Count() {
  360. const int max_coocs_size = 16 * 1024 * 1024;
  361. // A buffer of co-occurrence counts that we'll periodically sort into
  362. // shards.
  363. cooc_counts_t coocs;
  364. fin_.seekg(start_);
  365. int nlines = 0;
  366. for (off_t filepos = start_; filepos < end_ && !fin_.eof(); filepos = fin_.tellg()) {
  367. // Buffer a single sentence.
  368. std::vector<int> sentence;
  369. bool eos;
  370. do {
  371. std::string word;
  372. eos = NextWord(fin_, &word);
  373. auto it = token_to_id_map_.find(word);
  374. if (it != token_to_id_map_.end()) sentence.push_back(it->second);
  375. } while (!eos);
  376. // Generate the co-occurrences for the sentence.
  377. for (int pos = 0; pos < static_cast<int>(sentence.size()); ++pos) {
  378. const int left_id = sentence[pos];
  379. const int window_extent =
  380. std::min(static_cast<int>(sentence.size()) - pos, 1 + window_size_);
  381. for (int off = 1; off < window_extent; ++off) {
  382. const int right_id = sentence[pos + off];
  383. const double count = 1.0 / static_cast<double>(off);
  384. const long long lo = std::min(left_id, right_id);
  385. const long long hi = std::max(left_id, right_id);
  386. const long long key = (hi << 32) | lo;
  387. coocs[key] += count;
  388. marginals_[left_id] += count;
  389. marginals_[right_id] += count;
  390. }
  391. marginals_[left_id] += 1.0;
  392. const long long key = (static_cast<long long>(left_id) << 32) |
  393. static_cast<long long>(left_id);
  394. coocs[key] += 0.5;
  395. }
  396. // Periodically flush the co-occurrences to disk.
  397. if (coocs.size() > max_coocs_size) {
  398. coocbuf_->AccumulateCoocs(coocs);
  399. coocs.clear();
  400. }
  401. if (start_ == 0 && ++nlines % 1000 == 0) {
  402. const double pct = 100.0 * filepos / end_;
  403. fprintf(stdout, "\rComputing co-occurrences: %0.1f%% complete...", pct);
  404. std::flush(std::cout);
  405. }
  406. }
  407. // Accumulate anything we haven't flushed yet.
  408. coocbuf_->AccumulateCoocs(coocs);
  409. if (start_ == 0) std::cout << "done." << std::endl;
  410. }
  411. void WriteMarginals(const std::vector<double> &marginals,
  412. const std::string &output_dirname) {
  413. for (const std::string filename : {"row_sums.txt", "col_sums.txt"}) {
  414. std::ofstream fout(output_dirname + "/" + filename);
  415. fout.setf(std::ios::fixed);
  416. for (double sum : marginals) fout << sum << std::endl;
  417. }
  418. }
  419. int main(int argc, char *argv[]) {
  420. std::string input_filename;
  421. std::string vocab_filename;
  422. std::string output_dirname;
  423. bool generate_vocab = true;
  424. int max_vocab_size = 100 * 1024;
  425. int min_vocab_count = 5;
  426. int window_size = 10;
  427. int shard_size = 4096;
  428. int num_threads = 4;
  429. for (int i = 1; i < argc; ++i) {
  430. std::string arg(argv[i]);
  431. if (arg == "--vocab") {
  432. if (++i >= argc) goto argmissing;
  433. generate_vocab = false;
  434. vocab_filename = argv[i];
  435. } else if (arg == "--max_vocab") {
  436. if (++i >= argc) goto argmissing;
  437. if ((max_vocab_size = atoi(argv[i])) <= 0) goto badarg;
  438. } else if (arg == "--min_count") {
  439. if (++i >= argc) goto argmissing;
  440. if ((min_vocab_count = atoi(argv[i])) <= 0) goto badarg;
  441. } else if (arg == "--window_size") {
  442. if (++i >= argc) goto argmissing;
  443. if ((window_size = atoi(argv[i])) <= 0) goto badarg;
  444. } else if (arg == "--input") {
  445. if (++i >= argc) goto argmissing;
  446. input_filename = argv[i];
  447. } else if (arg == "--output_dir") {
  448. if (++i >= argc) goto argmissing;
  449. output_dirname = argv[i];
  450. } else if (arg == "--shard_size") {
  451. if (++i >= argc) goto argmissing;
  452. shard_size = atoi(argv[i]);
  453. } else if (arg == "--num_threads") {
  454. if (++i >= argc) goto argmissing;
  455. num_threads = atoi(argv[i]);
  456. } else if (arg == "--help") {
  457. std::cout << usage << std::endl;
  458. return 0;
  459. } else {
  460. std::cerr << "unknown arg '" << arg << "'; try --help?" << std::endl;
  461. return 2;
  462. }
  463. continue;
  464. badarg:
  465. std::cerr << "'" << argv[i] << "' is not a valid value for '" << arg
  466. << "'; try --help?" << std::endl;
  467. return 2;
  468. argmissing:
  469. std::cerr << arg << " requires an argument; try --help?" << std::endl;
  470. }
  471. if (input_filename.empty()) {
  472. std::cerr << "please specify the input text with '--input'; try --help?"
  473. << std::endl;
  474. return 2;
  475. }
  476. if (output_dirname.empty()) {
  477. std::cerr << "please specify the output directory with '--output_dir'"
  478. << std::endl;
  479. return 2;
  480. }
  481. struct stat sb;
  482. if (lstat(output_dirname.c_str(), &sb) != 0 || !S_ISDIR(sb.st_mode)) {
  483. if (mkdir(output_dirname.c_str(), 0755) != 0) {
  484. std::cerr << "output directory '" << output_dirname
  485. << "' does not exist or is not a directory." << std::endl;
  486. return 1;
  487. }
  488. }
  489. if (lstat(input_filename.c_str(), &sb) != 0 || !S_ISREG(sb.st_mode)) {
  490. std::cerr << "input file '" << input_filename
  491. << "' does not exist or is not a file." << std::endl;
  492. return 1;
  493. }
  494. // The total size of the input.
  495. const off_t input_size = sb.st_size;
  496. const std::vector<std::string> vocab =
  497. generate_vocab ? CreateVocabulary(input_filename, shard_size,
  498. min_vocab_count, max_vocab_size)
  499. : ReadVocabulary(vocab_filename);
  500. if (!vocab.size()) {
  501. std::cerr << "Empty vocabulary." << std::endl;
  502. return 1;
  503. }
  504. std::cout << "Generating Swivel co-occurrence data into " << output_dirname
  505. << std::endl;
  506. std::cout << "Shard size: " << shard_size << "x" << shard_size << std::endl;
  507. std::cout << "Vocab size: " << vocab.size() << std::endl;
  508. // Write the vocabulary files into the output directory.
  509. WriteVocabulary(vocab, output_dirname);
  510. const int num_shards = vocab.size() / shard_size;
  511. CoocBuffer coocbuf(output_dirname, num_shards, shard_size);
  512. // Build a mapping from the token to its position in the vocabulary file.
  513. std::unordered_map<std::string, int> token_to_id_map;
  514. for (int i = 0; i < static_cast<int>(vocab.size()); ++i)
  515. token_to_id_map[vocab[i]] = i;
  516. // Compute the co-occurrences
  517. std::vector<std::thread> threads;
  518. threads.reserve(num_threads);
  519. std::vector<CoocCounter*> counters;
  520. const off_t nbytes_per_thread = input_size / num_threads;
  521. std::cout << "Running " << num_threads << " threads, each on "
  522. << nbytes_per_thread << " bytes" << std::endl;
  523. for (int i = 0; i < num_threads; ++i) {
  524. // We could make this smarter and look around for newlines. But
  525. // realistically that's not going to change things much.
  526. const off_t start = i * nbytes_per_thread;
  527. const off_t end =
  528. i < num_threads - 1 ? (i + 1) * nbytes_per_thread : input_size;
  529. CoocCounter *counter = new CoocCounter(
  530. input_filename, start, end, window_size, token_to_id_map, &coocbuf);
  531. counters.push_back(counter);
  532. threads.emplace_back(CoocCounter::Run, counter);
  533. }
  534. // Wait for threads to finish and collect marginals.
  535. std::vector<double> marginals(vocab.size());
  536. for (int i = 0; i < num_threads; ++i) {
  537. if (i > 0) {
  538. std::cout << "joining thread #" << (i + 1) << std::endl;
  539. }
  540. threads[i].join();
  541. const std::vector<double>& counter_marginals = counters[i]->Marginals();
  542. for (int j = 0; j < static_cast<int>(vocab.size()); ++j)
  543. marginals[j] += counter_marginals[j];
  544. delete counters[i];
  545. }
  546. std::cout << "writing marginals..." << std::endl;
  547. WriteMarginals(marginals, output_dirname);
  548. std::cout << "writing shards..." << std::endl;
  549. coocbuf.WriteShards();
  550. return 0;
  551. }