mirror of https://github.com/google/gemma.cpp.git
Implements FusedSoftmaxAndSampleTopK.
This computes softmax on the top-K logits, instead of computing softmax first and then getting top-K probs. So we end up avoiding renormalizing too. Additionally, modify softmax to do temperature scaling, if temp != 1.0 PiperOrigin-RevId: 727702149
This commit is contained in:
parent
bdf5d25e97
commit
0e5b59d24d
|
|
@ -106,7 +106,6 @@ cc_library(
|
|||
"@highway//:hwy",
|
||||
"@highway//:math",
|
||||
"@highway//:matvec",
|
||||
"@highway//:nanobenchmark", # timer
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -1210,11 +1210,9 @@ HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) {
|
|||
return [&runtime_config](float* logits,
|
||||
size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||
PROFILER_ZONE("Gen.Sample general");
|
||||
Softmax(logits, vocab_size);
|
||||
const int token = SampleTopK(
|
||||
return FusedSoftmaxAndSampleTopK(
|
||||
logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
|
||||
runtime_config.temperature, runtime_config.accept_token);
|
||||
return TokenAndProb{.token = token, .prob = logits[token]};
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -501,7 +501,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
|||
|
||||
// See below for a specialized version for top-1 sampling.
|
||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||
const size_t mask_pos) {
|
||||
const size_t mask_pos,
|
||||
float temperature = 1.0f) {
|
||||
HWY_DASSERT(size != 0);
|
||||
HWY_DASSERT(mask_pos <= size);
|
||||
|
||||
|
|
@ -528,6 +529,14 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
|||
}
|
||||
});
|
||||
|
||||
if (temperature != 1.0f) {
|
||||
const float temperature_inv = 1.0f / temperature;
|
||||
hn::Transform(d, x, mask_pos,
|
||||
[temperature_inv](const auto d, const V value) HWY_ATTR {
|
||||
return hn::Mul(value, hn::Set(d, temperature_inv));
|
||||
});
|
||||
}
|
||||
|
||||
// Normalize to probability distribution. The exact sum seems like it should
|
||||
// not make a huge difference. It halves the standard deviation of the sum of
|
||||
// the normalized probabilities from 1E-7 to 5E-8, but actually also changes
|
||||
|
|
@ -539,8 +548,9 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
|||
}
|
||||
|
||||
static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x,
|
||||
const size_t size) {
|
||||
Softmax(x, size, size);
|
||||
const size_t size,
|
||||
float temperature = 1.0f) {
|
||||
Softmax(x, size, size, temperature);
|
||||
}
|
||||
|
||||
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
|
||||
|
|
@ -728,6 +738,49 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
|||
return indices[create_distribution(top_k, temperature)(gen)];
|
||||
}
|
||||
|
||||
template <typename TAcceptToken>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
||||
const float* HWY_RESTRICT logits, size_t k, size_t vocab_size,
|
||||
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
|
||||
// Softmax and sample top-K is equivalent to taking the top-K logits and
|
||||
// sampling from the softmax of the top-K logits. The latter is faster as it
|
||||
// avoids computing the softmax of all logits.
|
||||
HWY_ASSERT(k != 0);
|
||||
HWY_ASSERT(k <= vocab_size);
|
||||
|
||||
std::vector<float> top_k(k, -std::numeric_limits<float>::infinity());
|
||||
std::vector<int> indices(k);
|
||||
size_t num_accepted = 0;
|
||||
for (size_t i = 0; i < vocab_size; ++i) {
|
||||
if (logits[i] < top_k[k - 1]) continue;
|
||||
bool accepted =
|
||||
!accept_token || accept_token(StaticCast<int>(i), logits[i]);
|
||||
if (!accepted) continue;
|
||||
num_accepted++;
|
||||
for (size_t j = 0; j < k; ++j) {
|
||||
if (logits[i] > top_k[j]) {
|
||||
// shift elements by 1, insert the new value, move on to next value
|
||||
for (size_t idx = k - 1; idx > j; --idx) {
|
||||
top_k[idx] = top_k[idx - 1];
|
||||
indices[idx] = indices[idx - 1];
|
||||
}
|
||||
top_k[j] = logits[i];
|
||||
indices[j] = StaticCast<int>(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t mask = k <= num_accepted ? k : num_accepted;
|
||||
Softmax(top_k.data(), mask, temperature);
|
||||
auto distribution = std::discrete_distribution<int>(std::begin(top_k),
|
||||
std::begin(top_k) + mask);
|
||||
int topk_sampled_index = distribution(gen);
|
||||
int sampled_index = indices[topk_sampled_index];
|
||||
return TokenAndProb{.token = sampled_index,
|
||||
.prob = top_k[topk_sampled_index]};
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
Loading…
Reference in New Issue