diff --git a/evals/benchmark.cc b/evals/benchmark.cc index f7c614e..4dec9ee 100644 --- a/evals/benchmark.cc +++ b/evals/benchmark.cc @@ -74,7 +74,7 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text, size_t num_tokens = std::min(prompt.size() - pos, batch_tokens); std::vector prompt_slice(prompt.begin() + pos, prompt.begin() + pos + num_tokens); - KVCache kv_cache(gemma.GetModelConfig(), gemma.Inference(), + KVCache kv_cache(gemma.Config(), gemma.Inference(), env.MutableEnv().ctx.allocator); float entropy = ComputeCrossEntropy(*env.GetGemma(), num_tokens, prompt_slice, kv_cache, diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 689062d..b4803d0 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -51,7 +51,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) { GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, const InferenceArgs& inference) : ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) { - const ModelConfig& config = gemma_.GetModelConfig(); + const ModelConfig& config = gemma_.Config(); // Only allocate one for starters because GenerateBatch might not be called. kv_caches_.push_back(KVCache(config, inference, ctx_.allocator)); @@ -141,7 +141,7 @@ std::vector GemmaEnv::BatchQueryModel( // Ensure we have at least one KVCache per query. while (kv_caches_.size() < num_queries) { kv_caches_.push_back( - KVCache(gemma_.GetModelConfig(), gemma_.Inference(), ctx_.allocator)); + KVCache(gemma_.Config(), gemma_.Inference(), ctx_.allocator)); } const hwy::Span kv_caches(&kv_caches_[0], num_queries); diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 8f1a238..a8f0dc8 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -73,7 +73,7 @@ class GemmaEnv { std::vector WrapAndTokenize(std::string& input) const { return gcpp::WrapAndTokenize(gemma_.Tokenizer(), gemma_.ChatTemplate(), - gemma_.GetModelConfig().wrapping, 0, input); + gemma_.Config().wrapping, 0, input); } std::string StringFromTokens(const std::vector& tokens) const { diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index f94ea11..320967c 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -102,7 +102,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens, MatMulEnv& env, int verbosity) { const StreamFunc stream_token = [](int, float) { return true; }; - const int vocab_size = gemma.GetModelConfig().vocab_size; + const int vocab_size = gemma.Config().vocab_size; float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s) size_t pos = 1; diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 96ff08b..12080f9 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -43,7 +43,7 @@ class GemmaTest : public ::testing::Test { static void InitEnv(int argc, char** argv) { HWY_ASSERT(s_env == nullptr); // Should only be called once. s_env = new GemmaEnv(argc, argv); - const gcpp::ModelConfig& config = s_env->GetGemma()->GetModelConfig(); + const gcpp::ModelConfig& config = s_env->GetGemma()->Config(); fprintf(stderr, "Using %s\n", config.Specifier().c_str()); } @@ -98,7 +98,7 @@ TEST_F(GemmaTest, Batched) { TEST_F(GemmaTest, Multiturn) { const Gemma* model = s_env->GetGemma(); - const ModelConfig& config = model->GetModelConfig(); + const ModelConfig& config = model->Config(); size_t abs_pos = 0; std::string response; auto stream_token = [&](size_t query_idx, size_t pos, int token, float) { @@ -149,7 +149,7 @@ TEST_F(GemmaTest, Multiturn) { TEST_F(GemmaTest, CrossEntropySmall) { HWY_ASSERT(s_env->GetGemma() != nullptr); - const ModelConfig& config = s_env->GetGemma()->GetModelConfig(); + const ModelConfig& config = s_env->GetGemma()->Config(); static const char kSmall[] = "The capital of Hungary is Budapest which is located in Europe."; float entropy = s_env->CrossEntropy(kSmall); diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 8f1a7b3..9f12407 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -54,7 +54,7 @@ int main(int argc, char** argv) { gcpp::ThreadingContext ctx(gcpp::UpdateArgs(threading, inference)); gcpp::MatMulEnv env(ctx); gcpp::Gemma gemma(loader, inference, ctx); - gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference, ctx.allocator); + gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator); size_t generated = 0; // Initialize random number generator @@ -66,7 +66,7 @@ int main(int argc, char** argv) { std::string prompt = "Write a greeting to the world."; const std::vector tokens = gcpp::WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(), - gemma.GetModelConfig().wrapping, generated, prompt); + gemma.Config().wrapping, generated, prompt); const size_t prompt_size = tokens.size(); // This callback function gets invoked every time a token is generated @@ -74,7 +74,7 @@ int main(int argc, char** argv) { ++generated; if (generated < prompt_size) { // print feedback - } else if (!gemma.GetModelConfig().IsEOS(token)) { + } else if (!gemma.Config().IsEOS(token)) { std::string token_text; HWY_ASSERT(gemma.Tokenizer().Decode({token}, &token_text)); std::cout << token_text << std::flush; diff --git a/examples/simplified_gemma/gemma.hpp b/examples/simplified_gemma/gemma.hpp index 48290e8..7f6e4c2 100644 --- a/examples/simplified_gemma/gemma.hpp +++ b/examples/simplified_gemma/gemma.hpp @@ -38,7 +38,7 @@ class SimplifiedGemma { : ctx_(UpdateArgs(threading, inference)), env_(ctx_), gemma_(loader, inference, ctx_), - kv_cache_(gemma_.GetModelConfig(), inference, ctx_.allocator) { + kv_cache_(gemma_.Config(), inference, ctx_.allocator) { // Initialize random number generator std::random_device rd; gen_.seed(rd()); @@ -56,7 +56,7 @@ class SimplifiedGemma { const std::vector tokens = gcpp::WrapAndTokenize( gemma_.Tokenizer(), gemma_.ChatTemplate(), - gemma_.GetModelConfig().wrapping, generated, prompt); + gemma_.Config().wrapping, generated, prompt); const size_t prompt_size = tokens.size(); // This callback function gets invoked every time a token is generated @@ -64,7 +64,7 @@ class SimplifiedGemma { ++generated; if (generated < prompt_size) { // print feedback - } else if (!gemma_.GetModelConfig().IsEOS(token)) { + } else if (!gemma_.Config().IsEOS(token)) { std::string token_text; HWY_ASSERT(gemma_.Tokenizer().Decode({token}, &token_text)); std::cout << token_text << std::flush; diff --git a/gemma/bindings/context.cc b/gemma/bindings/context.cc index e8329c2..76ebe1e 100644 --- a/gemma/bindings/context.cc +++ b/gemma/bindings/context.cc @@ -112,7 +112,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader, LogDebug("Creating initial ConversationData"); // Create the initial ConversationData object using make_shared active_conversation = std::make_shared( - model.GetModelConfig(), inference_args, ctx.allocator); + model.Config(), inference_args, ctx.allocator); LogDebug( "Storing initial ConversationData in conversation_cache[\"default\"]"); @@ -150,7 +150,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string, const bool in_prompt = tokens_generated_this_turn < prompt_size; const bool first_response_token = tokens_generated_this_turn == prompt_size; ++tokens_generated_this_turn; - if (in_prompt || model.GetModelConfig().IsEOS(token)) { + if (in_prompt || model.Config().IsEOS(token)) { return true; } @@ -180,7 +180,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string, inference_args.CopyTo(runtime_config); size_t prefix_end = 0; - const ModelConfig& model_config = model.GetModelConfig(); + const ModelConfig& model_config = model.Config(); // generate std::vector prompt; diff --git a/gemma/bindings/context.h b/gemma/bindings/context.h index fcf3529..859a644 100644 --- a/gemma/bindings/context.h +++ b/gemma/bindings/context.h @@ -180,7 +180,7 @@ class GemmaContext { active_conversation->abs_pos = 0; // Replace the cache within the current ConversationData object active_conversation->kv_cache = std::make_unique( - model.GetModelConfig(), inference_args, ctx.allocator); + model.Config(), inference_args, ctx.allocator); LogDebug((log_prefix + "Successfully rewound to initial state.").c_str()); } else { @@ -198,7 +198,7 @@ class GemmaContext { LogDebug("Creating new conversation"); // Create a new ConversationData object using make_shared conversation_cache[name] = std::make_shared( - model.GetModelConfig(), inference_args, ctx.allocator); + model.Config(), inference_args, ctx.allocator); return true; } diff --git a/gemma/gemma.h b/gemma/gemma.h index 43af21e..b9f4127 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -236,8 +236,7 @@ class Gemma { ThreadingContext& ctx); ~Gemma(); - // TODO: rename to Config() - const ModelConfig& GetModelConfig() const { return model_.Config(); } + const ModelConfig& Config() const { return model_.Config(); } const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); } const WeightsPtrs& Weights() const { return weights_; } const GemmaChatTemplate& ChatTemplate() const { return chat_template_; } diff --git a/gemma/run.cc b/gemma/run.cc index cd72d63..997d06e 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -97,7 +97,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, size_t abs_pos = 0; // across turns size_t tokens_generated_this_turn = 0; // differentiates prefill from reply size_t prompt_size = 0; - const ModelConfig& config = gemma.GetModelConfig(); + const ModelConfig& config = gemma.Config(); std::mt19937 gen; InitGenerator(inference, gen); @@ -258,7 +258,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading, MatMulEnv env(ctx); if (inference.verbosity >= 2) env.print_best = true; const Gemma gemma(loader, inference, ctx); - KVCache kv_cache(gemma.GetModelConfig(), inference, ctx.allocator); + KVCache kv_cache(gemma.Config(), inference, ctx.allocator); if (inference.verbosity >= 1) { std::string instructions = @@ -285,7 +285,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading, if (inference.IsInteractive()) { std::cout << "\033[2J\033[1;1H" // clear screen << kAsciiArtBanner << "\n\n"; - ShowConfig(loader, threading, inference, gemma.GetModelConfig(), ctx); + ShowConfig(loader, threading, inference, gemma.Config(), ctx); std::cout << "\n" << instructions << "\n"; } } diff --git a/paligemma/paligemma_helper.cc b/paligemma/paligemma_helper.cc index 4872553..2c798b9 100644 --- a/paligemma/paligemma_helper.cc +++ b/paligemma/paligemma_helper.cc @@ -16,7 +16,7 @@ namespace gcpp { void PaliGemmaHelper::InitVit(const std::string& path) { HWY_ASSERT(env_->GetGemma() != nullptr); const Gemma& gemma = *(env_->GetGemma()); - const ModelConfig& config = gemma.GetModelConfig(); + const ModelConfig& config = gemma.Config(); HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA); image_tokens_ = std::make_unique( diff --git a/paligemma/paligemma_test.cc b/paligemma/paligemma_test.cc index 56f618f..0a7401a 100644 --- a/paligemma/paligemma_test.cc +++ b/paligemma/paligemma_test.cc @@ -58,7 +58,7 @@ TEST_F(PaliGemmaTest, QueryObjects) { const char* question = "answer en What objects are in the image?"; // 3B PT/Mix 224, 10B Mix 224 const char* expected_substring = "Building, Tower"; - const Model model = s_env->GetGemma()->GetModelConfig().model; + const Model model = s_env->GetGemma()->Config().model; if (model == Model::PALIGEMMA2_3B_448) { expected_substring = "Lake."; } else if (model == Model::PALIGEMMA2_10B_224) { diff --git a/python/gemma_py.cc b/python/gemma_py.cc index 238b546..9af07b3 100644 --- a/python/gemma_py.cc +++ b/python/gemma_py.cc @@ -167,7 +167,7 @@ class GemmaModel { void SetImage(const py::array_t& image) { const gcpp::Gemma& gemma = *env_.GetGemma(); - const gcpp::ModelConfig& config = gemma.GetModelConfig(); + const gcpp::ModelConfig& config = gemma.Config(); if (config.wrapping != gcpp::PromptWrapping::PALIGEMMA && config.wrapping != gcpp::PromptWrapping::GEMMA_VLM) { throw std::invalid_argument("Not a PaliGemma model.");