123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718 |
- /*
- coding=utf-8
- Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
- /* Helper methods for fast index mapping builds */
- #include <algorithm>
- #include <iostream>
- #include <limits>
- #include <math.h>
- #include <stdexcept>
- #include <pybind11/pybind11.h>
- #include <pybind11/numpy.h>
- #include <random>
- namespace py = pybind11;
- using namespace std;
- const int32_t LONG_SENTENCE_LEN = 512;
- void build_blending_indices(py::array_t<uint8_t>& dataset_index,
- py::array_t<int64_t>& dataset_sample_index,
- const py::array_t<double>& weights,
- const int32_t num_datasets,
- const int64_t size, const bool verbose) {
- /* Given multiple datasets and a weighting array, build samples
- such that it follows those wieghts.*/
- if (verbose) {
- std::cout << "> building indices for blendable datasets ..." << std::endl;
- }
- // Get the pointer access without the checks.
- auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
- auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
- auto weights_ptr = weights.unchecked<1>();
- // Initialize buffer for number of samples used for each dataset.
- int64_t current_samples[num_datasets];
- for(int64_t i = 0; i < num_datasets; ++i) {
- current_samples[i] = 0;
- }
- // For each sample:
- for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {
- // Determine where the max error in sampling is happening.
- auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
- int64_t max_error_index = 0;
- double max_error = weights_ptr[0] * sample_idx_double -
- static_cast<double>(current_samples[0]);
- for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
- double error = weights_ptr[dataset_idx] * sample_idx_double -
- static_cast<double>(current_samples[dataset_idx]);
- if (error > max_error) {
- max_error = error;
- max_error_index = dataset_idx;
- }
- }
- // Populate the indices.
- dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);
- dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];
- // Update the total samples.
- current_samples[max_error_index] += 1;
-
- }
- // print info
- if (verbose) {
- std::cout << " > sample ratios:" << std::endl;
- for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
- auto ratio = static_cast<double>(current_samples[dataset_idx]) /
- static_cast<double>(size);
- std::cout << " dataset " << dataset_idx << ", input: " <<
- weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
- }
- }
- }
- py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
- const py::array_t<int32_t>& doc_idx_,
- const int32_t seq_length,
- const int32_t num_epochs,
- const int64_t tokens_per_epoch) {
- /* Sample index (sample_idx) is used for gpt2 like dataset for which
- the documents are flattened and the samples are built based on this
- 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
- where [..., 0] contains the index into `doc_idx` and [..., 1] is the
- starting offset in that document.*/
- // Consistency checks.
- assert(seq_length > 1);
- assert(num_epochs > 0);
- assert(tokens_per_epoch > 1);
- // Remove bound checks.
- auto sizes = sizes_.unchecked<1>();
- auto doc_idx = doc_idx_.unchecked<1>();
- // Mapping and it's length (1D).
- int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
- int32_t* sample_idx = new int32_t[2*(num_samples+1)];
- cout << " using:" << endl << std::flush;
- cout << " number of documents: " <<
- doc_idx_.shape(0) / num_epochs << endl << std::flush;
- cout << " number of epochs: " << num_epochs <<
- endl << std::flush;
- cout << " sequence length: " << seq_length <<
- endl << std::flush;
- cout << " total number of samples: " << num_samples <<
- endl << std::flush;
- // Index into sample_idx.
- int64_t sample_index = 0;
- // Index into doc_idx.
- int64_t doc_idx_index = 0;
- // Begining offset for each document.
- int32_t doc_offset = 0;
- // Start with first document and no offset.
- sample_idx[2 * sample_index] = doc_idx_index;
- sample_idx[2 * sample_index + 1] = doc_offset;
- ++sample_index;
- while (sample_index <= num_samples) {
- // Start with a fresh sequence.
- int32_t remaining_seq_length = seq_length + 1;
- while (remaining_seq_length != 0) {
- // Get the document length.
- auto doc_id = doc_idx[doc_idx_index];
- auto doc_length = sizes[doc_id] - doc_offset;
- // And add it to the current sequence.
- remaining_seq_length -= doc_length;
- // If we have more than a full sequence, adjust offset and set
- // remaining length to zero so we return from the while loop.
- // Note that -1 here is for the same reason we have -1 in
- // `_num_epochs` calculations.
- if (remaining_seq_length <= 0) {
- doc_offset += (remaining_seq_length + doc_length - 1);
- remaining_seq_length = 0;
- } else {
- // Otherwise, start from the begining of the next document.
- ++doc_idx_index;
- doc_offset = 0;
- }
- }
- // Record the sequence.
- sample_idx[2 * sample_index] = doc_idx_index;
- sample_idx[2 * sample_index + 1] = doc_offset;
- ++sample_index;
- }
- // Method to deallocate memory.
- py::capsule free_when_done(sample_idx, [](void *mem_) {
- int32_t *mem = reinterpret_cast<int32_t*>(mem_);
- delete[] mem;
- });
- // Return the numpy array.
- const auto byte_size = sizeof(int32_t);
- return py::array(std::vector<int64_t>{num_samples+1, 2}, // shape
- {2*byte_size, byte_size}, // C-style contiguous strides
- sample_idx, // the data pointer
- free_when_done); // numpy array references
-
- }
- inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
- const int32_t max_length,
- std::mt19937& rand32_gen) {
- /* Training sample length. */
- if (short_seq_ratio == 0) {
- return max_length;
- }
- const auto random_number = rand32_gen();
- if ((random_number % short_seq_ratio) == 0) {
- return 2 + random_number % (max_length - 1);
- }
- return max_length;
- }
- template<typename DocIdx>
- py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
- const py::array_t<int32_t>& sizes_,
- const int32_t num_epochs,
- const uint64_t max_num_samples,
- const int32_t max_seq_length,
- const double short_seq_prob,
- const int32_t seed,
- const bool verbose,
- const int32_t min_num_sent) {
- /* Build a mapping of (start-index, end-index, sequence-length) where
- start and end index are the indices of the sentences in the sample
- and sequence-length is the target sequence length.
- */
- // Consistency checks.
- assert(num_epochs > 0);
- assert(max_seq_length > 1);
- assert(short_seq_prob >= 0.0);
- assert(short_seq_prob <= 1.0);
- assert(seed > 0);
- // Remove bound checks.
- auto docs = docs_.unchecked<1>();
- auto sizes = sizes_.unchecked<1>();
- // For efficiency, convert probability to ratio. Note: rand() generates int.
- int32_t short_seq_ratio = 0;
- if (short_seq_prob > 0) {
- short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
- }
- if (verbose) {
- const auto sent_start_index = docs[0];
- const auto sent_end_index = docs[docs_.shape(0) - 1];
- const auto num_sentences = sent_end_index - sent_start_index;
- cout << " using:" << endl << std::flush;
- cout << " number of documents: " << docs_.shape(0) - 1 <<
- endl << std::flush;
- cout << " sentences range: [" << sent_start_index <<
- ", " << sent_end_index << ")" << endl << std::flush;
- cout << " total number of sentences: " << num_sentences <<
- endl << std::flush;
- cout << " number of epochs: " << num_epochs <<
- endl << std::flush;
- cout << " maximum number of samples: " << max_num_samples <<
- endl << std::flush;
- cout << " maximum sequence length: " << max_seq_length <<
- endl << std::flush;
- cout << " short sequence probability: " << short_seq_prob <<
- endl << std::flush;
- cout << " short sequence ration (1/prob): " << short_seq_ratio <<
- endl << std::flush;
- cout << " seed: " << seed << endl <<
- std::flush;
- }
- // Mapping and it's length (1D).
- int64_t num_samples = -1;
- DocIdx* maps = NULL;
- // Perform two iterations, in the first iteration get the size
- // and allocate memory and in the second iteration populate the map.
- bool second = false;
- for (int32_t iteration=0; iteration<2; ++iteration) {
- // Set the seed so both iterations produce the same results.
- std::mt19937 rand32_gen(seed);
- // Set the flag on second iteration.
- second = (iteration == 1);
- // Counters:
- uint64_t empty_docs = 0;
- uint64_t one_sent_docs = 0;
- uint64_t long_sent_docs = 0;
- // Current map index.
- uint64_t map_index = 0;
- // For each epoch:
- for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
- if (map_index >= max_num_samples) {
- if (verbose && (!second)) {
- cout << " reached " << max_num_samples << " samples after "
- << epoch << " epochs ..." << endl << std::flush;
- }
- break;
- }
- // For each document:
- for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
- // Document sentences are in [sent_index_first, sent_index_last)
- const auto sent_index_first = docs[doc];
- const auto sent_index_last = docs[doc + 1];
- // At the begining of the document previous index is the
- // start index.
- auto prev_start_index = sent_index_first;
- // Remaining documents.
- auto num_remain_sent = sent_index_last - sent_index_first;
- // Some bookkeeping
- if ((epoch == 0) && (!second)) {
- if (num_remain_sent == 0) {
- ++empty_docs;
- }
- if (num_remain_sent == 1) {
- ++one_sent_docs;
- }
- }
- // Detect documents with long sentences.
- bool contains_long_sentence = false;
- if (num_remain_sent > 1) {
- for (auto sent_index=sent_index_first;
- sent_index < sent_index_last; ++sent_index) {
- if (sizes[sent_index] > LONG_SENTENCE_LEN){
- if ((epoch == 0) && (!second)) {
- ++long_sent_docs;
- }
- contains_long_sentence = true;
- break;
- }
- }
- }
- // If we have more than two sentences.
- if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
- // Set values.
- auto seq_len = int32_t{0};
- auto num_sent = int32_t{0};
- auto target_seq_len = get_target_sample_len(short_seq_ratio,
- max_seq_length,
- rand32_gen);
- // Loop through sentences.
- for (auto sent_index=sent_index_first;
- sent_index < sent_index_last; ++sent_index) {
- // Add the size and number of sentences.
- seq_len += sizes[sent_index];
- ++num_sent;
- --num_remain_sent;
- // If we have reached the target length.
- // and if not only one sentence is left in the document.
- // and if we have at least two sentneces.
- // and if we have reached end of the document.
- if (((seq_len >= target_seq_len) &&
- (num_remain_sent > 1) &&
- (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
- // Check for overflow.
- if ((3 * map_index + 2) >
- std::numeric_limits<int64_t>::max()) {
- cout << "number of samples exceeded maximum "
- << "allowed by type int64: "
- << std::numeric_limits<int64_t>::max()
- << endl;
- throw std::overflow_error("Number of samples");
- }
- // Populate the map.
- if (second) {
- const auto map_index_0 = 3 * map_index;
- maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
- maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
- maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len);
- }
- // Update indices / counters.
- ++map_index;
- prev_start_index = sent_index + 1;
- target_seq_len = get_target_sample_len(short_seq_ratio,
- max_seq_length,
- rand32_gen);
- seq_len = 0;
- num_sent = 0;
- }
- } // for (auto sent_index=sent_index_first; ...
- } // if (num_remain_sent > 1) {
- } // for (int doc=0; doc < num_docs; ++doc) {
- } // for (int epoch=0; epoch < num_epochs; ++epoch) {
- if (!second) {
- if (verbose) {
- cout << " number of empty documents: " << empty_docs <<
- endl << std::flush;
- cout << " number of documents with one sentence: " <<
- one_sent_docs << endl << std::flush;
- cout << " number of documents with long sentences: " <<
- long_sent_docs << endl << std::flush;
- cout << " will create mapping for " << map_index <<
- " samples" << endl << std::flush;
- }
- assert(maps == NULL);
- assert(num_samples < 0);
- maps = new DocIdx[3*map_index];
- num_samples = static_cast<int64_t>(map_index);
- }
- } // for (int iteration=0; iteration < 2; ++iteration) {
- // Shuffle.
- // We need a 64 bit random number generator as we might have more
- // than 2 billion samples.
- std::mt19937_64 rand64_gen(seed + 1);
- for (auto i=(num_samples - 1); i > 0; --i) {
- const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
- const auto i0 = 3 * i;
- const auto j0 = 3 * j;
- // Swap values.
- swap(maps[i0], maps[j0]);
- swap(maps[i0 + 1], maps[j0 + 1]);
- swap(maps[i0 + 2], maps[j0 + 2]);
- }
- // Method to deallocate memory.
- py::capsule free_when_done(maps, [](void *mem_) {
- DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
- delete[] mem;
- });
- // Return the numpy array.
- const auto byte_size = sizeof(DocIdx);
- return py::array(std::vector<int64_t>{num_samples, 3}, // shape
- {3*byte_size, byte_size}, // C-style contiguous strides
- maps, // the data pointer
- free_when_done); // numpy array references
- }
- py::array build_mapping(const py::array_t<int64_t>& docs_,
- const py::array_t<int>& sizes_,
- const int num_epochs,
- const uint64_t max_num_samples,
- const int max_seq_length,
- const double short_seq_prob,
- const int seed,
- const bool verbose,
- const int32_t min_num_sent) {
- if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
- if (verbose) {
- cout << " using uint64 for data mapping..." << endl << std::flush;
- }
- return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
- max_num_samples, max_seq_length,
- short_seq_prob, seed, verbose,
- min_num_sent);
- } else {
- if (verbose) {
- cout << " using uint32 for data mapping..." << endl << std::flush;
- }
- return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
- max_num_samples, max_seq_length,
- short_seq_prob, seed, verbose,
- min_num_sent);
- }
- }
- template<typename DocIdx>
- py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
- const py::array_t<int32_t>& sizes_,
- const py::array_t<int32_t>& titles_sizes_,
- const int32_t num_epochs,
- const uint64_t max_num_samples,
- const int32_t max_seq_length,
- const int32_t seed,
- const bool verbose,
- const bool use_one_sent_blocks) {
- /* Build a mapping of (start-index, end-index, sequence-length) where
- start and end index are the indices of the sentences in the sample
- and sequence-length is the target sequence length.
- */
- // Consistency checks.
- assert(num_epochs > 0);
- assert(max_seq_length > 1);
- assert(seed > 0);
- // Remove bound checks.
- auto docs = docs_.unchecked<1>();
- auto sizes = sizes_.unchecked<1>();
- auto titles_sizes = titles_sizes_.unchecked<1>();
- if (verbose) {
- const auto sent_start_index = docs[0];
- const auto sent_end_index = docs[docs_.shape(0) - 1];
- const auto num_sentences = sent_end_index - sent_start_index;
- cout << " using:" << endl << std::flush;
- cout << " number of documents: " << docs_.shape(0) - 1 <<
- endl << std::flush;
- cout << " sentences range: [" << sent_start_index <<
- ", " << sent_end_index << ")" << endl << std::flush;
- cout << " total number of sentences: " << num_sentences <<
- endl << std::flush;
- cout << " number of epochs: " << num_epochs <<
- endl << std::flush;
- cout << " maximum number of samples: " << max_num_samples <<
- endl << std::flush;
- cout << " maximum sequence length: " << max_seq_length <<
- endl << std::flush;
- cout << " seed: " << seed << endl <<
- std::flush;
- }
- // Mapping and its length (1D).
- int64_t num_samples = -1;
- DocIdx* maps = NULL;
- // Acceptable number of sentences per block.
- int min_num_sent = 2;
- if (use_one_sent_blocks) {
- min_num_sent = 1;
- }
- // Perform two iterations, in the first iteration get the size
- // and allocate memory and in the second iteration populate the map.
- bool second = false;
- for (int32_t iteration=0; iteration<2; ++iteration) {
- // Set the flag on second iteration.
- second = (iteration == 1);
- // Current map index.
- uint64_t map_index = 0;
- uint64_t empty_docs = 0;
- uint64_t one_sent_docs = 0;
- uint64_t long_sent_docs = 0;
- // For each epoch:
- for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
- // assign every block a unique id
- int32_t block_id = 0;
- if (map_index >= max_num_samples) {
- if (verbose && (!second)) {
- cout << " reached " << max_num_samples << " samples after "
- << epoch << " epochs ..." << endl << std::flush;
- }
- break;
- }
- // For each document:
- for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
- // Document sentences are in [sent_index_first, sent_index_last)
- const auto sent_index_first = docs[doc];
- const auto sent_index_last = docs[doc + 1];
- const auto target_seq_len = max_seq_length - titles_sizes[doc];
- // At the begining of the document previous index is the
- // start index.
- auto prev_start_index = sent_index_first;
- // Remaining documents.
- auto num_remain_sent = sent_index_last - sent_index_first;
- // Some bookkeeping
- if ((epoch == 0) && (!second)) {
- if (num_remain_sent == 0) {
- ++empty_docs;
- }
- if (num_remain_sent == 1) {
- ++one_sent_docs;
- }
- }
- // Detect documents with long sentences.
- bool contains_long_sentence = false;
- if (num_remain_sent >= min_num_sent) {
- for (auto sent_index=sent_index_first;
- sent_index < sent_index_last; ++sent_index) {
- if (sizes[sent_index] > LONG_SENTENCE_LEN){
- if ((epoch == 0) && (!second)) {
- ++long_sent_docs;
- }
- contains_long_sentence = true;
- break;
- }
- }
- }
- // If we have enough sentences and no long sentences.
- if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
- // Set values.
- auto seq_len = int32_t{0};
- auto num_sent = int32_t{0};
- // Loop through sentences.
- for (auto sent_index=sent_index_first;
- sent_index < sent_index_last; ++sent_index) {
- // Add the size and number of sentences.
- seq_len += sizes[sent_index];
- ++num_sent;
- --num_remain_sent;
- // If we have reached the target length.
- // and there are an acceptable number of sentences left
- // and if we have at least the minimum number of sentences.
- // or if we have reached end of the document.
- if (((seq_len >= target_seq_len) &&
- (num_remain_sent >= min_num_sent) &&
- (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
- // Populate the map.
- if (second) {
- const auto map_index_0 = 4 * map_index;
- // Each sample has 4 items: the starting sentence index, ending sentence index,
- // the index of the document from which the block comes (used for fetching titles)
- // and the unique id of the block (used for creating block indexes)
- maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
- maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
- maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
- maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
- }
- // Update indices / counters.
- ++map_index;
- ++block_id;
- prev_start_index = sent_index + 1;
- seq_len = 0;
- num_sent = 0;
- }
- } // for (auto sent_index=sent_index_first; ...
- } // if (num_remain_sent > 1) {
- } // for (int doc=0; doc < num_docs; ++doc) {
- } // for (int epoch=0; epoch < num_epochs; ++epoch) {
- if (!second) {
- if (verbose) {
- cout << " number of empty documents: " << empty_docs <<
- endl << std::flush;
- cout << " number of documents with one sentence: " <<
- one_sent_docs << endl << std::flush;
- cout << " number of documents with long sentences: " <<
- long_sent_docs << endl << std::flush;
- cout << " will create mapping for " << map_index <<
- " samples" << endl << std::flush;
- }
- assert(maps == NULL);
- assert(num_samples < 0);
- maps = new DocIdx[4*map_index];
- num_samples = static_cast<int64_t>(map_index);
- }
- } // for (int iteration=0; iteration < 2; ++iteration) {
- // Shuffle.
- // We need a 64 bit random number generator as we might have more
- // than 2 billion samples.
- std::mt19937_64 rand64_gen(seed + 1);
- for (auto i=(num_samples - 1); i > 0; --i) {
- const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
- const auto i0 = 4 * i;
- const auto j0 = 4 * j;
- // Swap values.
- swap(maps[i0], maps[j0]);
- swap(maps[i0 + 1], maps[j0 + 1]);
- swap(maps[i0 + 2], maps[j0 + 2]);
- swap(maps[i0 + 3], maps[j0 + 3]);
- }
- // Method to deallocate memory.
- py::capsule free_when_done(maps, [](void *mem_) {
- DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
- delete[] mem;
- });
- // Return the numpy array.
- const auto byte_size = sizeof(DocIdx);
- return py::array(std::vector<int64_t>{num_samples, 4}, // shape
- {4*byte_size, byte_size}, // C-style contiguous strides
- maps, // the data pointer
- free_when_done); // numpy array references
- }
- py::array build_blocks_mapping(const py::array_t<int64_t>& docs_,
- const py::array_t<int>& sizes_,
- const py::array_t<int>& titles_sizes_,
- const int num_epochs,
- const uint64_t max_num_samples,
- const int max_seq_length,
- const int seed,
- const bool verbose,
- const bool use_one_sent_blocks) {
- if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
- if (verbose) {
- cout << " using uint64 for data mapping..." << endl << std::flush;
- }
- return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_,
- num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
- } else {
- if (verbose) {
- cout << " using uint32 for data mapping..." << endl << std::flush;
- }
- return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_,
- num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
- }
- }
- PYBIND11_MODULE(helpers, m) {
- m.def("build_mapping", &build_mapping);
- m.def("build_blocks_mapping", &build_blocks_mapping);
- m.def("build_sample_idx", &build_sample_idx);
- m.def("build_blending_indices", &build_blending_indices);
- }
|