helpers.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718
  1. /*
  2. coding=utf-8
  3. Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  4. Licensed under the Apache License, Version 2.0 (the "License");
  5. you may not use this file except in compliance with the License.
  6. You may obtain a copy of the License at
  7. http://www.apache.org/licenses/LICENSE-2.0
  8. Unless required by applicable law or agreed to in writing, software
  9. distributed under the License is distributed on an "AS IS" BASIS,
  10. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  11. See the License for the specific language governing permissions and
  12. limitations under the License.
  13. */
  14. /* Helper methods for fast index mapping builds */
  15. #include <algorithm>
  16. #include <iostream>
  17. #include <limits>
  18. #include <math.h>
  19. #include <stdexcept>
  20. #include <pybind11/pybind11.h>
  21. #include <pybind11/numpy.h>
  22. #include <random>
  23. namespace py = pybind11;
  24. using namespace std;
  25. const int32_t LONG_SENTENCE_LEN = 512;
  26. void build_blending_indices(py::array_t<uint8_t>& dataset_index,
  27. py::array_t<int64_t>& dataset_sample_index,
  28. const py::array_t<double>& weights,
  29. const int32_t num_datasets,
  30. const int64_t size, const bool verbose) {
  31. /* Given multiple datasets and a weighting array, build samples
  32. such that it follows those wieghts.*/
  33. if (verbose) {
  34. std::cout << "> building indices for blendable datasets ..." << std::endl;
  35. }
  36. // Get the pointer access without the checks.
  37. auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
  38. auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
  39. auto weights_ptr = weights.unchecked<1>();
  40. // Initialize buffer for number of samples used for each dataset.
  41. int64_t current_samples[num_datasets];
  42. for(int64_t i = 0; i < num_datasets; ++i) {
  43. current_samples[i] = 0;
  44. }
  45. // For each sample:
  46. for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {
  47. // Determine where the max error in sampling is happening.
  48. auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
  49. int64_t max_error_index = 0;
  50. double max_error = weights_ptr[0] * sample_idx_double -
  51. static_cast<double>(current_samples[0]);
  52. for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
  53. double error = weights_ptr[dataset_idx] * sample_idx_double -
  54. static_cast<double>(current_samples[dataset_idx]);
  55. if (error > max_error) {
  56. max_error = error;
  57. max_error_index = dataset_idx;
  58. }
  59. }
  60. // Populate the indices.
  61. dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);
  62. dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];
  63. // Update the total samples.
  64. current_samples[max_error_index] += 1;
  65. }
  66. // print info
  67. if (verbose) {
  68. std::cout << " > sample ratios:" << std::endl;
  69. for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
  70. auto ratio = static_cast<double>(current_samples[dataset_idx]) /
  71. static_cast<double>(size);
  72. std::cout << " dataset " << dataset_idx << ", input: " <<
  73. weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
  74. }
  75. }
  76. }
  77. py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
  78. const py::array_t<int32_t>& doc_idx_,
  79. const int32_t seq_length,
  80. const int32_t num_epochs,
  81. const int64_t tokens_per_epoch) {
  82. /* Sample index (sample_idx) is used for gpt2 like dataset for which
  83. the documents are flattened and the samples are built based on this
  84. 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
  85. where [..., 0] contains the index into `doc_idx` and [..., 1] is the
  86. starting offset in that document.*/
  87. // Consistency checks.
  88. assert(seq_length > 1);
  89. assert(num_epochs > 0);
  90. assert(tokens_per_epoch > 1);
  91. // Remove bound checks.
  92. auto sizes = sizes_.unchecked<1>();
  93. auto doc_idx = doc_idx_.unchecked<1>();
  94. // Mapping and it's length (1D).
  95. int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
  96. int32_t* sample_idx = new int32_t[2*(num_samples+1)];
  97. cout << " using:" << endl << std::flush;
  98. cout << " number of documents: " <<
  99. doc_idx_.shape(0) / num_epochs << endl << std::flush;
  100. cout << " number of epochs: " << num_epochs <<
  101. endl << std::flush;
  102. cout << " sequence length: " << seq_length <<
  103. endl << std::flush;
  104. cout << " total number of samples: " << num_samples <<
  105. endl << std::flush;
  106. // Index into sample_idx.
  107. int64_t sample_index = 0;
  108. // Index into doc_idx.
  109. int64_t doc_idx_index = 0;
  110. // Begining offset for each document.
  111. int32_t doc_offset = 0;
  112. // Start with first document and no offset.
  113. sample_idx[2 * sample_index] = doc_idx_index;
  114. sample_idx[2 * sample_index + 1] = doc_offset;
  115. ++sample_index;
  116. while (sample_index <= num_samples) {
  117. // Start with a fresh sequence.
  118. int32_t remaining_seq_length = seq_length + 1;
  119. while (remaining_seq_length != 0) {
  120. // Get the document length.
  121. auto doc_id = doc_idx[doc_idx_index];
  122. auto doc_length = sizes[doc_id] - doc_offset;
  123. // And add it to the current sequence.
  124. remaining_seq_length -= doc_length;
  125. // If we have more than a full sequence, adjust offset and set
  126. // remaining length to zero so we return from the while loop.
  127. // Note that -1 here is for the same reason we have -1 in
  128. // `_num_epochs` calculations.
  129. if (remaining_seq_length <= 0) {
  130. doc_offset += (remaining_seq_length + doc_length - 1);
  131. remaining_seq_length = 0;
  132. } else {
  133. // Otherwise, start from the begining of the next document.
  134. ++doc_idx_index;
  135. doc_offset = 0;
  136. }
  137. }
  138. // Record the sequence.
  139. sample_idx[2 * sample_index] = doc_idx_index;
  140. sample_idx[2 * sample_index + 1] = doc_offset;
  141. ++sample_index;
  142. }
  143. // Method to deallocate memory.
  144. py::capsule free_when_done(sample_idx, [](void *mem_) {
  145. int32_t *mem = reinterpret_cast<int32_t*>(mem_);
  146. delete[] mem;
  147. });
  148. // Return the numpy array.
  149. const auto byte_size = sizeof(int32_t);
  150. return py::array(std::vector<int64_t>{num_samples+1, 2}, // shape
  151. {2*byte_size, byte_size}, // C-style contiguous strides
  152. sample_idx, // the data pointer
  153. free_when_done); // numpy array references
  154. }
  155. inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
  156. const int32_t max_length,
  157. std::mt19937& rand32_gen) {
  158. /* Training sample length. */
  159. if (short_seq_ratio == 0) {
  160. return max_length;
  161. }
  162. const auto random_number = rand32_gen();
  163. if ((random_number % short_seq_ratio) == 0) {
  164. return 2 + random_number % (max_length - 1);
  165. }
  166. return max_length;
  167. }
  168. template<typename DocIdx>
  169. py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
  170. const py::array_t<int32_t>& sizes_,
  171. const int32_t num_epochs,
  172. const uint64_t max_num_samples,
  173. const int32_t max_seq_length,
  174. const double short_seq_prob,
  175. const int32_t seed,
  176. const bool verbose,
  177. const int32_t min_num_sent) {
  178. /* Build a mapping of (start-index, end-index, sequence-length) where
  179. start and end index are the indices of the sentences in the sample
  180. and sequence-length is the target sequence length.
  181. */
  182. // Consistency checks.
  183. assert(num_epochs > 0);
  184. assert(max_seq_length > 1);
  185. assert(short_seq_prob >= 0.0);
  186. assert(short_seq_prob <= 1.0);
  187. assert(seed > 0);
  188. // Remove bound checks.
  189. auto docs = docs_.unchecked<1>();
  190. auto sizes = sizes_.unchecked<1>();
  191. // For efficiency, convert probability to ratio. Note: rand() generates int.
  192. int32_t short_seq_ratio = 0;
  193. if (short_seq_prob > 0) {
  194. short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
  195. }
  196. if (verbose) {
  197. const auto sent_start_index = docs[0];
  198. const auto sent_end_index = docs[docs_.shape(0) - 1];
  199. const auto num_sentences = sent_end_index - sent_start_index;
  200. cout << " using:" << endl << std::flush;
  201. cout << " number of documents: " << docs_.shape(0) - 1 <<
  202. endl << std::flush;
  203. cout << " sentences range: [" << sent_start_index <<
  204. ", " << sent_end_index << ")" << endl << std::flush;
  205. cout << " total number of sentences: " << num_sentences <<
  206. endl << std::flush;
  207. cout << " number of epochs: " << num_epochs <<
  208. endl << std::flush;
  209. cout << " maximum number of samples: " << max_num_samples <<
  210. endl << std::flush;
  211. cout << " maximum sequence length: " << max_seq_length <<
  212. endl << std::flush;
  213. cout << " short sequence probability: " << short_seq_prob <<
  214. endl << std::flush;
  215. cout << " short sequence ration (1/prob): " << short_seq_ratio <<
  216. endl << std::flush;
  217. cout << " seed: " << seed << endl <<
  218. std::flush;
  219. }
  220. // Mapping and it's length (1D).
  221. int64_t num_samples = -1;
  222. DocIdx* maps = NULL;
  223. // Perform two iterations, in the first iteration get the size
  224. // and allocate memory and in the second iteration populate the map.
  225. bool second = false;
  226. for (int32_t iteration=0; iteration<2; ++iteration) {
  227. // Set the seed so both iterations produce the same results.
  228. std::mt19937 rand32_gen(seed);
  229. // Set the flag on second iteration.
  230. second = (iteration == 1);
  231. // Counters:
  232. uint64_t empty_docs = 0;
  233. uint64_t one_sent_docs = 0;
  234. uint64_t long_sent_docs = 0;
  235. // Current map index.
  236. uint64_t map_index = 0;
  237. // For each epoch:
  238. for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
  239. if (map_index >= max_num_samples) {
  240. if (verbose && (!second)) {
  241. cout << " reached " << max_num_samples << " samples after "
  242. << epoch << " epochs ..." << endl << std::flush;
  243. }
  244. break;
  245. }
  246. // For each document:
  247. for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
  248. // Document sentences are in [sent_index_first, sent_index_last)
  249. const auto sent_index_first = docs[doc];
  250. const auto sent_index_last = docs[doc + 1];
  251. // At the begining of the document previous index is the
  252. // start index.
  253. auto prev_start_index = sent_index_first;
  254. // Remaining documents.
  255. auto num_remain_sent = sent_index_last - sent_index_first;
  256. // Some bookkeeping
  257. if ((epoch == 0) && (!second)) {
  258. if (num_remain_sent == 0) {
  259. ++empty_docs;
  260. }
  261. if (num_remain_sent == 1) {
  262. ++one_sent_docs;
  263. }
  264. }
  265. // Detect documents with long sentences.
  266. bool contains_long_sentence = false;
  267. if (num_remain_sent > 1) {
  268. for (auto sent_index=sent_index_first;
  269. sent_index < sent_index_last; ++sent_index) {
  270. if (sizes[sent_index] > LONG_SENTENCE_LEN){
  271. if ((epoch == 0) && (!second)) {
  272. ++long_sent_docs;
  273. }
  274. contains_long_sentence = true;
  275. break;
  276. }
  277. }
  278. }
  279. // If we have more than two sentences.
  280. if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
  281. // Set values.
  282. auto seq_len = int32_t{0};
  283. auto num_sent = int32_t{0};
  284. auto target_seq_len = get_target_sample_len(short_seq_ratio,
  285. max_seq_length,
  286. rand32_gen);
  287. // Loop through sentences.
  288. for (auto sent_index=sent_index_first;
  289. sent_index < sent_index_last; ++sent_index) {
  290. // Add the size and number of sentences.
  291. seq_len += sizes[sent_index];
  292. ++num_sent;
  293. --num_remain_sent;
  294. // If we have reached the target length.
  295. // and if not only one sentence is left in the document.
  296. // and if we have at least two sentneces.
  297. // and if we have reached end of the document.
  298. if (((seq_len >= target_seq_len) &&
  299. (num_remain_sent > 1) &&
  300. (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
  301. // Check for overflow.
  302. if ((3 * map_index + 2) >
  303. std::numeric_limits<int64_t>::max()) {
  304. cout << "number of samples exceeded maximum "
  305. << "allowed by type int64: "
  306. << std::numeric_limits<int64_t>::max()
  307. << endl;
  308. throw std::overflow_error("Number of samples");
  309. }
  310. // Populate the map.
  311. if (second) {
  312. const auto map_index_0 = 3 * map_index;
  313. maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
  314. maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
  315. maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len);
  316. }
  317. // Update indices / counters.
  318. ++map_index;
  319. prev_start_index = sent_index + 1;
  320. target_seq_len = get_target_sample_len(short_seq_ratio,
  321. max_seq_length,
  322. rand32_gen);
  323. seq_len = 0;
  324. num_sent = 0;
  325. }
  326. } // for (auto sent_index=sent_index_first; ...
  327. } // if (num_remain_sent > 1) {
  328. } // for (int doc=0; doc < num_docs; ++doc) {
  329. } // for (int epoch=0; epoch < num_epochs; ++epoch) {
  330. if (!second) {
  331. if (verbose) {
  332. cout << " number of empty documents: " << empty_docs <<
  333. endl << std::flush;
  334. cout << " number of documents with one sentence: " <<
  335. one_sent_docs << endl << std::flush;
  336. cout << " number of documents with long sentences: " <<
  337. long_sent_docs << endl << std::flush;
  338. cout << " will create mapping for " << map_index <<
  339. " samples" << endl << std::flush;
  340. }
  341. assert(maps == NULL);
  342. assert(num_samples < 0);
  343. maps = new DocIdx[3*map_index];
  344. num_samples = static_cast<int64_t>(map_index);
  345. }
  346. } // for (int iteration=0; iteration < 2; ++iteration) {
  347. // Shuffle.
  348. // We need a 64 bit random number generator as we might have more
  349. // than 2 billion samples.
  350. std::mt19937_64 rand64_gen(seed + 1);
  351. for (auto i=(num_samples - 1); i > 0; --i) {
  352. const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
  353. const auto i0 = 3 * i;
  354. const auto j0 = 3 * j;
  355. // Swap values.
  356. swap(maps[i0], maps[j0]);
  357. swap(maps[i0 + 1], maps[j0 + 1]);
  358. swap(maps[i0 + 2], maps[j0 + 2]);
  359. }
  360. // Method to deallocate memory.
  361. py::capsule free_when_done(maps, [](void *mem_) {
  362. DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
  363. delete[] mem;
  364. });
  365. // Return the numpy array.
  366. const auto byte_size = sizeof(DocIdx);
  367. return py::array(std::vector<int64_t>{num_samples, 3}, // shape
  368. {3*byte_size, byte_size}, // C-style contiguous strides
  369. maps, // the data pointer
  370. free_when_done); // numpy array references
  371. }
  372. py::array build_mapping(const py::array_t<int64_t>& docs_,
  373. const py::array_t<int>& sizes_,
  374. const int num_epochs,
  375. const uint64_t max_num_samples,
  376. const int max_seq_length,
  377. const double short_seq_prob,
  378. const int seed,
  379. const bool verbose,
  380. const int32_t min_num_sent) {
  381. if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
  382. if (verbose) {
  383. cout << " using uint64 for data mapping..." << endl << std::flush;
  384. }
  385. return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
  386. max_num_samples, max_seq_length,
  387. short_seq_prob, seed, verbose,
  388. min_num_sent);
  389. } else {
  390. if (verbose) {
  391. cout << " using uint32 for data mapping..." << endl << std::flush;
  392. }
  393. return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
  394. max_num_samples, max_seq_length,
  395. short_seq_prob, seed, verbose,
  396. min_num_sent);
  397. }
  398. }
  399. template<typename DocIdx>
  400. py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
  401. const py::array_t<int32_t>& sizes_,
  402. const py::array_t<int32_t>& titles_sizes_,
  403. const int32_t num_epochs,
  404. const uint64_t max_num_samples,
  405. const int32_t max_seq_length,
  406. const int32_t seed,
  407. const bool verbose,
  408. const bool use_one_sent_blocks) {
  409. /* Build a mapping of (start-index, end-index, sequence-length) where
  410. start and end index are the indices of the sentences in the sample
  411. and sequence-length is the target sequence length.
  412. */
  413. // Consistency checks.
  414. assert(num_epochs > 0);
  415. assert(max_seq_length > 1);
  416. assert(seed > 0);
  417. // Remove bound checks.
  418. auto docs = docs_.unchecked<1>();
  419. auto sizes = sizes_.unchecked<1>();
  420. auto titles_sizes = titles_sizes_.unchecked<1>();
  421. if (verbose) {
  422. const auto sent_start_index = docs[0];
  423. const auto sent_end_index = docs[docs_.shape(0) - 1];
  424. const auto num_sentences = sent_end_index - sent_start_index;
  425. cout << " using:" << endl << std::flush;
  426. cout << " number of documents: " << docs_.shape(0) - 1 <<
  427. endl << std::flush;
  428. cout << " sentences range: [" << sent_start_index <<
  429. ", " << sent_end_index << ")" << endl << std::flush;
  430. cout << " total number of sentences: " << num_sentences <<
  431. endl << std::flush;
  432. cout << " number of epochs: " << num_epochs <<
  433. endl << std::flush;
  434. cout << " maximum number of samples: " << max_num_samples <<
  435. endl << std::flush;
  436. cout << " maximum sequence length: " << max_seq_length <<
  437. endl << std::flush;
  438. cout << " seed: " << seed << endl <<
  439. std::flush;
  440. }
  441. // Mapping and its length (1D).
  442. int64_t num_samples = -1;
  443. DocIdx* maps = NULL;
  444. // Acceptable number of sentences per block.
  445. int min_num_sent = 2;
  446. if (use_one_sent_blocks) {
  447. min_num_sent = 1;
  448. }
  449. // Perform two iterations, in the first iteration get the size
  450. // and allocate memory and in the second iteration populate the map.
  451. bool second = false;
  452. for (int32_t iteration=0; iteration<2; ++iteration) {
  453. // Set the flag on second iteration.
  454. second = (iteration == 1);
  455. // Current map index.
  456. uint64_t map_index = 0;
  457. uint64_t empty_docs = 0;
  458. uint64_t one_sent_docs = 0;
  459. uint64_t long_sent_docs = 0;
  460. // For each epoch:
  461. for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
  462. // assign every block a unique id
  463. int32_t block_id = 0;
  464. if (map_index >= max_num_samples) {
  465. if (verbose && (!second)) {
  466. cout << " reached " << max_num_samples << " samples after "
  467. << epoch << " epochs ..." << endl << std::flush;
  468. }
  469. break;
  470. }
  471. // For each document:
  472. for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
  473. // Document sentences are in [sent_index_first, sent_index_last)
  474. const auto sent_index_first = docs[doc];
  475. const auto sent_index_last = docs[doc + 1];
  476. const auto target_seq_len = max_seq_length - titles_sizes[doc];
  477. // At the begining of the document previous index is the
  478. // start index.
  479. auto prev_start_index = sent_index_first;
  480. // Remaining documents.
  481. auto num_remain_sent = sent_index_last - sent_index_first;
  482. // Some bookkeeping
  483. if ((epoch == 0) && (!second)) {
  484. if (num_remain_sent == 0) {
  485. ++empty_docs;
  486. }
  487. if (num_remain_sent == 1) {
  488. ++one_sent_docs;
  489. }
  490. }
  491. // Detect documents with long sentences.
  492. bool contains_long_sentence = false;
  493. if (num_remain_sent >= min_num_sent) {
  494. for (auto sent_index=sent_index_first;
  495. sent_index < sent_index_last; ++sent_index) {
  496. if (sizes[sent_index] > LONG_SENTENCE_LEN){
  497. if ((epoch == 0) && (!second)) {
  498. ++long_sent_docs;
  499. }
  500. contains_long_sentence = true;
  501. break;
  502. }
  503. }
  504. }
  505. // If we have enough sentences and no long sentences.
  506. if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
  507. // Set values.
  508. auto seq_len = int32_t{0};
  509. auto num_sent = int32_t{0};
  510. // Loop through sentences.
  511. for (auto sent_index=sent_index_first;
  512. sent_index < sent_index_last; ++sent_index) {
  513. // Add the size and number of sentences.
  514. seq_len += sizes[sent_index];
  515. ++num_sent;
  516. --num_remain_sent;
  517. // If we have reached the target length.
  518. // and there are an acceptable number of sentences left
  519. // and if we have at least the minimum number of sentences.
  520. // or if we have reached end of the document.
  521. if (((seq_len >= target_seq_len) &&
  522. (num_remain_sent >= min_num_sent) &&
  523. (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
  524. // Populate the map.
  525. if (second) {
  526. const auto map_index_0 = 4 * map_index;
  527. // Each sample has 4 items: the starting sentence index, ending sentence index,
  528. // the index of the document from which the block comes (used for fetching titles)
  529. // and the unique id of the block (used for creating block indexes)
  530. maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
  531. maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
  532. maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
  533. maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
  534. }
  535. // Update indices / counters.
  536. ++map_index;
  537. ++block_id;
  538. prev_start_index = sent_index + 1;
  539. seq_len = 0;
  540. num_sent = 0;
  541. }
  542. } // for (auto sent_index=sent_index_first; ...
  543. } // if (num_remain_sent > 1) {
  544. } // for (int doc=0; doc < num_docs; ++doc) {
  545. } // for (int epoch=0; epoch < num_epochs; ++epoch) {
  546. if (!second) {
  547. if (verbose) {
  548. cout << " number of empty documents: " << empty_docs <<
  549. endl << std::flush;
  550. cout << " number of documents with one sentence: " <<
  551. one_sent_docs << endl << std::flush;
  552. cout << " number of documents with long sentences: " <<
  553. long_sent_docs << endl << std::flush;
  554. cout << " will create mapping for " << map_index <<
  555. " samples" << endl << std::flush;
  556. }
  557. assert(maps == NULL);
  558. assert(num_samples < 0);
  559. maps = new DocIdx[4*map_index];
  560. num_samples = static_cast<int64_t>(map_index);
  561. }
  562. } // for (int iteration=0; iteration < 2; ++iteration) {
  563. // Shuffle.
  564. // We need a 64 bit random number generator as we might have more
  565. // than 2 billion samples.
  566. std::mt19937_64 rand64_gen(seed + 1);
  567. for (auto i=(num_samples - 1); i > 0; --i) {
  568. const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
  569. const auto i0 = 4 * i;
  570. const auto j0 = 4 * j;
  571. // Swap values.
  572. swap(maps[i0], maps[j0]);
  573. swap(maps[i0 + 1], maps[j0 + 1]);
  574. swap(maps[i0 + 2], maps[j0 + 2]);
  575. swap(maps[i0 + 3], maps[j0 + 3]);
  576. }
  577. // Method to deallocate memory.
  578. py::capsule free_when_done(maps, [](void *mem_) {
  579. DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
  580. delete[] mem;
  581. });
  582. // Return the numpy array.
  583. const auto byte_size = sizeof(DocIdx);
  584. return py::array(std::vector<int64_t>{num_samples, 4}, // shape
  585. {4*byte_size, byte_size}, // C-style contiguous strides
  586. maps, // the data pointer
  587. free_when_done); // numpy array references
  588. }
  589. py::array build_blocks_mapping(const py::array_t<int64_t>& docs_,
  590. const py::array_t<int>& sizes_,
  591. const py::array_t<int>& titles_sizes_,
  592. const int num_epochs,
  593. const uint64_t max_num_samples,
  594. const int max_seq_length,
  595. const int seed,
  596. const bool verbose,
  597. const bool use_one_sent_blocks) {
  598. if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
  599. if (verbose) {
  600. cout << " using uint64 for data mapping..." << endl << std::flush;
  601. }
  602. return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_,
  603. num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
  604. } else {
  605. if (verbose) {
  606. cout << " using uint32 for data mapping..." << endl << std::flush;
  607. }
  608. return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_,
  609. num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
  610. }
  611. }
  612. PYBIND11_MODULE(helpers, m) {
  613. m.def("build_mapping", &build_mapping);
  614. m.def("build_blocks_mapping", &build_blocks_mapping);
  615. m.def("build_sample_idx", &build_sample_idx);
  616. m.def("build_blending_indices", &build_blending_indices);
  617. }