Skip to content
Snippets Groups Projects
Unverified Commit bb40f551 authored by Jonathan Kunstwald's avatar Jonathan Kunstwald
Browse files

Sort vocabulary and use binary over linear search

parent 7d6d70ac
No related branches found
No related tags found
No related merge requests found
#pragma once
#include <iterator> // for (only) std::forward_iterator_tag
#include <clean-core/span.hh>
namespace asr
{
/// like std::swap
template <class T>
constexpr void real_swap(T& a, T& b) // no ADL-jacking please
{
T tmp = static_cast<T&&>(a);
a = static_cast<T&&>(b);
b = static_cast<T&&>(tmp);
}
/// STL-compatible iterator with the main intention of sorting two parallel arrays (same length)
/// while only comparing values of the first one
/// (application here: sorting vocabulary hash array [for faster searches], and keeping the string array in sync)
template <class T1, class T2>
struct parallel_iterator
{
// these are required for std::iterator_traits to work
using difference_type = std::ptrdiff_t; // value received when subtracting two of these
using value_type = T1; // value that can be "dereferenced" from these
using pointer = T1*;
using reference = T1&;
using iterator_category = std::random_access_iterator_tag;
parallel_iterator& operator++() noexcept // pre-increment
{
++_data1;
++_data2;
return *this;
}
parallel_iterator operator++(int) noexcept // post-increment
{
parallel_iterator res = *this;
++(*this); // calls the pre-increment op above
return res;
}
T1& operator*() const noexcept { return *_data1; }
bool operator==(parallel_iterator const& rhs) const noexcept { return _data1 == rhs._data1; }
bool operator!=(parallel_iterator const& rhs) const noexcept { return _data1 != rhs._data1; }
inline friend difference_type operator-(parallel_iterator const& lhs, const parallel_iterator const& rhs) { return lhs._data1 - rhs._data1; }
void swap_from(parallel_iterator& other) noexcept
{
real_swap(_data1[0], other._data1[0]);
real_swap(_data2[0], other._data2[0]);
}
explicit constexpr parallel_iterator(T1* begin1, T2* begin2) : _data1(begin1), _data2(begin2) {}
private:
T1* _data1;
T2* _data2;
};
template <class T1, class T2>
parallel_iterator<T1, T2> it_parallel_begin(cc::span<T1> arr1, cc::span<T2> arr2)
{
CC_ASSERT(arr1.size() == arr2.size());
return parallel_iterator(arr1.begin(), arr2.begin());
}
template <class T1, class T2>
parallel_iterator<T1, T2> it_parallel_end(cc::span<T1> arr1, cc::span<T2> arr2)
{
CC_ASSERT(arr1.size() == arr2.size());
return parallel_iterator(arr1.end(), arr2.end());
}
template <class T1, class T2>
void swap(parallel_iterator<T1, T2>& lhs, parallel_iterator<T1, T2>& rhs) noexcept
{
lhs.swap_from(rhs);
}
}
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "common/file_util.hh" #include "common/file_util.hh"
#include "common/hashing.hh" #include "common/hashing.hh"
#include "common/parallel_iterator.hh"
#include "container/flat_map.hh" #include "container/flat_map.hh"
#define ASR_COUNTOF(_arr_) (sizeof(_arr_) / sizeof(_arr_[0])) #define ASR_COUNTOF(_arr_) (sizeof(_arr_) / sizeof(_arr_[0]))
...@@ -29,7 +30,7 @@ ...@@ -29,7 +30,7 @@
#define ASR_VOCABULARY_FILE "data/sheet1/data_lm_vocabulary.txt" #define ASR_VOCABULARY_FILE "data/sheet1/data_lm_vocabulary.txt"
#define ASR_SKIP_TASK1 0 #define ASR_SKIP_TASK1 1
#define ASR_SKIP_SLOW_NAIVE_TASK3 1 #define ASR_SKIP_SLOW_NAIVE_TASK3 1
namespace namespace
...@@ -54,6 +55,34 @@ struct corpus_statistics ...@@ -54,6 +55,34 @@ struct corpus_statistics
std::unordered_map<std::string, token_info> token_occurence; // key: token, value: "info" - occurence std::unordered_map<std::string, token_info> token_occurence; // key: token, value: "info" - occurence
}; };
// turns data = {a, b, c, d} and indices = {3, 2, 0, 1}
// into {d, c, a, b}
template <class T>
void permute_by_indices(cc::span<T> inout_data, cc::span<unsigned const> indices, cc::allocator* scratch_alloc)
{
CC_ASSERT(inout_data.size() == indices.size() && "indices do not match data length");
cc::alloc_array<T> temp_copy;
if constexpr (std::is_trivially_copyable_v<T>)
{
temp_copy = cc::alloc_array<uint64_t>::uninitialized(inout_data.size(), scratch_alloc);
std::memcpy(temp_copy.data(), inout_data.data(), inout_data.size_bytes());
}
else
{
temp_copy = cc::alloc_array<T>::defaulted(inout_data.size(), scratch_alloc);
for (auto i = 0u; i < temp_copy.size(); ++i)
{
temp_copy[i] = cc::move(inout_data[i]);
}
}
for (auto i = 0u; i < inout_data.size(); ++i)
{
inout_data[i] = cc::move(temp_copy[indices[i]]);
}
}
struct token_list struct token_list
{ {
cc::alloc_vector<std::string> ordered_tokens; // ordered tokens, <s> and </s> injected (slow) cc::alloc_vector<std::string> ordered_tokens; // ordered tokens, <s> and </s> injected (slow)
...@@ -76,6 +105,28 @@ struct token_list ...@@ -76,6 +105,28 @@ struct token_list
ordered_tokens.push_back(std::string(token)); ordered_tokens.push_back(std::string(token));
ordered_token_hashes.push_back(hash); ordered_token_hashes.push_back(hash);
} }
// sorts hashes to be incrementing over the array, while keeping the string array correctly parallel
void sort_parallel(cc::allocator* scratch_alloc = cc::system_allocator)
{
auto indices_array = cc::alloc_array<unsigned>::uninitialized(ordered_token_hashes.size(), scratch_alloc);
for (auto i = 0u; i < indices_array.size(); ++i)
{
indices_array[i] = i;
}
std::sort(indices_array.begin(), indices_array.end(),
[&](unsigned i_lhs, unsigned i_rhs) { return ordered_token_hashes[i_lhs] < ordered_token_hashes[i_rhs]; });
permute_by_indices<uint64_t>(ordered_token_hashes, indices_array, scratch_alloc);
permute_by_indices<std::string>(ordered_tokens, indices_array, scratch_alloc);
// for (auto i = 0u; i < ordered_tokens.size(); ++i)
// {
// printf("sorted #%u - %zu - %s\n", i, ordered_token_hashes[i], ordered_tokens[i].c_str());
// }
// fflush(stdout);
}
}; };
struct corpus_line_indices struct corpus_line_indices
...@@ -89,6 +140,7 @@ struct corpus_line_indices ...@@ -89,6 +140,7 @@ struct corpus_line_indices
cc::alloc_vector<line_info> line_infos; cc::alloc_vector<line_info> line_infos;
}; };
// takes a corpus string, writes statistics and <s>/</s>-padded token list // takes a corpus string, writes statistics and <s>/</s>-padded token list
void compute_corpus_statistics(char const* corpus, corpus_statistics& out_stats, token_list& out_list) void compute_corpus_statistics(char const* corpus, corpus_statistics& out_stats, token_list& out_list)
{ {
...@@ -130,13 +182,19 @@ void compute_corpus_statistics(char const* corpus, corpus_statistics& out_stats, ...@@ -130,13 +182,19 @@ void compute_corpus_statistics(char const* corpus, corpus_statistics& out_stats,
float tokenize_corpus(char const* corpus, cc::span<uint64_t const> vocab_hashes, token_list& out_list, corpus_line_indices& out_line_infos) float tokenize_corpus(char const* corpus, cc::span<uint64_t const> vocab_hashes, token_list& out_list, corpus_line_indices& out_line_infos)
{ {
#define ASR_USE_VOCAB_BINSEARCH 1
auto f_is_known_vocab = [&](uint64_t token_hash) -> bool { auto f_is_known_vocab = [&](uint64_t token_hash) -> bool {
#if ASR_USE_VOCAB_BINSEARCH
return std::binary_search(vocab_hashes.begin(), vocab_hashes.end(), token_hash, [](uint64_t lhs, uint64_t rhs) { return lhs < rhs; });
#else
for (auto const hash : vocab_hashes) for (auto const hash : vocab_hashes)
{ {
if (hash == token_hash) if (hash == token_hash)
return true; return true;
} }
return false; return false;
#endif
}; };
std::istringstream ss(corpus); std::istringstream ss(corpus);
...@@ -151,7 +209,7 @@ float tokenize_corpus(char const* corpus, cc::span<uint64_t const> vocab_hashes, ...@@ -151,7 +209,7 @@ float tokenize_corpus(char const* corpus, cc::span<uint64_t const> vocab_hashes,
// iterate over lines // iterate over lines
while (std::getline(ss, line)) while (std::getline(ss, line))
{ {
if ((out_line_infos.line_infos.size() & 0b1111111111111) == 0) if ((out_line_infos.line_infos.size() & 0b1111111111111111) == 0)
{ {
printf("[tokenize_corpus] line %zu of corpus\n", out_line_infos.line_infos.size()); printf("[tokenize_corpus] line %zu of corpus\n", out_line_infos.line_infos.size());
fflush(stdout); fflush(stdout);
...@@ -196,6 +254,8 @@ void tokenize_vocabulary(char const* vocabulary, token_list& out_list) ...@@ -196,6 +254,8 @@ void tokenize_vocabulary(char const* vocabulary, token_list& out_list)
while (ss >> token) while (ss >> token)
out_list.push_token(token.c_str()); out_list.push_token(token.c_str());
out_list.sort_parallel();
} }
using token_with_info_t = std::pair<std::string, token_info>; using token_with_info_t = std::pair<std::string, token_info>;
...@@ -339,7 +399,7 @@ void compute_ngram_stats(cc::span<uint64_t const> token_hashes, size_t vocab_siz ...@@ -339,7 +399,7 @@ void compute_ngram_stats(cc::span<uint64_t const> token_hashes, size_t vocab_siz
auto const loop_end = token_hashes.size() - (n - 1); auto const loop_end = token_hashes.size() - (n - 1);
for (auto i = 0u; i < loop_end; ++i) for (auto i = 0u; i < loop_end; ++i)
{ {
if ((i & 0xFFFFF) == 0) if ((i & 0x3FFFFF) == 0)
{ {
printf("[compute_ngram_stats] %u-tuple %u of %zu\n", n, i, loop_end); printf("[compute_ngram_stats] %u-tuple %u of %zu\n", n, i, loop_end);
fflush(stdout); fflush(stdout);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment