mirror of https://github.com/google/gemma.cpp.git
Implements FusedSoftmaxAndSampleTopK.
This computes softmax on the top-K logits, instead of computing softmax first and then getting top-K probs. So we end up avoiding renormalizing too. Additionally, modify softmax to do temperature scaling, if temp != 1.0 PiperOrigin-RevId: 727702149
This commit is contained in:
parent
bdf5d25e97
commit
0e5b59d24d
|
|
@ -106,7 +106,6 @@ cc_library(
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:math",
|
"@highway//:math",
|
||||||
"@highway//:matvec",
|
"@highway//:matvec",
|
||||||
"@highway//:nanobenchmark", # timer
|
|
||||||
"@highway//:profiler",
|
"@highway//:profiler",
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -1210,11 +1210,9 @@ HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) {
|
||||||
return [&runtime_config](float* logits,
|
return [&runtime_config](float* logits,
|
||||||
size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||||
PROFILER_ZONE("Gen.Sample general");
|
PROFILER_ZONE("Gen.Sample general");
|
||||||
Softmax(logits, vocab_size);
|
return FusedSoftmaxAndSampleTopK(
|
||||||
const int token = SampleTopK(
|
|
||||||
logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
|
logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
|
||||||
runtime_config.temperature, runtime_config.accept_token);
|
runtime_config.temperature, runtime_config.accept_token);
|
||||||
return TokenAndProb{.token = token, .prob = logits[token]};
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -501,7 +501,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||||
|
|
||||||
// See below for a specialized version for top-1 sampling.
|
// See below for a specialized version for top-1 sampling.
|
||||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
const size_t mask_pos) {
|
const size_t mask_pos,
|
||||||
|
float temperature = 1.0f) {
|
||||||
HWY_DASSERT(size != 0);
|
HWY_DASSERT(size != 0);
|
||||||
HWY_DASSERT(mask_pos <= size);
|
HWY_DASSERT(mask_pos <= size);
|
||||||
|
|
||||||
|
|
@ -528,6 +529,14 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if (temperature != 1.0f) {
|
||||||
|
const float temperature_inv = 1.0f / temperature;
|
||||||
|
hn::Transform(d, x, mask_pos,
|
||||||
|
[temperature_inv](const auto d, const V value) HWY_ATTR {
|
||||||
|
return hn::Mul(value, hn::Set(d, temperature_inv));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Normalize to probability distribution. The exact sum seems like it should
|
// Normalize to probability distribution. The exact sum seems like it should
|
||||||
// not make a huge difference. It halves the standard deviation of the sum of
|
// not make a huge difference. It halves the standard deviation of the sum of
|
||||||
// the normalized probabilities from 1E-7 to 5E-8, but actually also changes
|
// the normalized probabilities from 1E-7 to 5E-8, but actually also changes
|
||||||
|
|
@ -539,8 +548,9 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x,
|
static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x,
|
||||||
const size_t size) {
|
const size_t size,
|
||||||
Softmax(x, size, size);
|
float temperature = 1.0f) {
|
||||||
|
Softmax(x, size, size, temperature);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
|
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
|
||||||
|
|
@ -728,6 +738,49 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
||||||
return indices[create_distribution(top_k, temperature)(gen)];
|
return indices[create_distribution(top_k, temperature)(gen)];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename TAcceptToken>
|
||||||
|
HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
||||||
|
const float* HWY_RESTRICT logits, size_t k, size_t vocab_size,
|
||||||
|
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
|
||||||
|
// Softmax and sample top-K is equivalent to taking the top-K logits and
|
||||||
|
// sampling from the softmax of the top-K logits. The latter is faster as it
|
||||||
|
// avoids computing the softmax of all logits.
|
||||||
|
HWY_ASSERT(k != 0);
|
||||||
|
HWY_ASSERT(k <= vocab_size);
|
||||||
|
|
||||||
|
std::vector<float> top_k(k, -std::numeric_limits<float>::infinity());
|
||||||
|
std::vector<int> indices(k);
|
||||||
|
size_t num_accepted = 0;
|
||||||
|
for (size_t i = 0; i < vocab_size; ++i) {
|
||||||
|
if (logits[i] < top_k[k - 1]) continue;
|
||||||
|
bool accepted =
|
||||||
|
!accept_token || accept_token(StaticCast<int>(i), logits[i]);
|
||||||
|
if (!accepted) continue;
|
||||||
|
num_accepted++;
|
||||||
|
for (size_t j = 0; j < k; ++j) {
|
||||||
|
if (logits[i] > top_k[j]) {
|
||||||
|
// shift elements by 1, insert the new value, move on to next value
|
||||||
|
for (size_t idx = k - 1; idx > j; --idx) {
|
||||||
|
top_k[idx] = top_k[idx - 1];
|
||||||
|
indices[idx] = indices[idx - 1];
|
||||||
|
}
|
||||||
|
top_k[j] = logits[i];
|
||||||
|
indices[j] = StaticCast<int>(i);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t mask = k <= num_accepted ? k : num_accepted;
|
||||||
|
Softmax(top_k.data(), mask, temperature);
|
||||||
|
auto distribution = std::discrete_distribution<int>(std::begin(top_k),
|
||||||
|
std::begin(top_k) + mask);
|
||||||
|
int topk_sampled_index = distribution(gen);
|
||||||
|
int sampled_index = indices[topk_sampled_index];
|
||||||
|
return TokenAndProb{.token = sampled_index,
|
||||||
|
.prob = top_k[topk_sampled_index]};
|
||||||
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue