From d83ad766791a7519494d73a0d5057a012c9e2f78 Mon Sep 17 00:00:00 2001 From: Daniel Keysers Date: Tue, 1 Oct 2024 00:50:54 -0700 Subject: [PATCH] Rename one variable in SampleTopK and update TestSampleTopK. PiperOrigin-RevId: 680897787 --- ops/ops-inl.h | 10 +++++----- ops/ops_test.cc | 7 ++++--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 31b3b00..5b9b007 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -777,7 +777,7 @@ create_distribution(std::array& top_k, float temperature) { template HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( - const float* HWY_RESTRICT logits, size_t vocab_size, + const float* HWY_RESTRICT probabilities, size_t vocab_size, std::mt19937& gen, float temperature, TAcceptToken& accept_token) { static_assert(k != 0, ""); HWY_ASSERT(k <= vocab_size); @@ -787,19 +787,19 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( std::array indices{}; size_t num_accepted = 0; for (size_t i = 0; i < vocab_size; ++i) { - if (logits[i] < top_k[k - 1]) continue; + if (probabilities[i] < top_k[k - 1]) continue; bool accepted = - !accept_token || accept_token(StaticCast(i), logits[i]); + !accept_token || accept_token(StaticCast(i), probabilities[i]); if (!accepted) continue; num_accepted++; for (size_t j = 0; j < k; ++j) { - if (logits[i] > top_k[j]) { + if (probabilities[i] > top_k[j]) { // shift elements by 1, insert the new value, move on to next value for (size_t idx = k - 1; idx > j; --idx) { top_k[idx] = top_k[idx - 1]; indices[idx] = indices[idx - 1]; } - top_k[j] = logits[i]; + top_k[j] = probabilities[i]; indices[j] = StaticCast(i); break; } diff --git a/ops/ops_test.cc b/ops/ops_test.cc index b54ac43..ddce780 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -564,9 +564,9 @@ void TestAllLayerNorm() { void TestSampleTopK() { const size_t kSize = 52; std::vector logits(kSize); - // SampleTopK is typically used on logits, which can be negative. - // Create a vector going from -100 to -100+51=49. + // Create a vector going from -100 to -100+51=49 and take Softmax. std::iota(logits.begin(), logits.end(), -100.0f); + Softmax(logits.data(), kSize); std::mt19937 gen; gen.seed(0x12345678); float temperature = 1.0f; @@ -580,8 +580,9 @@ void TestSampleTopK() { sample = SampleTopK<1>(logits.data(), kSize, gen, temperature, accept_token); EXPECT_EQ(sample, 50); // Last even index. - // Reset the logits to a positive, increasing sequence. + // Reset the logits to a positive, increasing sequence and take Softmax. std::iota(logits.begin(), logits.end(), 1.0f); + Softmax(logits.data(), kSize); // Sample from the top 3, expect one of the top 3 even indices. for (int i = 0; i < 100; ++i) { sample = SampleTopK<3>(logits.data(), kSize, gen, temperature,