mirror of https://github.com/google/gemma.cpp.git
Make top_k a runtime argument (instead of a model argument).
PiperOrigin-RevId: 696170691
This commit is contained in:
parent
b94295b6d9
commit
719699f132
|
|
@ -74,8 +74,8 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
RuntimeConfig runtime = {
|
||||
.max_generated_tokens = 16,
|
||||
.temperature = 1.0f,
|
||||
.verbosity = 0,
|
||||
.gen = &gen,
|
||||
.verbosity = 0,
|
||||
.stream_token = stream_token,
|
||||
.eos_id = ReverseSequenceSampler::kEndToken,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -74,8 +74,8 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
|||
runtime_config_ = {
|
||||
.max_generated_tokens = inference.max_generated_tokens,
|
||||
.temperature = inference.temperature,
|
||||
.verbosity = app.verbosity,
|
||||
.gen = &gen_,
|
||||
.verbosity = app.verbosity,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -139,8 +139,8 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
|
|||
RuntimeConfig runtime = {
|
||||
.max_generated_tokens = max_generated_tokens - 1,
|
||||
.temperature = 0.0f,
|
||||
.verbosity = verbosity,
|
||||
.gen = nullptr,
|
||||
.verbosity = verbosity,
|
||||
.stream_token = stream_token,
|
||||
.sample_func = sample_token,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -169,8 +169,8 @@ TEST_F(GemmaTest, Multiturn) {
|
|||
RuntimeConfig runtime_config{
|
||||
.max_generated_tokens = 64,
|
||||
.temperature = 0.0f,
|
||||
.verbosity = 2,
|
||||
.gen = &s_env->MutableGen(),
|
||||
.verbosity = 2,
|
||||
.stream_token = stream_token,
|
||||
};
|
||||
TimingInfo timing_info{.verbosity = 0};
|
||||
|
|
|
|||
|
|
@ -127,8 +127,8 @@ void Run(GemmaEnv& env, JsonArgs& json) {
|
|||
gcpp::RuntimeConfig runtime_config = {
|
||||
.max_generated_tokens = 30,
|
||||
.temperature = 0.0f,
|
||||
.verbosity = env.Verbosity(),
|
||||
.gen = &env.MutableGen(),
|
||||
.verbosity = env.Verbosity(),
|
||||
.stream_token = stream_token,
|
||||
};
|
||||
env.GetModel()->Generate(runtime_config, prompt, /*pos=*/0,
|
||||
|
|
|
|||
|
|
@ -89,8 +89,8 @@ int main(int argc, char** argv) {
|
|||
gcpp::RuntimeConfig runtime_config = {
|
||||
.max_generated_tokens = 1024,
|
||||
.temperature = 1.0,
|
||||
.verbosity = 0,
|
||||
.gen = &gen,
|
||||
.verbosity = 0,
|
||||
.stream_token = stream_token,
|
||||
.accept_token =
|
||||
[&](int token, float /* prob */) {
|
||||
|
|
|
|||
|
|
@ -178,7 +178,6 @@ struct ModelConfig {
|
|||
size_t vit_seq_len = 0;
|
||||
size_t num_tensor_scales = 0;
|
||||
size_t num_vit_scales = 0;
|
||||
size_t top_k = kTopK;
|
||||
float att_cap = 0.0f;
|
||||
float final_cap = 0.0f;
|
||||
bool absolute_pe = false;
|
||||
|
|
|
|||
|
|
@ -374,7 +374,7 @@ void AssertMatch(const ModelConfig& config) {
|
|||
}
|
||||
ASSERT_EQ(TConfig::kVocabSize, config.vocab_size);
|
||||
ASSERT_EQ(TConfig::kSeqLen, config.seq_len);
|
||||
ASSERT_EQ(TConfig::kTopK, config.top_k);
|
||||
// ASSERT_EQ(TConfig::kTopK, config.top_k); - is now a runtime config value.
|
||||
ASSERT_EQ(TConfig::kAttCap, config.att_cap);
|
||||
ASSERT_EQ(TConfig::kFinalCap, config.final_cap);
|
||||
ASSERT_EQ(TConfig::kAbsolutePE, config.absolute_pe);
|
||||
|
|
|
|||
|
|
@ -1196,13 +1196,12 @@ class TokenStreamer {
|
|||
hwy::BitSet4096<> is_eos_;
|
||||
};
|
||||
|
||||
HWY_INLINE SampleFunc ChooseSampleFunc(int top_k,
|
||||
const RuntimeConfig& runtime_config) {
|
||||
HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) {
|
||||
// If user provided a sample_func, use it.
|
||||
if (runtime_config.sample_func) return runtime_config.sample_func;
|
||||
|
||||
// Fast path for top-1 with no accept_token.
|
||||
if (top_k == 1 && !runtime_config.accept_token) {
|
||||
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
|
||||
return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||
PROFILER_ZONE("Gen.Sample Top1");
|
||||
return Top1OfSoftmax(logits, vocab_size);
|
||||
|
|
@ -1210,13 +1209,13 @@ HWY_INLINE SampleFunc ChooseSampleFunc(int top_k,
|
|||
}
|
||||
|
||||
// General case: Softmax with top-k sampling.
|
||||
return [top_k, &runtime_config](float* logits,
|
||||
size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||
return [&runtime_config](float* logits,
|
||||
size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||
PROFILER_ZONE("Gen.Sample general");
|
||||
Softmax(logits, vocab_size);
|
||||
const int token =
|
||||
SampleTopK(logits, top_k, vocab_size, *runtime_config.gen,
|
||||
runtime_config.temperature, runtime_config.accept_token);
|
||||
const int token = SampleTopK(
|
||||
logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
|
||||
runtime_config.temperature, runtime_config.accept_token);
|
||||
return TokenAndProb{.token = token, .prob = logits[token]};
|
||||
};
|
||||
}
|
||||
|
|
@ -1276,8 +1275,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
|||
size_t max_prompt_size = MaxQueryLength(queries_prompt);
|
||||
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
||||
RangeChecks(weights.weights_config, max_generated_tokens, max_prompt_size);
|
||||
const SampleFunc sample_token =
|
||||
ChooseSampleFunc(weights.weights_config.top_k, runtime_config);
|
||||
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
|
||||
|
||||
// Prefill stops before min_prompt_size - 1 because the last prompt
|
||||
// token is the first input token for generation.
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@
|
|||
#include "compression/io.h" // Path
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "gemma/weights.h"
|
||||
|
|
@ -102,9 +103,12 @@ struct RuntimeConfig {
|
|||
// Max queries per batch (one token from each) during decode.
|
||||
size_t decode_qbatch_size = 16;
|
||||
|
||||
float temperature; // Temperature for sampling.
|
||||
// Sampling-related parameters.
|
||||
float temperature; // Temperature for sampling.
|
||||
size_t top_k = kTopK; // Top-k for sampling.
|
||||
std::mt19937* gen; // Random number generator used for sampling.
|
||||
|
||||
int verbosity; // Controls verbosity of printed messages.
|
||||
std::mt19937* gen; // Random number generator used for sampling.
|
||||
|
||||
// Functions operating on the generated tokens.
|
||||
StreamFunc stream_token;
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
HWY_ASSERT(image.ReadPPM(args.image_file.path));
|
||||
image.Resize();
|
||||
RuntimeConfig runtime_config = {
|
||||
.verbosity = app.verbosity, .gen = &gen, .use_spinning = app.spin};
|
||||
.gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin};
|
||||
double image_tokens_start = hwy::platform::Now();
|
||||
model.GenerateImageTokens(runtime_config, image, image_tokens);
|
||||
if (app.verbosity >= 1) {
|
||||
|
|
@ -172,8 +172,8 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
}
|
||||
|
||||
TimingInfo timing_info = {.verbosity = app.verbosity};
|
||||
RuntimeConfig runtime_config = {.verbosity = app.verbosity,
|
||||
.gen = &gen,
|
||||
RuntimeConfig runtime_config = {.gen = &gen,
|
||||
.verbosity = app.verbosity,
|
||||
.stream_token = stream_token,
|
||||
.accept_token = accept_token,
|
||||
.use_spinning = app.spin};
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ void PaliGemmaTest::InitVit(const std::string& path) {
|
|||
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
|
||||
HWY_ASSERT(image.ReadPPM(path));
|
||||
image.Resize();
|
||||
RuntimeConfig runtime_config = {.verbosity = 0, .gen = &s_env->MutableGen()};
|
||||
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0};
|
||||
model.GenerateImageTokens(runtime_config, image, image_tokens_);
|
||||
}
|
||||
|
||||
|
|
@ -64,8 +64,8 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
|||
Gemma& model = *(s_env->GetModel());
|
||||
s_env->MutableGen().seed(0x12345678);
|
||||
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
||||
.verbosity = 0,
|
||||
.gen = &s_env->MutableGen()};
|
||||
.gen = &s_env->MutableGen(),
|
||||
.verbosity = 0};
|
||||
runtime_config.image_tokens = &image_tokens_;
|
||||
size_t abs_pos = 0;
|
||||
std::string mutable_prompt = prompt_text;
|
||||
|
|
|
|||
|
|
@ -220,6 +220,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
size_t decode_qbatch_size;
|
||||
|
||||
float temperature;
|
||||
size_t top_k;
|
||||
bool deterministic;
|
||||
bool multiturn;
|
||||
Path image_file;
|
||||
|
|
@ -244,6 +245,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
"Decode: max queries per batch.");
|
||||
|
||||
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
|
||||
visitor(top_k, "top_k", size_t{1}, "Number of top-K tokens to sample from",
|
||||
2);
|
||||
visitor(deterministic, "deterministic", false,
|
||||
"Make top-k sampling deterministic", 2);
|
||||
visitor(multiturn, "multiturn", false,
|
||||
|
|
@ -259,6 +262,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
|
||||
runtime_config.decode_qbatch_size = decode_qbatch_size;
|
||||
runtime_config.temperature = temperature;
|
||||
runtime_config.top_k = top_k;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue