Make top_k a runtime argument (instead of a model argument).

PiperOrigin-RevId: 696170691
This commit is contained in:
Daniel Keysers 2024-11-13 09:48:25 -08:00 committed by Copybara-Service
parent b94295b6d9
commit 719699f132
13 changed files with 31 additions and 26 deletions

View File

@ -74,8 +74,8 @@ TEST(OptimizeTest, GradientDescent) {
RuntimeConfig runtime = {
.max_generated_tokens = 16,
.temperature = 1.0f,
.verbosity = 0,
.gen = &gen,
.verbosity = 0,
.stream_token = stream_token,
.eos_id = ReverseSequenceSampler::kEndToken,
};

View File

@ -74,8 +74,8 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
runtime_config_ = {
.max_generated_tokens = inference.max_generated_tokens,
.temperature = inference.temperature,
.verbosity = app.verbosity,
.gen = &gen_,
.verbosity = app.verbosity,
};
}

View File

@ -139,8 +139,8 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
RuntimeConfig runtime = {
.max_generated_tokens = max_generated_tokens - 1,
.temperature = 0.0f,
.verbosity = verbosity,
.gen = nullptr,
.verbosity = verbosity,
.stream_token = stream_token,
.sample_func = sample_token,
};

View File

@ -169,8 +169,8 @@ TEST_F(GemmaTest, Multiturn) {
RuntimeConfig runtime_config{
.max_generated_tokens = 64,
.temperature = 0.0f,
.verbosity = 2,
.gen = &s_env->MutableGen(),
.verbosity = 2,
.stream_token = stream_token,
};
TimingInfo timing_info{.verbosity = 0};

View File

@ -127,8 +127,8 @@ void Run(GemmaEnv& env, JsonArgs& json) {
gcpp::RuntimeConfig runtime_config = {
.max_generated_tokens = 30,
.temperature = 0.0f,
.verbosity = env.Verbosity(),
.gen = &env.MutableGen(),
.verbosity = env.Verbosity(),
.stream_token = stream_token,
};
env.GetModel()->Generate(runtime_config, prompt, /*pos=*/0,

View File

@ -89,8 +89,8 @@ int main(int argc, char** argv) {
gcpp::RuntimeConfig runtime_config = {
.max_generated_tokens = 1024,
.temperature = 1.0,
.verbosity = 0,
.gen = &gen,
.verbosity = 0,
.stream_token = stream_token,
.accept_token =
[&](int token, float /* prob */) {

View File

@ -178,7 +178,6 @@ struct ModelConfig {
size_t vit_seq_len = 0;
size_t num_tensor_scales = 0;
size_t num_vit_scales = 0;
size_t top_k = kTopK;
float att_cap = 0.0f;
float final_cap = 0.0f;
bool absolute_pe = false;

View File

@ -374,7 +374,7 @@ void AssertMatch(const ModelConfig& config) {
}
ASSERT_EQ(TConfig::kVocabSize, config.vocab_size);
ASSERT_EQ(TConfig::kSeqLen, config.seq_len);
ASSERT_EQ(TConfig::kTopK, config.top_k);
// ASSERT_EQ(TConfig::kTopK, config.top_k); - is now a runtime config value.
ASSERT_EQ(TConfig::kAttCap, config.att_cap);
ASSERT_EQ(TConfig::kFinalCap, config.final_cap);
ASSERT_EQ(TConfig::kAbsolutePE, config.absolute_pe);

View File

@ -1196,13 +1196,12 @@ class TokenStreamer {
hwy::BitSet4096<> is_eos_;
};
HWY_INLINE SampleFunc ChooseSampleFunc(int top_k,
const RuntimeConfig& runtime_config) {
HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) {
// If user provided a sample_func, use it.
if (runtime_config.sample_func) return runtime_config.sample_func;
// Fast path for top-1 with no accept_token.
if (top_k == 1 && !runtime_config.accept_token) {
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE("Gen.Sample Top1");
return Top1OfSoftmax(logits, vocab_size);
@ -1210,12 +1209,12 @@ HWY_INLINE SampleFunc ChooseSampleFunc(int top_k,
}
// General case: Softmax with top-k sampling.
return [top_k, &runtime_config](float* logits,
return [&runtime_config](float* logits,
size_t vocab_size) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE("Gen.Sample general");
Softmax(logits, vocab_size);
const int token =
SampleTopK(logits, top_k, vocab_size, *runtime_config.gen,
const int token = SampleTopK(
logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
runtime_config.temperature, runtime_config.accept_token);
return TokenAndProb{.token = token, .prob = logits[token]};
};
@ -1276,8 +1275,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
size_t max_prompt_size = MaxQueryLength(queries_prompt);
size_t max_generated_tokens = runtime_config.max_generated_tokens;
RangeChecks(weights.weights_config, max_generated_tokens, max_prompt_size);
const SampleFunc sample_token =
ChooseSampleFunc(weights.weights_config.top_k, runtime_config);
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
// Prefill stops before min_prompt_size - 1 because the last prompt
// token is the first input token for generation.

View File

@ -25,6 +25,7 @@
#include "compression/io.h" // Path
#include "gemma/activations.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/kv_cache.h"
#include "gemma/tokenizer.h"
#include "gemma/weights.h"
@ -102,10 +103,13 @@ struct RuntimeConfig {
// Max queries per batch (one token from each) during decode.
size_t decode_qbatch_size = 16;
// Sampling-related parameters.
float temperature; // Temperature for sampling.
int verbosity; // Controls verbosity of printed messages.
size_t top_k = kTopK; // Top-k for sampling.
std::mt19937* gen; // Random number generator used for sampling.
int verbosity; // Controls verbosity of printed messages.
// Functions operating on the generated tokens.
StreamFunc stream_token;
BatchStreamFunc batch_stream_token;

View File

@ -99,7 +99,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
HWY_ASSERT(image.ReadPPM(args.image_file.path));
image.Resize();
RuntimeConfig runtime_config = {
.verbosity = app.verbosity, .gen = &gen, .use_spinning = app.spin};
.gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin};
double image_tokens_start = hwy::platform::Now();
model.GenerateImageTokens(runtime_config, image, image_tokens);
if (app.verbosity >= 1) {
@ -172,8 +172,8 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
}
TimingInfo timing_info = {.verbosity = app.verbosity};
RuntimeConfig runtime_config = {.verbosity = app.verbosity,
.gen = &gen,
RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = app.verbosity,
.stream_token = stream_token,
.accept_token = accept_token,
.use_spinning = app.spin};

View File

@ -56,7 +56,7 @@ void PaliGemmaTest::InitVit(const std::string& path) {
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(path));
image.Resize();
RuntimeConfig runtime_config = {.verbosity = 0, .gen = &s_env->MutableGen()};
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0};
model.GenerateImageTokens(runtime_config, image, image_tokens_);
}
@ -64,8 +64,8 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
Gemma& model = *(s_env->GetModel());
s_env->MutableGen().seed(0x12345678);
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
.verbosity = 0,
.gen = &s_env->MutableGen()};
.gen = &s_env->MutableGen(),
.verbosity = 0};
runtime_config.image_tokens = &image_tokens_;
size_t abs_pos = 0;
std::string mutable_prompt = prompt_text;

View File

@ -220,6 +220,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
size_t decode_qbatch_size;
float temperature;
size_t top_k;
bool deterministic;
bool multiturn;
Path image_file;
@ -244,6 +245,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
"Decode: max queries per batch.");
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
visitor(top_k, "top_k", size_t{1}, "Number of top-K tokens to sample from",
2);
visitor(deterministic, "deterministic", false,
"Make top-k sampling deterministic", 2);
visitor(multiturn, "multiturn", false,
@ -259,6 +262,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
runtime_config.decode_qbatch_size = decode_qbatch_size;
runtime_config.temperature = temperature;
runtime_config.top_k = top_k;
}
};