Add tests for SampleTopK that highlight existing problems and fix those:

- Sampling was not correct for k>1 and temperature=0.
- Sampling was not correct for only negative logits.

Also restructure the code a bit for better readability and add some asserts for things that shouldn't happen.

PiperOrigin-RevId: 676043267
This commit is contained in:
Daniel Keysers 2024-09-18 10:30:52 -07:00 committed by Copybara-Service
parent 760a69449e
commit 03f0ee2323
2 changed files with 59 additions and 9 deletions

View File

@ -23,6 +23,7 @@
#include <array> #include <array>
#include <cmath> #include <cmath>
#include <limits>
#include <random> #include <random>
#include <type_traits> // std::enable_if_t #include <type_traits> // std::enable_if_t
@ -593,6 +594,12 @@ SampleArgmax(const float* probabilities, size_t vocab_size) {
template <size_t k> template <size_t k>
HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution<int> HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution<int>
create_distribution(std::array<float, k>& top_k, float temperature) { create_distribution(std::array<float, k>& top_k, float temperature) {
HWY_ASSERT(temperature >= 0.0f);
if (temperature == 0.0f) {
// Temperature == 0 is a special case which always returns the argmax (0).
// We also want to avoid dividing by zero in the code below.
return std::discrete_distribution<int>();
}
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>; using D = hn::ScalableTag<float>;
@ -609,32 +616,35 @@ create_distribution(std::array<float, k>& top_k, float temperature) {
template <size_t k, typename TAcceptToken> template <size_t k, typename TAcceptToken>
HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
const float* HWY_RESTRICT probabilities, size_t vocab_size, const float* HWY_RESTRICT logits, size_t vocab_size,
std::mt19937& gen, float temperature, TAcceptToken& accept_token) { std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
static_assert(k != 0, ""); static_assert(k != 0, "");
HWY_ASSERT(k <= vocab_size);
// TODO: Optimize, potentially using new VQSort PartialSort. // TODO: Optimize, potentially using new VQSort PartialSort.
std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1] std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1]
top_k.fill(-std::numeric_limits<float>::infinity());
std::array<int, k> indices{}; std::array<int, k> indices{};
size_t num_accepted = 0;
for (size_t i = 0; i < vocab_size; ++i) { for (size_t i = 0; i < vocab_size; ++i) {
if (probabilities[i] < top_k[k - 1] && if (logits[i] < top_k[k - 1]) continue;
(!accept_token || accept_token(StaticCast<int>(i), probabilities[i]))) { bool accepted =
continue; !accept_token || accept_token(StaticCast<int>(i), logits[i]);
} if (!accepted) continue;
num_accepted++;
for (size_t j = 0; j < k; ++j) { for (size_t j = 0; j < k; ++j) {
if (probabilities[i] > top_k[j] && if (logits[i] > top_k[j]) {
(!accept_token ||
accept_token(StaticCast<int>(i), probabilities[i]))) {
// shift elements by 1, insert the new value, move on to next value // shift elements by 1, insert the new value, move on to next value
for (size_t idx = k - 1; idx > j; --idx) { for (size_t idx = k - 1; idx > j; --idx) {
top_k[idx] = top_k[idx - 1]; top_k[idx] = top_k[idx - 1];
indices[idx] = indices[idx - 1]; indices[idx] = indices[idx - 1];
} }
top_k[j] = probabilities[i]; top_k[j] = logits[i];
indices[j] = StaticCast<int>(i); indices[j] = StaticCast<int>(i);
break; break;
} }
} }
} }
HWY_ASSERT(k <= num_accepted);
return indices[create_distribution<k>(top_k, temperature)(gen)]; return indices[create_distribution<k>(top_k, temperature)(gen)];
} }

View File

@ -24,6 +24,8 @@
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <cmath> #include <cmath>
#include <functional>
#include <numeric>
#include <random> #include <random>
#include <vector> #include <vector>
@ -559,6 +561,43 @@ void TestAllLayerNorm() {
TestLayerNorm<float, BF16, BF16>(rng); TestLayerNorm<float, BF16, BF16>(rng);
} }
void TestSampleTopK() {
const size_t kSize = 52;
std::vector<float> logits(kSize);
// SampleTopK is typically used on logits, which can be negative.
// Create a vector going from -100 to -100+51=49.
std::iota(logits.begin(), logits.end(), -100.0f);
std::mt19937 gen;
gen.seed(0x12345678);
float temperature = 1.0f;
// SampleTopK<1> should return the argmax.
std::function<bool(int, float)> accept_token;
int sample = SampleTopK<1>(logits.data(), kSize, gen, temperature,
accept_token);
EXPECT_EQ(sample, 51); // Last is largest.
// Only accept even tokens, expect the last (largest) even index.
accept_token = [](int i, float) { return i % 2 == 0; };
sample = SampleTopK<1>(logits.data(), kSize, gen, temperature,
accept_token);
EXPECT_EQ(sample, 50); // Last even index.
// Reset the logits to a positive, increasing sequence.
std::iota(logits.begin(), logits.end(), 1.0f);
// Sample from the top 3, expect one of the top 3 even indices.
for (int i = 0; i < 100; ++i) {
sample = SampleTopK<3>(logits.data(), kSize, gen, temperature,
accept_token);
EXPECT_TRUE(sample == 50 || sample == 48 || sample == 46);
}
// Now set the temperature to 0.0f, which should always return the argmax,
// even for k=3.
temperature = 0.0f;
for (int i = 0; i < 100; ++i) {
sample = SampleTopK<3>(logits.data(), kSize, gen, temperature,
accept_token);
EXPECT_EQ(sample, 50);
}
}
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE } // namespace HWY_NAMESPACE
} // namespace gcpp } // namespace gcpp
@ -579,6 +618,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNorm); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNorm);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllLayerNorm); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllLayerNorm);
HWY_EXPORT_AND_TEST_P(OpsTest, TestLayerNormSimple); HWY_EXPORT_AND_TEST_P(OpsTest, TestLayerNormSimple);
HWY_EXPORT_AND_TEST_P(OpsTest, TestSampleTopK);
HWY_AFTER_TEST(); HWY_AFTER_TEST();
} // namespace gcpp } // namespace gcpp