diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 4183572..5203b86 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -23,6 +23,7 @@ #include #include +#include #include #include // std::enable_if_t @@ -593,6 +594,12 @@ SampleArgmax(const float* probabilities, size_t vocab_size) { template HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution create_distribution(std::array& top_k, float temperature) { + HWY_ASSERT(temperature >= 0.0f); + if (temperature == 0.0f) { + // Temperature == 0 is a special case which always returns the argmax (0). + // We also want to avoid dividing by zero in the code below. + return std::discrete_distribution(); + } namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; @@ -609,32 +616,35 @@ create_distribution(std::array& top_k, float temperature) { template HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( - const float* HWY_RESTRICT probabilities, size_t vocab_size, + const float* HWY_RESTRICT logits, size_t vocab_size, std::mt19937& gen, float temperature, TAcceptToken& accept_token) { static_assert(k != 0, ""); + HWY_ASSERT(k <= vocab_size); // TODO: Optimize, potentially using new VQSort PartialSort. std::array top_k{}; // sorted from highest [0], to lowest [k-1] + top_k.fill(-std::numeric_limits::infinity()); std::array indices{}; + size_t num_accepted = 0; for (size_t i = 0; i < vocab_size; ++i) { - if (probabilities[i] < top_k[k - 1] && - (!accept_token || accept_token(StaticCast(i), probabilities[i]))) { - continue; - } + if (logits[i] < top_k[k - 1]) continue; + bool accepted = + !accept_token || accept_token(StaticCast(i), logits[i]); + if (!accepted) continue; + num_accepted++; for (size_t j = 0; j < k; ++j) { - if (probabilities[i] > top_k[j] && - (!accept_token || - accept_token(StaticCast(i), probabilities[i]))) { + if (logits[i] > top_k[j]) { // shift elements by 1, insert the new value, move on to next value for (size_t idx = k - 1; idx > j; --idx) { top_k[idx] = top_k[idx - 1]; indices[idx] = indices[idx - 1]; } - top_k[j] = probabilities[i]; + top_k[j] = logits[i]; indices[j] = StaticCast(i); break; } } } + HWY_ASSERT(k <= num_accepted); return indices[create_distribution(top_k, temperature)(gen)]; } diff --git a/ops/ops_test.cc b/ops/ops_test.cc index a993883..b54ac43 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -24,6 +24,8 @@ #include #include #include +#include +#include #include #include @@ -559,6 +561,43 @@ void TestAllLayerNorm() { TestLayerNorm(rng); } +void TestSampleTopK() { + const size_t kSize = 52; + std::vector logits(kSize); + // SampleTopK is typically used on logits, which can be negative. + // Create a vector going from -100 to -100+51=49. + std::iota(logits.begin(), logits.end(), -100.0f); + std::mt19937 gen; + gen.seed(0x12345678); + float temperature = 1.0f; + // SampleTopK<1> should return the argmax. + std::function accept_token; + int sample = SampleTopK<1>(logits.data(), kSize, gen, temperature, + accept_token); + EXPECT_EQ(sample, 51); // Last is largest. + // Only accept even tokens, expect the last (largest) even index. + accept_token = [](int i, float) { return i % 2 == 0; }; + sample = SampleTopK<1>(logits.data(), kSize, gen, temperature, + accept_token); + EXPECT_EQ(sample, 50); // Last even index. + // Reset the logits to a positive, increasing sequence. + std::iota(logits.begin(), logits.end(), 1.0f); + // Sample from the top 3, expect one of the top 3 even indices. + for (int i = 0; i < 100; ++i) { + sample = SampleTopK<3>(logits.data(), kSize, gen, temperature, + accept_token); + EXPECT_TRUE(sample == 50 || sample == 48 || sample == 46); + } + // Now set the temperature to 0.0f, which should always return the argmax, + // even for k=3. + temperature = 0.0f; + for (int i = 0; i < 100; ++i) { + sample = SampleTopK<3>(logits.data(), kSize, gen, temperature, + accept_token); + EXPECT_EQ(sample, 50); + } +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp @@ -579,6 +618,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNorm); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllLayerNorm); HWY_EXPORT_AND_TEST_P(OpsTest, TestLayerNormSimple); +HWY_EXPORT_AND_TEST_P(OpsTest, TestSampleTopK); HWY_AFTER_TEST(); } // namespace gcpp