diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 0677d6c..6a9573d 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -1175,8 +1175,8 @@ SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) { // Fast path for top-1 with no accept_token. if (kTopK == 1 && !runtime_config.accept_token) { - PROFILER_ZONE("Gen.Sample Top1"); return [](float* logits, size_t vocab_size) -> TokenAndProb { + PROFILER_ZONE("Gen.Sample Top1"); return Top1OfSoftmax(logits, vocab_size); }; } diff --git a/ops/ops-inl.h b/ops/ops-inl.h index e1d1e0f..29fff5b 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -701,6 +701,7 @@ static HWY_INLINE TokenAndProb ArgmaxAndMax(const float* HWY_RESTRICT x, // with normalized probabilities. Only equivalent to `Softmax` + `sample_func` // if `kTopK` == 1. This is worthwhile because `num` is typically `kVocabSize` // == 256K, and this avoids writing and then scanning again for the max. +// However, this is not enough to make parallelization worthwhile. static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x, const size_t num) { namespace hn = hwy::HWY_NAMESPACE;