Rename GetModelConfig->Config

PiperOrigin-RevId: 788506480
This commit is contained in:
Jan Wassenberg 2025-07-29 10:17:14 -07:00 committed by Copybara-Service
parent 33fabd4ed1
commit ac0d751d20
14 changed files with 26 additions and 27 deletions

View File

@ -74,7 +74,7 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens); size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
std::vector<int> prompt_slice(prompt.begin() + pos, std::vector<int> prompt_slice(prompt.begin() + pos,
prompt.begin() + pos + num_tokens); prompt.begin() + pos + num_tokens);
KVCache kv_cache(gemma.GetModelConfig(), gemma.Inference(), KVCache kv_cache(gemma.Config(), gemma.Inference(),
env.MutableEnv().ctx.allocator); env.MutableEnv().ctx.allocator);
float entropy = float entropy =
ComputeCrossEntropy(*env.GetGemma(), num_tokens, prompt_slice, kv_cache, ComputeCrossEntropy(*env.GetGemma(), num_tokens, prompt_slice, kv_cache,

View File

@ -51,7 +51,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference) const InferenceArgs& inference)
: ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) { : ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) {
const ModelConfig& config = gemma_.GetModelConfig(); const ModelConfig& config = gemma_.Config();
// Only allocate one for starters because GenerateBatch might not be called. // Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator)); kv_caches_.push_back(KVCache(config, inference, ctx_.allocator));
@ -141,7 +141,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
// Ensure we have at least one KVCache per query. // Ensure we have at least one KVCache per query.
while (kv_caches_.size() < num_queries) { while (kv_caches_.size() < num_queries) {
kv_caches_.push_back( kv_caches_.push_back(
KVCache(gemma_.GetModelConfig(), gemma_.Inference(), ctx_.allocator)); KVCache(gemma_.Config(), gemma_.Inference(), ctx_.allocator));
} }
const hwy::Span<KVCache> kv_caches(&kv_caches_[0], num_queries); const hwy::Span<KVCache> kv_caches(&kv_caches_[0], num_queries);

View File

@ -73,7 +73,7 @@ class GemmaEnv {
std::vector<int> WrapAndTokenize(std::string& input) const { std::vector<int> WrapAndTokenize(std::string& input) const {
return gcpp::WrapAndTokenize(gemma_.Tokenizer(), gemma_.ChatTemplate(), return gcpp::WrapAndTokenize(gemma_.Tokenizer(), gemma_.ChatTemplate(),
gemma_.GetModelConfig().wrapping, 0, input); gemma_.Config().wrapping, 0, input);
} }
std::string StringFromTokens(const std::vector<int>& tokens) const { std::string StringFromTokens(const std::vector<int>& tokens) const {

View File

@ -102,7 +102,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
MatMulEnv& env, int verbosity) { MatMulEnv& env, int verbosity) {
const StreamFunc stream_token = [](int, float) { return true; }; const StreamFunc stream_token = [](int, float) { return true; };
const int vocab_size = gemma.GetModelConfig().vocab_size; const int vocab_size = gemma.Config().vocab_size;
float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s) float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s)
size_t pos = 1; size_t pos = 1;

View File

@ -43,7 +43,7 @@ class GemmaTest : public ::testing::Test {
static void InitEnv(int argc, char** argv) { static void InitEnv(int argc, char** argv) {
HWY_ASSERT(s_env == nullptr); // Should only be called once. HWY_ASSERT(s_env == nullptr); // Should only be called once.
s_env = new GemmaEnv(argc, argv); s_env = new GemmaEnv(argc, argv);
const gcpp::ModelConfig& config = s_env->GetGemma()->GetModelConfig(); const gcpp::ModelConfig& config = s_env->GetGemma()->Config();
fprintf(stderr, "Using %s\n", config.Specifier().c_str()); fprintf(stderr, "Using %s\n", config.Specifier().c_str());
} }
@ -98,7 +98,7 @@ TEST_F(GemmaTest, Batched) {
TEST_F(GemmaTest, Multiturn) { TEST_F(GemmaTest, Multiturn) {
const Gemma* model = s_env->GetGemma(); const Gemma* model = s_env->GetGemma();
const ModelConfig& config = model->GetModelConfig(); const ModelConfig& config = model->Config();
size_t abs_pos = 0; size_t abs_pos = 0;
std::string response; std::string response;
auto stream_token = [&](size_t query_idx, size_t pos, int token, float) { auto stream_token = [&](size_t query_idx, size_t pos, int token, float) {
@ -149,7 +149,7 @@ TEST_F(GemmaTest, Multiturn) {
TEST_F(GemmaTest, CrossEntropySmall) { TEST_F(GemmaTest, CrossEntropySmall) {
HWY_ASSERT(s_env->GetGemma() != nullptr); HWY_ASSERT(s_env->GetGemma() != nullptr);
const ModelConfig& config = s_env->GetGemma()->GetModelConfig(); const ModelConfig& config = s_env->GetGemma()->Config();
static const char kSmall[] = static const char kSmall[] =
"The capital of Hungary is Budapest which is located in Europe."; "The capital of Hungary is Budapest which is located in Europe.";
float entropy = s_env->CrossEntropy(kSmall); float entropy = s_env->CrossEntropy(kSmall);

View File

@ -54,7 +54,7 @@ int main(int argc, char** argv) {
gcpp::ThreadingContext ctx(gcpp::UpdateArgs(threading, inference)); gcpp::ThreadingContext ctx(gcpp::UpdateArgs(threading, inference));
gcpp::MatMulEnv env(ctx); gcpp::MatMulEnv env(ctx);
gcpp::Gemma gemma(loader, inference, ctx); gcpp::Gemma gemma(loader, inference, ctx);
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference, ctx.allocator); gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
size_t generated = 0; size_t generated = 0;
// Initialize random number generator // Initialize random number generator
@ -66,7 +66,7 @@ int main(int argc, char** argv) {
std::string prompt = "Write a greeting to the world."; std::string prompt = "Write a greeting to the world.";
const std::vector<int> tokens = const std::vector<int> tokens =
gcpp::WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(), gcpp::WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
gemma.GetModelConfig().wrapping, generated, prompt); gemma.Config().wrapping, generated, prompt);
const size_t prompt_size = tokens.size(); const size_t prompt_size = tokens.size();
// This callback function gets invoked every time a token is generated // This callback function gets invoked every time a token is generated
@ -74,7 +74,7 @@ int main(int argc, char** argv) {
++generated; ++generated;
if (generated < prompt_size) { if (generated < prompt_size) {
// print feedback // print feedback
} else if (!gemma.GetModelConfig().IsEOS(token)) { } else if (!gemma.Config().IsEOS(token)) {
std::string token_text; std::string token_text;
HWY_ASSERT(gemma.Tokenizer().Decode({token}, &token_text)); HWY_ASSERT(gemma.Tokenizer().Decode({token}, &token_text));
std::cout << token_text << std::flush; std::cout << token_text << std::flush;

View File

@ -38,7 +38,7 @@ class SimplifiedGemma {
: ctx_(UpdateArgs(threading, inference)), : ctx_(UpdateArgs(threading, inference)),
env_(ctx_), env_(ctx_),
gemma_(loader, inference, ctx_), gemma_(loader, inference, ctx_),
kv_cache_(gemma_.GetModelConfig(), inference, ctx_.allocator) { kv_cache_(gemma_.Config(), inference, ctx_.allocator) {
// Initialize random number generator // Initialize random number generator
std::random_device rd; std::random_device rd;
gen_.seed(rd()); gen_.seed(rd());
@ -56,7 +56,7 @@ class SimplifiedGemma {
const std::vector<int> tokens = gcpp::WrapAndTokenize( const std::vector<int> tokens = gcpp::WrapAndTokenize(
gemma_.Tokenizer(), gemma_.ChatTemplate(), gemma_.Tokenizer(), gemma_.ChatTemplate(),
gemma_.GetModelConfig().wrapping, generated, prompt); gemma_.Config().wrapping, generated, prompt);
const size_t prompt_size = tokens.size(); const size_t prompt_size = tokens.size();
// This callback function gets invoked every time a token is generated // This callback function gets invoked every time a token is generated
@ -64,7 +64,7 @@ class SimplifiedGemma {
++generated; ++generated;
if (generated < prompt_size) { if (generated < prompt_size) {
// print feedback // print feedback
} else if (!gemma_.GetModelConfig().IsEOS(token)) { } else if (!gemma_.Config().IsEOS(token)) {
std::string token_text; std::string token_text;
HWY_ASSERT(gemma_.Tokenizer().Decode({token}, &token_text)); HWY_ASSERT(gemma_.Tokenizer().Decode({token}, &token_text));
std::cout << token_text << std::flush; std::cout << token_text << std::flush;

View File

@ -112,7 +112,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
LogDebug("Creating initial ConversationData"); LogDebug("Creating initial ConversationData");
// Create the initial ConversationData object using make_shared // Create the initial ConversationData object using make_shared
active_conversation = std::make_shared<ConversationData>( active_conversation = std::make_shared<ConversationData>(
model.GetModelConfig(), inference_args, ctx.allocator); model.Config(), inference_args, ctx.allocator);
LogDebug( LogDebug(
"Storing initial ConversationData in conversation_cache[\"default\"]"); "Storing initial ConversationData in conversation_cache[\"default\"]");
@ -150,7 +150,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
const bool in_prompt = tokens_generated_this_turn < prompt_size; const bool in_prompt = tokens_generated_this_turn < prompt_size;
const bool first_response_token = tokens_generated_this_turn == prompt_size; const bool first_response_token = tokens_generated_this_turn == prompt_size;
++tokens_generated_this_turn; ++tokens_generated_this_turn;
if (in_prompt || model.GetModelConfig().IsEOS(token)) { if (in_prompt || model.Config().IsEOS(token)) {
return true; return true;
} }
@ -180,7 +180,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
inference_args.CopyTo(runtime_config); inference_args.CopyTo(runtime_config);
size_t prefix_end = 0; size_t prefix_end = 0;
const ModelConfig& model_config = model.GetModelConfig(); const ModelConfig& model_config = model.Config();
// generate // generate
std::vector<int> prompt; std::vector<int> prompt;

View File

@ -180,7 +180,7 @@ class GemmaContext {
active_conversation->abs_pos = 0; active_conversation->abs_pos = 0;
// Replace the cache within the current ConversationData object // Replace the cache within the current ConversationData object
active_conversation->kv_cache = std::make_unique<KVCache>( active_conversation->kv_cache = std::make_unique<KVCache>(
model.GetModelConfig(), inference_args, ctx.allocator); model.Config(), inference_args, ctx.allocator);
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str()); LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
} else { } else {
@ -198,7 +198,7 @@ class GemmaContext {
LogDebug("Creating new conversation"); LogDebug("Creating new conversation");
// Create a new ConversationData object using make_shared // Create a new ConversationData object using make_shared
conversation_cache[name] = std::make_shared<ConversationData>( conversation_cache[name] = std::make_shared<ConversationData>(
model.GetModelConfig(), inference_args, ctx.allocator); model.Config(), inference_args, ctx.allocator);
return true; return true;
} }

View File

@ -236,8 +236,7 @@ class Gemma {
ThreadingContext& ctx); ThreadingContext& ctx);
~Gemma(); ~Gemma();
// TODO: rename to Config() const ModelConfig& Config() const { return model_.Config(); }
const ModelConfig& GetModelConfig() const { return model_.Config(); }
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); } const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
const WeightsPtrs& Weights() const { return weights_; } const WeightsPtrs& Weights() const { return weights_; }
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; } const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }

View File

@ -97,7 +97,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
size_t abs_pos = 0; // across turns size_t abs_pos = 0; // across turns
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
size_t prompt_size = 0; size_t prompt_size = 0;
const ModelConfig& config = gemma.GetModelConfig(); const ModelConfig& config = gemma.Config();
std::mt19937 gen; std::mt19937 gen;
InitGenerator(inference, gen); InitGenerator(inference, gen);
@ -258,7 +258,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
MatMulEnv env(ctx); MatMulEnv env(ctx);
if (inference.verbosity >= 2) env.print_best = true; if (inference.verbosity >= 2) env.print_best = true;
const Gemma gemma(loader, inference, ctx); const Gemma gemma(loader, inference, ctx);
KVCache kv_cache(gemma.GetModelConfig(), inference, ctx.allocator); KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
if (inference.verbosity >= 1) { if (inference.verbosity >= 1) {
std::string instructions = std::string instructions =
@ -285,7 +285,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
if (inference.IsInteractive()) { if (inference.IsInteractive()) {
std::cout << "\033[2J\033[1;1H" // clear screen std::cout << "\033[2J\033[1;1H" // clear screen
<< kAsciiArtBanner << "\n\n"; << kAsciiArtBanner << "\n\n";
ShowConfig(loader, threading, inference, gemma.GetModelConfig(), ctx); ShowConfig(loader, threading, inference, gemma.Config(), ctx);
std::cout << "\n" << instructions << "\n"; std::cout << "\n" << instructions << "\n";
} }
} }

View File

@ -16,7 +16,7 @@ namespace gcpp {
void PaliGemmaHelper::InitVit(const std::string& path) { void PaliGemmaHelper::InitVit(const std::string& path) {
HWY_ASSERT(env_->GetGemma() != nullptr); HWY_ASSERT(env_->GetGemma() != nullptr);
const Gemma& gemma = *(env_->GetGemma()); const Gemma& gemma = *(env_->GetGemma());
const ModelConfig& config = gemma.GetModelConfig(); const ModelConfig& config = gemma.Config();
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA); HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
image_tokens_ = std::make_unique<ImageTokens>( image_tokens_ = std::make_unique<ImageTokens>(

View File

@ -58,7 +58,7 @@ TEST_F(PaliGemmaTest, QueryObjects) {
const char* question = "answer en What objects are in the image?"; const char* question = "answer en What objects are in the image?";
// 3B PT/Mix 224, 10B Mix 224 // 3B PT/Mix 224, 10B Mix 224
const char* expected_substring = "Building, Tower"; const char* expected_substring = "Building, Tower";
const Model model = s_env->GetGemma()->GetModelConfig().model; const Model model = s_env->GetGemma()->Config().model;
if (model == Model::PALIGEMMA2_3B_448) { if (model == Model::PALIGEMMA2_3B_448) {
expected_substring = "Lake."; expected_substring = "Lake.";
} else if (model == Model::PALIGEMMA2_10B_224) { } else if (model == Model::PALIGEMMA2_10B_224) {

View File

@ -167,7 +167,7 @@ class GemmaModel {
void SetImage(const py::array_t<float, py::array::c_style | void SetImage(const py::array_t<float, py::array::c_style |
py::array::forcecast>& image) { py::array::forcecast>& image) {
const gcpp::Gemma& gemma = *env_.GetGemma(); const gcpp::Gemma& gemma = *env_.GetGemma();
const gcpp::ModelConfig& config = gemma.GetModelConfig(); const gcpp::ModelConfig& config = gemma.Config();
if (config.wrapping != gcpp::PromptWrapping::PALIGEMMA && if (config.wrapping != gcpp::PromptWrapping::PALIGEMMA &&
config.wrapping != gcpp::PromptWrapping::GEMMA_VLM) { config.wrapping != gcpp::PromptWrapping::GEMMA_VLM) {
throw std::invalid_argument("Not a PaliGemma model."); throw std::invalid_argument("Not a PaliGemma model.");