diff --git a/src/common/parallel_iterator.hh b/src/common/parallel_iterator.hh new file mode 100644 index 0000000000000000000000000000000000000000..fca066ee6e9b4e7cdc18da909124a9b3cde51b29 --- /dev/null +++ b/src/common/parallel_iterator.hh @@ -0,0 +1,85 @@ +#pragma once + +#include <iterator> // for (only) std::forward_iterator_tag + +#include <clean-core/span.hh> + +namespace asr +{ +/// like std::swap +template <class T> +constexpr void real_swap(T& a, T& b) // no ADL-jacking please +{ + T tmp = static_cast<T&&>(a); + a = static_cast<T&&>(b); + b = static_cast<T&&>(tmp); +} + +/// STL-compatible iterator with the main intention of sorting two parallel arrays (same length) +/// while only comparing values of the first one +/// (application here: sorting vocabulary hash array [for faster searches], and keeping the string array in sync) +template <class T1, class T2> +struct parallel_iterator +{ + // these are required for std::iterator_traits to work + using difference_type = std::ptrdiff_t; // value received when subtracting two of these + using value_type = T1; // value that can be "dereferenced" from these + using pointer = T1*; + using reference = T1&; + using iterator_category = std::random_access_iterator_tag; + + parallel_iterator& operator++() noexcept // pre-increment + { + ++_data1; + ++_data2; + return *this; + } + + parallel_iterator operator++(int) noexcept // post-increment + { + parallel_iterator res = *this; + ++(*this); // calls the pre-increment op above + return res; + } + + T1& operator*() const noexcept { return *_data1; } + + bool operator==(parallel_iterator const& rhs) const noexcept { return _data1 == rhs._data1; } + bool operator!=(parallel_iterator const& rhs) const noexcept { return _data1 != rhs._data1; } + + inline friend difference_type operator-(parallel_iterator const& lhs, const parallel_iterator const& rhs) { return lhs._data1 - rhs._data1; } + + + void swap_from(parallel_iterator& other) noexcept + { + real_swap(_data1[0], other._data1[0]); + real_swap(_data2[0], other._data2[0]); + } + + explicit constexpr parallel_iterator(T1* begin1, T2* begin2) : _data1(begin1), _data2(begin2) {} + +private: + T1* _data1; + T2* _data2; +}; + +template <class T1, class T2> +parallel_iterator<T1, T2> it_parallel_begin(cc::span<T1> arr1, cc::span<T2> arr2) +{ + CC_ASSERT(arr1.size() == arr2.size()); + return parallel_iterator(arr1.begin(), arr2.begin()); +} + +template <class T1, class T2> +parallel_iterator<T1, T2> it_parallel_end(cc::span<T1> arr1, cc::span<T2> arr2) +{ + CC_ASSERT(arr1.size() == arr2.size()); + return parallel_iterator(arr1.end(), arr2.end()); +} + +template <class T1, class T2> +void swap(parallel_iterator<T1, T2>& lhs, parallel_iterator<T1, T2>& rhs) noexcept +{ + lhs.swap_from(rhs); +} +} diff --git a/src/main.cc b/src/main.cc index 3e412282dcee945fcafe820a54dfdd6123824b59..0f5dce1ab5bc6039ff8d8460aa9ac6993a40a724 100644 --- a/src/main.cc +++ b/src/main.cc @@ -15,6 +15,7 @@ #include "common/file_util.hh" #include "common/hashing.hh" +#include "common/parallel_iterator.hh" #include "container/flat_map.hh" #define ASR_COUNTOF(_arr_) (sizeof(_arr_) / sizeof(_arr_[0])) @@ -29,7 +30,7 @@ #define ASR_VOCABULARY_FILE "data/sheet1/data_lm_vocabulary.txt" -#define ASR_SKIP_TASK1 0 +#define ASR_SKIP_TASK1 1 #define ASR_SKIP_SLOW_NAIVE_TASK3 1 namespace @@ -54,6 +55,34 @@ struct corpus_statistics std::unordered_map<std::string, token_info> token_occurence; // key: token, value: "info" - occurence }; +// turns data = {a, b, c, d} and indices = {3, 2, 0, 1} +// into {d, c, a, b} +template <class T> +void permute_by_indices(cc::span<T> inout_data, cc::span<unsigned const> indices, cc::allocator* scratch_alloc) +{ + CC_ASSERT(inout_data.size() == indices.size() && "indices do not match data length"); + + cc::alloc_array<T> temp_copy; + if constexpr (std::is_trivially_copyable_v<T>) + { + temp_copy = cc::alloc_array<uint64_t>::uninitialized(inout_data.size(), scratch_alloc); + std::memcpy(temp_copy.data(), inout_data.data(), inout_data.size_bytes()); + } + else + { + temp_copy = cc::alloc_array<T>::defaulted(inout_data.size(), scratch_alloc); + for (auto i = 0u; i < temp_copy.size(); ++i) + { + temp_copy[i] = cc::move(inout_data[i]); + } + } + + for (auto i = 0u; i < inout_data.size(); ++i) + { + inout_data[i] = cc::move(temp_copy[indices[i]]); + } +} + struct token_list { cc::alloc_vector<std::string> ordered_tokens; // ordered tokens, <s> and </s> injected (slow) @@ -76,6 +105,28 @@ struct token_list ordered_tokens.push_back(std::string(token)); ordered_token_hashes.push_back(hash); } + + // sorts hashes to be incrementing over the array, while keeping the string array correctly parallel + void sort_parallel(cc::allocator* scratch_alloc = cc::system_allocator) + { + auto indices_array = cc::alloc_array<unsigned>::uninitialized(ordered_token_hashes.size(), scratch_alloc); + for (auto i = 0u; i < indices_array.size(); ++i) + { + indices_array[i] = i; + } + + std::sort(indices_array.begin(), indices_array.end(), + [&](unsigned i_lhs, unsigned i_rhs) { return ordered_token_hashes[i_lhs] < ordered_token_hashes[i_rhs]; }); + + permute_by_indices<uint64_t>(ordered_token_hashes, indices_array, scratch_alloc); + permute_by_indices<std::string>(ordered_tokens, indices_array, scratch_alloc); + + // for (auto i = 0u; i < ordered_tokens.size(); ++i) + // { + // printf("sorted #%u - %zu - %s\n", i, ordered_token_hashes[i], ordered_tokens[i].c_str()); + // } + // fflush(stdout); + } }; struct corpus_line_indices @@ -89,6 +140,7 @@ struct corpus_line_indices cc::alloc_vector<line_info> line_infos; }; + // takes a corpus string, writes statistics and <s>/</s>-padded token list void compute_corpus_statistics(char const* corpus, corpus_statistics& out_stats, token_list& out_list) { @@ -130,13 +182,19 @@ void compute_corpus_statistics(char const* corpus, corpus_statistics& out_stats, float tokenize_corpus(char const* corpus, cc::span<uint64_t const> vocab_hashes, token_list& out_list, corpus_line_indices& out_line_infos) { +#define ASR_USE_VOCAB_BINSEARCH 1 + auto f_is_known_vocab = [&](uint64_t token_hash) -> bool { +#if ASR_USE_VOCAB_BINSEARCH + return std::binary_search(vocab_hashes.begin(), vocab_hashes.end(), token_hash, [](uint64_t lhs, uint64_t rhs) { return lhs < rhs; }); +#else for (auto const hash : vocab_hashes) { if (hash == token_hash) return true; } return false; +#endif }; std::istringstream ss(corpus); @@ -151,7 +209,7 @@ float tokenize_corpus(char const* corpus, cc::span<uint64_t const> vocab_hashes, // iterate over lines while (std::getline(ss, line)) { - if ((out_line_infos.line_infos.size() & 0b1111111111111) == 0) + if ((out_line_infos.line_infos.size() & 0b1111111111111111) == 0) { printf("[tokenize_corpus] line %zu of corpus\n", out_line_infos.line_infos.size()); fflush(stdout); @@ -196,6 +254,8 @@ void tokenize_vocabulary(char const* vocabulary, token_list& out_list) while (ss >> token) out_list.push_token(token.c_str()); + + out_list.sort_parallel(); } using token_with_info_t = std::pair<std::string, token_info>; @@ -339,7 +399,7 @@ void compute_ngram_stats(cc::span<uint64_t const> token_hashes, size_t vocab_siz auto const loop_end = token_hashes.size() - (n - 1); for (auto i = 0u; i < loop_end; ++i) { - if ((i & 0xFFFFF) == 0) + if ((i & 0x3FFFFF) == 0) { printf("[compute_ngram_stats] %u-tuple %u of %zu\n", n, i, loop_end); fflush(stdout);