Use vectorized TopK using highway VQSelect

PiperOrigin-RevId: 728159153
This commit is contained in:
Apoorv Reddy 2025-02-18 05:00:53 -08:00 committed by Copybara-Service
parent 0e5b59d24d
commit d854471ae2
4 changed files with 94 additions and 58 deletions

View File

@ -108,6 +108,7 @@ cc_library(
"@highway//:matvec", "@highway//:matvec",
"@highway//:profiler", "@highway//:profiler",
"@highway//:thread_pool", "@highway//:thread_pool",
"@highway//hwy/contrib/sort:vqsort",
], ],
) )

View File

@ -22,7 +22,7 @@
#include <stdio.h> #include <stdio.h>
#include <cmath> #include <cmath>
#include <limits> #include <cstdint>
#include <random> #include <random>
#include <type_traits> // std::enable_if_t #include <type_traits> // std::enable_if_t
#include <vector> #include <vector>
@ -30,6 +30,8 @@
#include "compression/compress.h" #include "compression/compress.h"
#include "util/basics.h" // TokenAndProb #include "util/basics.h" // TokenAndProb
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/sort/order.h"
#include "hwy/contrib/sort/vqsort.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/detect_targets.h" #include "hwy/detect_targets.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
@ -54,6 +56,35 @@ namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
HWY_INLINE double PackTokenAndProb(int32_t token, float prob) {
// casting prob from float to double just makes some changes to the
// exponent bias and pads zeros in the mantissa.
double packed = static_cast<double>(prob);
int64_t packed_int64;
hwy::CopySameSize(&packed, &packed_int64);
// stuff the token into the lower 32 bits of packed_int64. (it is an int32_t
// anyway)
packed_int64 &= 0xFFFFFFFF00000000;
packed_int64 |= token;
// copy bytes back into packed.
hwy::CopySameSize(&packed_int64, &packed);
return packed;
}
HWY_INLINE TokenAndProb UnpackTokenAndProb(double packed) {
TokenAndProb tp;
int64_t packed_int64;
hwy::CopySameSize(&packed, &packed_int64);
tp.token = static_cast<int>(packed_int64 & 0xFFFFFFFFULL);
// clear the lower 32 bits of packed_int64 before copying back into packed.
packed_int64 &= 0xFFFFFFFF00000000ULL;
hwy::CopySameSize(&packed_int64, &packed);
tp.prob = static_cast<float>(packed);
return tp;
}
template <typename To, typename From> template <typename To, typename From>
HWY_INLINE constexpr std::enable_if_t< HWY_INLINE constexpr std::enable_if_t<
std::is_arithmetic_v<To> && std::is_arithmetic_v<From>, To> std::is_arithmetic_v<To> && std::is_arithmetic_v<From>, To>
@ -704,38 +735,45 @@ HWY_INLINE HWY_MAYBE_UNUSED std::discrete_distribution<int> create_distribution(
return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k)); return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k));
} }
template <typename TAcceptToken>
HWY_NOINLINE HWY_MAYBE_UNUSED std::vector<TokenAndProb> TopK(
const float* HWY_RESTRICT probabilities, size_t vocab_size, size_t k,
TAcceptToken& accept_token) {
HWY_ASSERT(k != 0);
HWY_ASSERT(k <= vocab_size);
std::vector<double> packed_token_probs;
for (int32_t i = 0; i < vocab_size; ++i) {
if (accept_token && !accept_token(StaticCast<int>(i), probabilities[i])) {
continue;
}
packed_token_probs.push_back(PackTokenAndProb(i, probabilities[i]));
}
hwy::VQSelect(packed_token_probs.data(), packed_token_probs.size(), k,
hwy::SortDescending());
hwy::VQSort(packed_token_probs.data(), k, hwy::SortDescending());
std::vector<TokenAndProb> token_probs;
token_probs.reserve(k);
for (int32_t i = 0; i < k; ++i) {
token_probs.push_back(UnpackTokenAndProb(packed_token_probs[i]));
}
return token_probs;
}
template <typename TAcceptToken> template <typename TAcceptToken>
HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
const float* HWY_RESTRICT probabilities, size_t k, size_t vocab_size, const float* HWY_RESTRICT probabilities, size_t k, size_t vocab_size,
std::mt19937& gen, float temperature, TAcceptToken& accept_token) { std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
HWY_ASSERT(k != 0); std::vector<TokenAndProb> token_probs =
HWY_ASSERT(k <= vocab_size); TopK(probabilities, vocab_size, k, accept_token);
// TODO: Optimize, potentially using new VQSort PartialSort. std::vector<int> topk_indices(k);
// Sorted from highest [0], to lowest [k-1] std::vector<float> topk_probs(k);
std::vector<float> top_k(k, -std::numeric_limits<float>::infinity()); for (int i = 0; i < k; ++i) {
std::vector<int> indices(k); topk_indices[i] = token_probs[i].token;
size_t num_accepted = 0; topk_probs[i] = token_probs[i].prob;
for (size_t i = 0; i < vocab_size; ++i) {
if (probabilities[i] < top_k[k - 1]) continue;
bool accepted =
!accept_token || accept_token(StaticCast<int>(i), probabilities[i]);
if (!accepted) continue;
num_accepted++;
for (size_t j = 0; j < k; ++j) {
if (probabilities[i] > top_k[j]) {
// shift elements by 1, insert the new value, move on to next value
for (size_t idx = k - 1; idx > j; --idx) {
top_k[idx] = top_k[idx - 1];
indices[idx] = indices[idx - 1];
}
top_k[j] = probabilities[i];
indices[j] = StaticCast<int>(i);
break;
}
}
} }
HWY_ASSERT(k <= num_accepted); return topk_indices[create_distribution(topk_probs, temperature)(gen)];
return indices[create_distribution(top_k, temperature)(gen)];
} }
template <typename TAcceptToken> template <typename TAcceptToken>
@ -745,40 +783,23 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
// Softmax and sample top-K is equivalent to taking the top-K logits and // Softmax and sample top-K is equivalent to taking the top-K logits and
// sampling from the softmax of the top-K logits. The latter is faster as it // sampling from the softmax of the top-K logits. The latter is faster as it
// avoids computing the softmax of all logits. // avoids computing the softmax of all logits.
HWY_ASSERT(k != 0); std::vector<TokenAndProb> token_logits =
HWY_ASSERT(k <= vocab_size); TopK(logits, vocab_size, k, accept_token);
std::vector<int> topk_indices(k);
std::vector<float> top_k(k, -std::numeric_limits<float>::infinity()); std::vector<float> topk_logits(k);
std::vector<int> indices(k); for (int i = 0; i < token_logits.size(); ++i) {
size_t num_accepted = 0; topk_indices[i] = token_logits[i].token;
for (size_t i = 0; i < vocab_size; ++i) { topk_logits[i] = token_logits[i].prob;
if (logits[i] < top_k[k - 1]) continue;
bool accepted =
!accept_token || accept_token(StaticCast<int>(i), logits[i]);
if (!accepted) continue;
num_accepted++;
for (size_t j = 0; j < k; ++j) {
if (logits[i] > top_k[j]) {
// shift elements by 1, insert the new value, move on to next value
for (size_t idx = k - 1; idx > j; --idx) {
top_k[idx] = top_k[idx - 1];
indices[idx] = indices[idx - 1];
}
top_k[j] = logits[i];
indices[j] = StaticCast<int>(i);
break;
}
}
} }
size_t mask = k <= num_accepted ? k : num_accepted; size_t mask = token_logits.size();
Softmax(top_k.data(), mask, temperature); Softmax(topk_logits.data(), mask, temperature);
auto distribution = std::discrete_distribution<int>(std::begin(top_k), auto distribution = std::discrete_distribution<int>(
std::begin(top_k) + mask); std::begin(topk_logits), std::begin(topk_logits) + mask);
int topk_sampled_index = distribution(gen); int topk_sampled_index = distribution(gen);
int sampled_index = indices[topk_sampled_index]; int sampled_index = topk_indices[topk_sampled_index];
return TokenAndProb{.token = sampled_index, return TokenAndProb{.token = sampled_index,
.prob = top_k[topk_sampled_index]}; .prob = topk_logits[topk_sampled_index]};
} }
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -600,6 +600,17 @@ void TestSampleTopK() {
} }
} }
void TestPackTokenAndProb() {
double packed1 = PackTokenAndProb(10, 0.96f);
TokenAndProb unpacked1 = UnpackTokenAndProb(packed1);
EXPECT_EQ(unpacked1.token, 10);
EXPECT_NEAR(unpacked1.prob, 0.96f, 1e-6);
double packed2 = PackTokenAndProb(1000000000, 0.87f);
EXPECT_LT(packed2, packed1);
}
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE } // namespace HWY_NAMESPACE
} // namespace gcpp } // namespace gcpp
@ -621,6 +632,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNorm);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllLayerNorm); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllLayerNorm);
HWY_EXPORT_AND_TEST_P(OpsTest, TestLayerNormSimple); HWY_EXPORT_AND_TEST_P(OpsTest, TestLayerNormSimple);
HWY_EXPORT_AND_TEST_P(OpsTest, TestSampleTopK); HWY_EXPORT_AND_TEST_P(OpsTest, TestSampleTopK);
HWY_EXPORT_AND_TEST_P(OpsTest, TestPackTokenAndProb);
HWY_AFTER_TEST(); HWY_AFTER_TEST();
} // namespace gcpp } // namespace gcpp

View File

@ -57,10 +57,12 @@ static inline void MaybeCheckInitialized(const void* ptr, size_t size) {
} }
// Shared between gemma.h and ops-inl.h. // Shared between gemma.h and ops-inl.h.
#pragma pack(push, 1)
struct TokenAndProb { struct TokenAndProb {
int token; int token;
float prob; float prob;
}; };
#pragma pack(pop)
// Entire size of a 2D array. // Entire size of a 2D array.
struct Extents2D { struct Extents2D {