diff --git a/BUILD.bazel b/BUILD.bazel index eb7cf72..0cf1801 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -106,7 +106,6 @@ cc_library( "@highway//:hwy", "@highway//:math", "@highway//:matvec", - "@highway//:nanobenchmark", # timer "@highway//:profiler", "@highway//:thread_pool", ], diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 76c2b1e..b41585e 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -1210,11 +1210,9 @@ HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) { return [&runtime_config](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb { PROFILER_ZONE("Gen.Sample general"); - Softmax(logits, vocab_size); - const int token = SampleTopK( + return FusedSoftmaxAndSampleTopK( logits, runtime_config.top_k, vocab_size, *runtime_config.gen, runtime_config.temperature, runtime_config.accept_token); - return TokenAndProb{.token = token, .prob = logits[token]}; }; } diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 9919abf..aad636b 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -501,7 +501,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( // See below for a specialized version for top-1 sampling. static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, - const size_t mask_pos) { + const size_t mask_pos, + float temperature = 1.0f) { HWY_DASSERT(size != 0); HWY_DASSERT(mask_pos <= size); @@ -528,6 +529,14 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, } }); + if (temperature != 1.0f) { + const float temperature_inv = 1.0f / temperature; + hn::Transform(d, x, mask_pos, + [temperature_inv](const auto d, const V value) HWY_ATTR { + return hn::Mul(value, hn::Set(d, temperature_inv)); + }); + } + // Normalize to probability distribution. The exact sum seems like it should // not make a huge difference. It halves the standard deviation of the sum of // the normalized probabilities from 1E-7 to 5E-8, but actually also changes @@ -539,8 +548,9 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, } static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x, - const size_t size) { - Softmax(x, size, size); + const size_t size, + float temperature = 1.0f) { + Softmax(x, size, size, temperature); } // Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max / @@ -728,6 +738,49 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( return indices[create_distribution(top_k, temperature)(gen)]; } +template +HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK( + const float* HWY_RESTRICT logits, size_t k, size_t vocab_size, + std::mt19937& gen, float temperature, TAcceptToken& accept_token) { + // Softmax and sample top-K is equivalent to taking the top-K logits and + // sampling from the softmax of the top-K logits. The latter is faster as it + // avoids computing the softmax of all logits. + HWY_ASSERT(k != 0); + HWY_ASSERT(k <= vocab_size); + + std::vector top_k(k, -std::numeric_limits::infinity()); + std::vector indices(k); + size_t num_accepted = 0; + for (size_t i = 0; i < vocab_size; ++i) { + if (logits[i] < top_k[k - 1]) continue; + bool accepted = + !accept_token || accept_token(StaticCast(i), logits[i]); + if (!accepted) continue; + num_accepted++; + for (size_t j = 0; j < k; ++j) { + if (logits[i] > top_k[j]) { + // shift elements by 1, insert the new value, move on to next value + for (size_t idx = k - 1; idx > j; --idx) { + top_k[idx] = top_k[idx - 1]; + indices[idx] = indices[idx - 1]; + } + top_k[j] = logits[i]; + indices[j] = StaticCast(i); + break; + } + } + } + + size_t mask = k <= num_accepted ? k : num_accepted; + Softmax(top_k.data(), mask, temperature); + auto distribution = std::discrete_distribution(std::begin(top_k), + std::begin(top_k) + mask); + int topk_sampled_index = distribution(gen); + int sampled_index = indices[topk_sampled_index]; + return TokenAndProb{.token = sampled_index, + .prob = top_k[topk_sampled_index]}; +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp