Implements FusedSoftmaxAndSampleTopK.

This computes softmax on the top-K logits, instead of computing softmax first and then getting top-K probs. So we end up avoiding renormalizing too. Additionally, modify softmax to do temperature scaling, if temp != 1.0

PiperOrigin-RevId: 727702149
This commit is contained in:
Apoorv Reddy 2025-02-16 21:29:05 -08:00 committed by Copybara-Service
parent bdf5d25e97
commit 0e5b59d24d
3 changed files with 57 additions and 7 deletions

View File

@ -106,7 +106,6 @@ cc_library(
"@highway//:hwy", "@highway//:hwy",
"@highway//:math", "@highway//:math",
"@highway//:matvec", "@highway//:matvec",
"@highway//:nanobenchmark", # timer
"@highway//:profiler", "@highway//:profiler",
"@highway//:thread_pool", "@highway//:thread_pool",
], ],

View File

@ -1210,11 +1210,9 @@ HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) {
return [&runtime_config](float* logits, return [&runtime_config](float* logits,
size_t vocab_size) HWY_ATTR -> TokenAndProb { size_t vocab_size) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE("Gen.Sample general"); PROFILER_ZONE("Gen.Sample general");
Softmax(logits, vocab_size); return FusedSoftmaxAndSampleTopK(
const int token = SampleTopK(
logits, runtime_config.top_k, vocab_size, *runtime_config.gen, logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
runtime_config.temperature, runtime_config.accept_token); runtime_config.temperature, runtime_config.accept_token);
return TokenAndProb{.token = token, .prob = logits[token]};
}; };
} }

View File

@ -501,7 +501,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
// See below for a specialized version for top-1 sampling. // See below for a specialized version for top-1 sampling.
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
const size_t mask_pos) { const size_t mask_pos,
float temperature = 1.0f) {
HWY_DASSERT(size != 0); HWY_DASSERT(size != 0);
HWY_DASSERT(mask_pos <= size); HWY_DASSERT(mask_pos <= size);
@ -528,6 +529,14 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
} }
}); });
if (temperature != 1.0f) {
const float temperature_inv = 1.0f / temperature;
hn::Transform(d, x, mask_pos,
[temperature_inv](const auto d, const V value) HWY_ATTR {
return hn::Mul(value, hn::Set(d, temperature_inv));
});
}
// Normalize to probability distribution. The exact sum seems like it should // Normalize to probability distribution. The exact sum seems like it should
// not make a huge difference. It halves the standard deviation of the sum of // not make a huge difference. It halves the standard deviation of the sum of
// the normalized probabilities from 1E-7 to 5E-8, but actually also changes // the normalized probabilities from 1E-7 to 5E-8, but actually also changes
@ -539,8 +548,9 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
} }
static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x, static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x,
const size_t size) { const size_t size,
Softmax(x, size, size); float temperature = 1.0f) {
Softmax(x, size, size, temperature);
} }
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max / // Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
@ -728,6 +738,49 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
return indices[create_distribution(top_k, temperature)(gen)]; return indices[create_distribution(top_k, temperature)(gen)];
} }
template <typename TAcceptToken>
HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
const float* HWY_RESTRICT logits, size_t k, size_t vocab_size,
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
// Softmax and sample top-K is equivalent to taking the top-K logits and
// sampling from the softmax of the top-K logits. The latter is faster as it
// avoids computing the softmax of all logits.
HWY_ASSERT(k != 0);
HWY_ASSERT(k <= vocab_size);
std::vector<float> top_k(k, -std::numeric_limits<float>::infinity());
std::vector<int> indices(k);
size_t num_accepted = 0;
for (size_t i = 0; i < vocab_size; ++i) {
if (logits[i] < top_k[k - 1]) continue;
bool accepted =
!accept_token || accept_token(StaticCast<int>(i), logits[i]);
if (!accepted) continue;
num_accepted++;
for (size_t j = 0; j < k; ++j) {
if (logits[i] > top_k[j]) {
// shift elements by 1, insert the new value, move on to next value
for (size_t idx = k - 1; idx > j; --idx) {
top_k[idx] = top_k[idx - 1];
indices[idx] = indices[idx - 1];
}
top_k[j] = logits[i];
indices[j] = StaticCast<int>(i);
break;
}
}
}
size_t mask = k <= num_accepted ? k : num_accepted;
Softmax(top_k.data(), mask, temperature);
auto distribution = std::discrete_distribution<int>(std::begin(top_k),
std::begin(top_k) + mask);
int topk_sampled_index = distribution(gen);
int sampled_index = indices[topk_sampled_index];
return TokenAndProb{.token = sampled_index,
.prob = top_k[topk_sampled_index]};
}
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE } // namespace HWY_NAMESPACE
} // namespace gcpp } // namespace gcpp