mirror of https://github.com/google/gemma.cpp.git
Rename one variable in SampleTopK and update TestSampleTopK.
PiperOrigin-RevId: 680897787
This commit is contained in:
parent
2d14d796e3
commit
d83ad76679
|
|
@ -777,7 +777,7 @@ create_distribution(std::array<float, k>& top_k, float temperature) {
|
||||||
|
|
||||||
template <size_t k, typename TAcceptToken>
|
template <size_t k, typename TAcceptToken>
|
||||||
HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
||||||
const float* HWY_RESTRICT logits, size_t vocab_size,
|
const float* HWY_RESTRICT probabilities, size_t vocab_size,
|
||||||
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
|
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
|
||||||
static_assert(k != 0, "");
|
static_assert(k != 0, "");
|
||||||
HWY_ASSERT(k <= vocab_size);
|
HWY_ASSERT(k <= vocab_size);
|
||||||
|
|
@ -787,19 +787,19 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
||||||
std::array<int, k> indices{};
|
std::array<int, k> indices{};
|
||||||
size_t num_accepted = 0;
|
size_t num_accepted = 0;
|
||||||
for (size_t i = 0; i < vocab_size; ++i) {
|
for (size_t i = 0; i < vocab_size; ++i) {
|
||||||
if (logits[i] < top_k[k - 1]) continue;
|
if (probabilities[i] < top_k[k - 1]) continue;
|
||||||
bool accepted =
|
bool accepted =
|
||||||
!accept_token || accept_token(StaticCast<int>(i), logits[i]);
|
!accept_token || accept_token(StaticCast<int>(i), probabilities[i]);
|
||||||
if (!accepted) continue;
|
if (!accepted) continue;
|
||||||
num_accepted++;
|
num_accepted++;
|
||||||
for (size_t j = 0; j < k; ++j) {
|
for (size_t j = 0; j < k; ++j) {
|
||||||
if (logits[i] > top_k[j]) {
|
if (probabilities[i] > top_k[j]) {
|
||||||
// shift elements by 1, insert the new value, move on to next value
|
// shift elements by 1, insert the new value, move on to next value
|
||||||
for (size_t idx = k - 1; idx > j; --idx) {
|
for (size_t idx = k - 1; idx > j; --idx) {
|
||||||
top_k[idx] = top_k[idx - 1];
|
top_k[idx] = top_k[idx - 1];
|
||||||
indices[idx] = indices[idx - 1];
|
indices[idx] = indices[idx - 1];
|
||||||
}
|
}
|
||||||
top_k[j] = logits[i];
|
top_k[j] = probabilities[i];
|
||||||
indices[j] = StaticCast<int>(i);
|
indices[j] = StaticCast<int>(i);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -564,9 +564,9 @@ void TestAllLayerNorm() {
|
||||||
void TestSampleTopK() {
|
void TestSampleTopK() {
|
||||||
const size_t kSize = 52;
|
const size_t kSize = 52;
|
||||||
std::vector<float> logits(kSize);
|
std::vector<float> logits(kSize);
|
||||||
// SampleTopK is typically used on logits, which can be negative.
|
// Create a vector going from -100 to -100+51=49 and take Softmax.
|
||||||
// Create a vector going from -100 to -100+51=49.
|
|
||||||
std::iota(logits.begin(), logits.end(), -100.0f);
|
std::iota(logits.begin(), logits.end(), -100.0f);
|
||||||
|
Softmax(logits.data(), kSize);
|
||||||
std::mt19937 gen;
|
std::mt19937 gen;
|
||||||
gen.seed(0x12345678);
|
gen.seed(0x12345678);
|
||||||
float temperature = 1.0f;
|
float temperature = 1.0f;
|
||||||
|
|
@ -580,8 +580,9 @@ void TestSampleTopK() {
|
||||||
sample = SampleTopK<1>(logits.data(), kSize, gen, temperature,
|
sample = SampleTopK<1>(logits.data(), kSize, gen, temperature,
|
||||||
accept_token);
|
accept_token);
|
||||||
EXPECT_EQ(sample, 50); // Last even index.
|
EXPECT_EQ(sample, 50); // Last even index.
|
||||||
// Reset the logits to a positive, increasing sequence.
|
// Reset the logits to a positive, increasing sequence and take Softmax.
|
||||||
std::iota(logits.begin(), logits.end(), 1.0f);
|
std::iota(logits.begin(), logits.end(), 1.0f);
|
||||||
|
Softmax(logits.data(), kSize);
|
||||||
// Sample from the top 3, expect one of the top 3 even indices.
|
// Sample from the top 3, expect one of the top 3 even indices.
|
||||||
for (int i = 0; i < 100; ++i) {
|
for (int i = 0; i < 100; ++i) {
|
||||||
sample = SampleTopK<3>(logits.data(), kSize, gen, temperature,
|
sample = SampleTopK<3>(logits.data(), kSize, gen, temperature,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue