Rename one variable in SampleTopK and update TestSampleTopK.

PiperOrigin-RevId: 680897787
This commit is contained in:
Daniel Keysers 2024-10-01 00:50:54 -07:00 committed by Copybara-Service
parent 2d14d796e3
commit d83ad76679
2 changed files with 9 additions and 8 deletions

View File

@ -777,7 +777,7 @@ create_distribution(std::array<float, k>& top_k, float temperature) {
template <size_t k, typename TAcceptToken>
HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
const float* HWY_RESTRICT logits, size_t vocab_size,
const float* HWY_RESTRICT probabilities, size_t vocab_size,
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
static_assert(k != 0, "");
HWY_ASSERT(k <= vocab_size);
@ -787,19 +787,19 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
std::array<int, k> indices{};
size_t num_accepted = 0;
for (size_t i = 0; i < vocab_size; ++i) {
if (logits[i] < top_k[k - 1]) continue;
if (probabilities[i] < top_k[k - 1]) continue;
bool accepted =
!accept_token || accept_token(StaticCast<int>(i), logits[i]);
!accept_token || accept_token(StaticCast<int>(i), probabilities[i]);
if (!accepted) continue;
num_accepted++;
for (size_t j = 0; j < k; ++j) {
if (logits[i] > top_k[j]) {
if (probabilities[i] > top_k[j]) {
// shift elements by 1, insert the new value, move on to next value
for (size_t idx = k - 1; idx > j; --idx) {
top_k[idx] = top_k[idx - 1];
indices[idx] = indices[idx - 1];
}
top_k[j] = logits[i];
top_k[j] = probabilities[i];
indices[j] = StaticCast<int>(i);
break;
}

View File

@ -564,9 +564,9 @@ void TestAllLayerNorm() {
void TestSampleTopK() {
const size_t kSize = 52;
std::vector<float> logits(kSize);
// SampleTopK is typically used on logits, which can be negative.
// Create a vector going from -100 to -100+51=49.
// Create a vector going from -100 to -100+51=49 and take Softmax.
std::iota(logits.begin(), logits.end(), -100.0f);
Softmax(logits.data(), kSize);
std::mt19937 gen;
gen.seed(0x12345678);
float temperature = 1.0f;
@ -580,8 +580,9 @@ void TestSampleTopK() {
sample = SampleTopK<1>(logits.data(), kSize, gen, temperature,
accept_token);
EXPECT_EQ(sample, 50); // Last even index.
// Reset the logits to a positive, increasing sequence.
// Reset the logits to a positive, increasing sequence and take Softmax.
std::iota(logits.begin(), logits.end(), 1.0f);
Softmax(logits.data(), kSize);
// Sample from the top 3, expect one of the top 3 even indices.
for (int i = 0; i < 100; ++i) {
sample = SampleTopK<3>(logits.data(), kSize, gen, temperature,