mirror of https://github.com/google/gemma.cpp.git
Add tests for SampleTopK that highlight existing problems and fix those:
- Sampling was not correct for k>1 and temperature=0. - Sampling was not correct for only negative logits. Also restructure the code a bit for better readability and add some asserts for things that shouldn't happen. PiperOrigin-RevId: 676043267
This commit is contained in:
parent
760a69449e
commit
03f0ee2323
|
|
@ -23,6 +23,7 @@
|
|||
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
#include <type_traits> // std::enable_if_t
|
||||
|
||||
|
|
@ -593,6 +594,12 @@ SampleArgmax(const float* probabilities, size_t vocab_size) {
|
|||
template <size_t k>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution<int>
|
||||
create_distribution(std::array<float, k>& top_k, float temperature) {
|
||||
HWY_ASSERT(temperature >= 0.0f);
|
||||
if (temperature == 0.0f) {
|
||||
// Temperature == 0 is a special case which always returns the argmax (0).
|
||||
// We also want to avoid dividing by zero in the code below.
|
||||
return std::discrete_distribution<int>();
|
||||
}
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
|
||||
|
|
@ -609,32 +616,35 @@ create_distribution(std::array<float, k>& top_k, float temperature) {
|
|||
|
||||
template <size_t k, typename TAcceptToken>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
||||
const float* HWY_RESTRICT probabilities, size_t vocab_size,
|
||||
const float* HWY_RESTRICT logits, size_t vocab_size,
|
||||
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
|
||||
static_assert(k != 0, "");
|
||||
HWY_ASSERT(k <= vocab_size);
|
||||
// TODO: Optimize, potentially using new VQSort PartialSort.
|
||||
std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1]
|
||||
top_k.fill(-std::numeric_limits<float>::infinity());
|
||||
std::array<int, k> indices{};
|
||||
size_t num_accepted = 0;
|
||||
for (size_t i = 0; i < vocab_size; ++i) {
|
||||
if (probabilities[i] < top_k[k - 1] &&
|
||||
(!accept_token || accept_token(StaticCast<int>(i), probabilities[i]))) {
|
||||
continue;
|
||||
}
|
||||
if (logits[i] < top_k[k - 1]) continue;
|
||||
bool accepted =
|
||||
!accept_token || accept_token(StaticCast<int>(i), logits[i]);
|
||||
if (!accepted) continue;
|
||||
num_accepted++;
|
||||
for (size_t j = 0; j < k; ++j) {
|
||||
if (probabilities[i] > top_k[j] &&
|
||||
(!accept_token ||
|
||||
accept_token(StaticCast<int>(i), probabilities[i]))) {
|
||||
if (logits[i] > top_k[j]) {
|
||||
// shift elements by 1, insert the new value, move on to next value
|
||||
for (size_t idx = k - 1; idx > j; --idx) {
|
||||
top_k[idx] = top_k[idx - 1];
|
||||
indices[idx] = indices[idx - 1];
|
||||
}
|
||||
top_k[j] = probabilities[i];
|
||||
top_k[j] = logits[i];
|
||||
indices[j] = StaticCast<int>(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
HWY_ASSERT(k <= num_accepted);
|
||||
return indices[create_distribution<k>(top_k, temperature)(gen)];
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@
|
|||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -559,6 +561,43 @@ void TestAllLayerNorm() {
|
|||
TestLayerNorm<float, BF16, BF16>(rng);
|
||||
}
|
||||
|
||||
void TestSampleTopK() {
|
||||
const size_t kSize = 52;
|
||||
std::vector<float> logits(kSize);
|
||||
// SampleTopK is typically used on logits, which can be negative.
|
||||
// Create a vector going from -100 to -100+51=49.
|
||||
std::iota(logits.begin(), logits.end(), -100.0f);
|
||||
std::mt19937 gen;
|
||||
gen.seed(0x12345678);
|
||||
float temperature = 1.0f;
|
||||
// SampleTopK<1> should return the argmax.
|
||||
std::function<bool(int, float)> accept_token;
|
||||
int sample = SampleTopK<1>(logits.data(), kSize, gen, temperature,
|
||||
accept_token);
|
||||
EXPECT_EQ(sample, 51); // Last is largest.
|
||||
// Only accept even tokens, expect the last (largest) even index.
|
||||
accept_token = [](int i, float) { return i % 2 == 0; };
|
||||
sample = SampleTopK<1>(logits.data(), kSize, gen, temperature,
|
||||
accept_token);
|
||||
EXPECT_EQ(sample, 50); // Last even index.
|
||||
// Reset the logits to a positive, increasing sequence.
|
||||
std::iota(logits.begin(), logits.end(), 1.0f);
|
||||
// Sample from the top 3, expect one of the top 3 even indices.
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
sample = SampleTopK<3>(logits.data(), kSize, gen, temperature,
|
||||
accept_token);
|
||||
EXPECT_TRUE(sample == 50 || sample == 48 || sample == 46);
|
||||
}
|
||||
// Now set the temperature to 0.0f, which should always return the argmax,
|
||||
// even for k=3.
|
||||
temperature = 0.0f;
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
sample = SampleTopK<3>(logits.data(), kSize, gen, temperature,
|
||||
accept_token);
|
||||
EXPECT_EQ(sample, 50);
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
@ -579,6 +618,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy);
|
|||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNorm);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllLayerNorm);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestLayerNormSimple);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestSampleTopK);
|
||||
HWY_AFTER_TEST();
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
Loading…
Reference in New Issue