mirror of https://github.com/google/gemma.cpp.git
Use vectorized TopK using highway VQSelect
PiperOrigin-RevId: 728159153
This commit is contained in:
parent
0e5b59d24d
commit
d854471ae2
|
|
@ -108,6 +108,7 @@ cc_library(
|
||||||
"@highway//:matvec",
|
"@highway//:matvec",
|
||||||
"@highway//:profiler",
|
"@highway//:profiler",
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
|
"@highway//hwy/contrib/sort:vqsort",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
137
ops/ops-inl.h
137
ops/ops-inl.h
|
|
@ -22,7 +22,7 @@
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <limits>
|
#include <cstdint>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <type_traits> // std::enable_if_t
|
#include <type_traits> // std::enable_if_t
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
@ -30,6 +30,8 @@
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
#include "util/basics.h" // TokenAndProb
|
#include "util/basics.h" // TokenAndProb
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/contrib/sort/order.h"
|
||||||
|
#include "hwy/contrib/sort/vqsort.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/detect_targets.h"
|
#include "hwy/detect_targets.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
@ -54,6 +56,35 @@ namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
|
HWY_INLINE double PackTokenAndProb(int32_t token, float prob) {
|
||||||
|
// casting prob from float to double just makes some changes to the
|
||||||
|
// exponent bias and pads zeros in the mantissa.
|
||||||
|
double packed = static_cast<double>(prob);
|
||||||
|
int64_t packed_int64;
|
||||||
|
hwy::CopySameSize(&packed, &packed_int64);
|
||||||
|
// stuff the token into the lower 32 bits of packed_int64. (it is an int32_t
|
||||||
|
// anyway)
|
||||||
|
packed_int64 &= 0xFFFFFFFF00000000;
|
||||||
|
packed_int64 |= token;
|
||||||
|
// copy bytes back into packed.
|
||||||
|
hwy::CopySameSize(&packed_int64, &packed);
|
||||||
|
return packed;
|
||||||
|
}
|
||||||
|
|
||||||
|
HWY_INLINE TokenAndProb UnpackTokenAndProb(double packed) {
|
||||||
|
TokenAndProb tp;
|
||||||
|
|
||||||
|
int64_t packed_int64;
|
||||||
|
hwy::CopySameSize(&packed, &packed_int64);
|
||||||
|
tp.token = static_cast<int>(packed_int64 & 0xFFFFFFFFULL);
|
||||||
|
|
||||||
|
// clear the lower 32 bits of packed_int64 before copying back into packed.
|
||||||
|
packed_int64 &= 0xFFFFFFFF00000000ULL;
|
||||||
|
hwy::CopySameSize(&packed_int64, &packed);
|
||||||
|
tp.prob = static_cast<float>(packed);
|
||||||
|
return tp;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename To, typename From>
|
template <typename To, typename From>
|
||||||
HWY_INLINE constexpr std::enable_if_t<
|
HWY_INLINE constexpr std::enable_if_t<
|
||||||
std::is_arithmetic_v<To> && std::is_arithmetic_v<From>, To>
|
std::is_arithmetic_v<To> && std::is_arithmetic_v<From>, To>
|
||||||
|
|
@ -704,38 +735,45 @@ HWY_INLINE HWY_MAYBE_UNUSED std::discrete_distribution<int> create_distribution(
|
||||||
return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k));
|
return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename TAcceptToken>
|
||||||
|
HWY_NOINLINE HWY_MAYBE_UNUSED std::vector<TokenAndProb> TopK(
|
||||||
|
const float* HWY_RESTRICT probabilities, size_t vocab_size, size_t k,
|
||||||
|
TAcceptToken& accept_token) {
|
||||||
|
HWY_ASSERT(k != 0);
|
||||||
|
HWY_ASSERT(k <= vocab_size);
|
||||||
|
std::vector<double> packed_token_probs;
|
||||||
|
for (int32_t i = 0; i < vocab_size; ++i) {
|
||||||
|
if (accept_token && !accept_token(StaticCast<int>(i), probabilities[i])) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
packed_token_probs.push_back(PackTokenAndProb(i, probabilities[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
hwy::VQSelect(packed_token_probs.data(), packed_token_probs.size(), k,
|
||||||
|
hwy::SortDescending());
|
||||||
|
hwy::VQSort(packed_token_probs.data(), k, hwy::SortDescending());
|
||||||
|
|
||||||
|
std::vector<TokenAndProb> token_probs;
|
||||||
|
token_probs.reserve(k);
|
||||||
|
for (int32_t i = 0; i < k; ++i) {
|
||||||
|
token_probs.push_back(UnpackTokenAndProb(packed_token_probs[i]));
|
||||||
|
}
|
||||||
|
return token_probs;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename TAcceptToken>
|
template <typename TAcceptToken>
|
||||||
HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
||||||
const float* HWY_RESTRICT probabilities, size_t k, size_t vocab_size,
|
const float* HWY_RESTRICT probabilities, size_t k, size_t vocab_size,
|
||||||
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
|
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
|
||||||
HWY_ASSERT(k != 0);
|
std::vector<TokenAndProb> token_probs =
|
||||||
HWY_ASSERT(k <= vocab_size);
|
TopK(probabilities, vocab_size, k, accept_token);
|
||||||
// TODO: Optimize, potentially using new VQSort PartialSort.
|
std::vector<int> topk_indices(k);
|
||||||
// Sorted from highest [0], to lowest [k-1]
|
std::vector<float> topk_probs(k);
|
||||||
std::vector<float> top_k(k, -std::numeric_limits<float>::infinity());
|
for (int i = 0; i < k; ++i) {
|
||||||
std::vector<int> indices(k);
|
topk_indices[i] = token_probs[i].token;
|
||||||
size_t num_accepted = 0;
|
topk_probs[i] = token_probs[i].prob;
|
||||||
for (size_t i = 0; i < vocab_size; ++i) {
|
|
||||||
if (probabilities[i] < top_k[k - 1]) continue;
|
|
||||||
bool accepted =
|
|
||||||
!accept_token || accept_token(StaticCast<int>(i), probabilities[i]);
|
|
||||||
if (!accepted) continue;
|
|
||||||
num_accepted++;
|
|
||||||
for (size_t j = 0; j < k; ++j) {
|
|
||||||
if (probabilities[i] > top_k[j]) {
|
|
||||||
// shift elements by 1, insert the new value, move on to next value
|
|
||||||
for (size_t idx = k - 1; idx > j; --idx) {
|
|
||||||
top_k[idx] = top_k[idx - 1];
|
|
||||||
indices[idx] = indices[idx - 1];
|
|
||||||
}
|
|
||||||
top_k[j] = probabilities[i];
|
|
||||||
indices[j] = StaticCast<int>(i);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
HWY_ASSERT(k <= num_accepted);
|
return topk_indices[create_distribution(topk_probs, temperature)(gen)];
|
||||||
return indices[create_distribution(top_k, temperature)(gen)];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TAcceptToken>
|
template <typename TAcceptToken>
|
||||||
|
|
@ -745,40 +783,23 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
||||||
// Softmax and sample top-K is equivalent to taking the top-K logits and
|
// Softmax and sample top-K is equivalent to taking the top-K logits and
|
||||||
// sampling from the softmax of the top-K logits. The latter is faster as it
|
// sampling from the softmax of the top-K logits. The latter is faster as it
|
||||||
// avoids computing the softmax of all logits.
|
// avoids computing the softmax of all logits.
|
||||||
HWY_ASSERT(k != 0);
|
std::vector<TokenAndProb> token_logits =
|
||||||
HWY_ASSERT(k <= vocab_size);
|
TopK(logits, vocab_size, k, accept_token);
|
||||||
|
std::vector<int> topk_indices(k);
|
||||||
std::vector<float> top_k(k, -std::numeric_limits<float>::infinity());
|
std::vector<float> topk_logits(k);
|
||||||
std::vector<int> indices(k);
|
for (int i = 0; i < token_logits.size(); ++i) {
|
||||||
size_t num_accepted = 0;
|
topk_indices[i] = token_logits[i].token;
|
||||||
for (size_t i = 0; i < vocab_size; ++i) {
|
topk_logits[i] = token_logits[i].prob;
|
||||||
if (logits[i] < top_k[k - 1]) continue;
|
|
||||||
bool accepted =
|
|
||||||
!accept_token || accept_token(StaticCast<int>(i), logits[i]);
|
|
||||||
if (!accepted) continue;
|
|
||||||
num_accepted++;
|
|
||||||
for (size_t j = 0; j < k; ++j) {
|
|
||||||
if (logits[i] > top_k[j]) {
|
|
||||||
// shift elements by 1, insert the new value, move on to next value
|
|
||||||
for (size_t idx = k - 1; idx > j; --idx) {
|
|
||||||
top_k[idx] = top_k[idx - 1];
|
|
||||||
indices[idx] = indices[idx - 1];
|
|
||||||
}
|
|
||||||
top_k[j] = logits[i];
|
|
||||||
indices[j] = StaticCast<int>(i);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t mask = k <= num_accepted ? k : num_accepted;
|
size_t mask = token_logits.size();
|
||||||
Softmax(top_k.data(), mask, temperature);
|
Softmax(topk_logits.data(), mask, temperature);
|
||||||
auto distribution = std::discrete_distribution<int>(std::begin(top_k),
|
auto distribution = std::discrete_distribution<int>(
|
||||||
std::begin(top_k) + mask);
|
std::begin(topk_logits), std::begin(topk_logits) + mask);
|
||||||
int topk_sampled_index = distribution(gen);
|
int topk_sampled_index = distribution(gen);
|
||||||
int sampled_index = indices[topk_sampled_index];
|
int sampled_index = topk_indices[topk_sampled_index];
|
||||||
return TokenAndProb{.token = sampled_index,
|
return TokenAndProb{.token = sampled_index,
|
||||||
.prob = top_k[topk_sampled_index]};
|
.prob = topk_logits[topk_sampled_index]};
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
|
|
||||||
|
|
@ -600,6 +600,17 @@ void TestSampleTopK() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestPackTokenAndProb() {
|
||||||
|
double packed1 = PackTokenAndProb(10, 0.96f);
|
||||||
|
TokenAndProb unpacked1 = UnpackTokenAndProb(packed1);
|
||||||
|
EXPECT_EQ(unpacked1.token, 10);
|
||||||
|
EXPECT_NEAR(unpacked1.prob, 0.96f, 1e-6);
|
||||||
|
|
||||||
|
double packed2 = PackTokenAndProb(1000000000, 0.87f);
|
||||||
|
|
||||||
|
EXPECT_LT(packed2, packed1);
|
||||||
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
@ -621,6 +632,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNorm);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllLayerNorm);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllLayerNorm);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestLayerNormSimple);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestLayerNormSimple);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestSampleTopK);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestSampleTopK);
|
||||||
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestPackTokenAndProb);
|
||||||
HWY_AFTER_TEST();
|
HWY_AFTER_TEST();
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -57,10 +57,12 @@ static inline void MaybeCheckInitialized(const void* ptr, size_t size) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shared between gemma.h and ops-inl.h.
|
// Shared between gemma.h and ops-inl.h.
|
||||||
|
#pragma pack(push, 1)
|
||||||
struct TokenAndProb {
|
struct TokenAndProb {
|
||||||
int token;
|
int token;
|
||||||
float prob;
|
float prob;
|
||||||
};
|
};
|
||||||
|
#pragma pack(pop)
|
||||||
|
|
||||||
// Entire size of a 2D array.
|
// Entire size of a 2D array.
|
||||||
struct Extents2D {
|
struct Extents2D {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue