mirror of https://github.com/google/gemma.cpp.git
Rename GetModelConfig->Config
PiperOrigin-RevId: 788506480
This commit is contained in:
parent
33fabd4ed1
commit
ac0d751d20
|
|
@ -74,7 +74,7 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
|||
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
||||
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||
prompt.begin() + pos + num_tokens);
|
||||
KVCache kv_cache(gemma.GetModelConfig(), gemma.Inference(),
|
||||
KVCache kv_cache(gemma.Config(), gemma.Inference(),
|
||||
env.MutableEnv().ctx.allocator);
|
||||
float entropy =
|
||||
ComputeCrossEntropy(*env.GetGemma(), num_tokens, prompt_slice, kv_cache,
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
|
|||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference)
|
||||
: ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) {
|
||||
const ModelConfig& config = gemma_.GetModelConfig();
|
||||
const ModelConfig& config = gemma_.Config();
|
||||
// Only allocate one for starters because GenerateBatch might not be called.
|
||||
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator));
|
||||
|
||||
|
|
@ -141,7 +141,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
|||
// Ensure we have at least one KVCache per query.
|
||||
while (kv_caches_.size() < num_queries) {
|
||||
kv_caches_.push_back(
|
||||
KVCache(gemma_.GetModelConfig(), gemma_.Inference(), ctx_.allocator));
|
||||
KVCache(gemma_.Config(), gemma_.Inference(), ctx_.allocator));
|
||||
}
|
||||
const hwy::Span<KVCache> kv_caches(&kv_caches_[0], num_queries);
|
||||
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ class GemmaEnv {
|
|||
|
||||
std::vector<int> WrapAndTokenize(std::string& input) const {
|
||||
return gcpp::WrapAndTokenize(gemma_.Tokenizer(), gemma_.ChatTemplate(),
|
||||
gemma_.GetModelConfig().wrapping, 0, input);
|
||||
gemma_.Config().wrapping, 0, input);
|
||||
}
|
||||
|
||||
std::string StringFromTokens(const std::vector<int>& tokens) const {
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
|||
MatMulEnv& env, int verbosity) {
|
||||
const StreamFunc stream_token = [](int, float) { return true; };
|
||||
|
||||
const int vocab_size = gemma.GetModelConfig().vocab_size;
|
||||
const int vocab_size = gemma.Config().vocab_size;
|
||||
float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s)
|
||||
size_t pos = 1;
|
||||
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class GemmaTest : public ::testing::Test {
|
|||
static void InitEnv(int argc, char** argv) {
|
||||
HWY_ASSERT(s_env == nullptr); // Should only be called once.
|
||||
s_env = new GemmaEnv(argc, argv);
|
||||
const gcpp::ModelConfig& config = s_env->GetGemma()->GetModelConfig();
|
||||
const gcpp::ModelConfig& config = s_env->GetGemma()->Config();
|
||||
fprintf(stderr, "Using %s\n", config.Specifier().c_str());
|
||||
}
|
||||
|
||||
|
|
@ -98,7 +98,7 @@ TEST_F(GemmaTest, Batched) {
|
|||
|
||||
TEST_F(GemmaTest, Multiturn) {
|
||||
const Gemma* model = s_env->GetGemma();
|
||||
const ModelConfig& config = model->GetModelConfig();
|
||||
const ModelConfig& config = model->Config();
|
||||
size_t abs_pos = 0;
|
||||
std::string response;
|
||||
auto stream_token = [&](size_t query_idx, size_t pos, int token, float) {
|
||||
|
|
@ -149,7 +149,7 @@ TEST_F(GemmaTest, Multiturn) {
|
|||
|
||||
TEST_F(GemmaTest, CrossEntropySmall) {
|
||||
HWY_ASSERT(s_env->GetGemma() != nullptr);
|
||||
const ModelConfig& config = s_env->GetGemma()->GetModelConfig();
|
||||
const ModelConfig& config = s_env->GetGemma()->Config();
|
||||
static const char kSmall[] =
|
||||
"The capital of Hungary is Budapest which is located in Europe.";
|
||||
float entropy = s_env->CrossEntropy(kSmall);
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ int main(int argc, char** argv) {
|
|||
gcpp::ThreadingContext ctx(gcpp::UpdateArgs(threading, inference));
|
||||
gcpp::MatMulEnv env(ctx);
|
||||
gcpp::Gemma gemma(loader, inference, ctx);
|
||||
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference, ctx.allocator);
|
||||
gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
||||
size_t generated = 0;
|
||||
|
||||
// Initialize random number generator
|
||||
|
|
@ -66,7 +66,7 @@ int main(int argc, char** argv) {
|
|||
std::string prompt = "Write a greeting to the world.";
|
||||
const std::vector<int> tokens =
|
||||
gcpp::WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
|
||||
gemma.GetModelConfig().wrapping, generated, prompt);
|
||||
gemma.Config().wrapping, generated, prompt);
|
||||
const size_t prompt_size = tokens.size();
|
||||
|
||||
// This callback function gets invoked every time a token is generated
|
||||
|
|
@ -74,7 +74,7 @@ int main(int argc, char** argv) {
|
|||
++generated;
|
||||
if (generated < prompt_size) {
|
||||
// print feedback
|
||||
} else if (!gemma.GetModelConfig().IsEOS(token)) {
|
||||
} else if (!gemma.Config().IsEOS(token)) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(gemma.Tokenizer().Decode({token}, &token_text));
|
||||
std::cout << token_text << std::flush;
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class SimplifiedGemma {
|
|||
: ctx_(UpdateArgs(threading, inference)),
|
||||
env_(ctx_),
|
||||
gemma_(loader, inference, ctx_),
|
||||
kv_cache_(gemma_.GetModelConfig(), inference, ctx_.allocator) {
|
||||
kv_cache_(gemma_.Config(), inference, ctx_.allocator) {
|
||||
// Initialize random number generator
|
||||
std::random_device rd;
|
||||
gen_.seed(rd());
|
||||
|
|
@ -56,7 +56,7 @@ class SimplifiedGemma {
|
|||
|
||||
const std::vector<int> tokens = gcpp::WrapAndTokenize(
|
||||
gemma_.Tokenizer(), gemma_.ChatTemplate(),
|
||||
gemma_.GetModelConfig().wrapping, generated, prompt);
|
||||
gemma_.Config().wrapping, generated, prompt);
|
||||
const size_t prompt_size = tokens.size();
|
||||
|
||||
// This callback function gets invoked every time a token is generated
|
||||
|
|
@ -64,7 +64,7 @@ class SimplifiedGemma {
|
|||
++generated;
|
||||
if (generated < prompt_size) {
|
||||
// print feedback
|
||||
} else if (!gemma_.GetModelConfig().IsEOS(token)) {
|
||||
} else if (!gemma_.Config().IsEOS(token)) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(gemma_.Tokenizer().Decode({token}, &token_text));
|
||||
std::cout << token_text << std::flush;
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
|
|||
LogDebug("Creating initial ConversationData");
|
||||
// Create the initial ConversationData object using make_shared
|
||||
active_conversation = std::make_shared<ConversationData>(
|
||||
model.GetModelConfig(), inference_args, ctx.allocator);
|
||||
model.Config(), inference_args, ctx.allocator);
|
||||
|
||||
LogDebug(
|
||||
"Storing initial ConversationData in conversation_cache[\"default\"]");
|
||||
|
|
@ -150,7 +150,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
const bool in_prompt = tokens_generated_this_turn < prompt_size;
|
||||
const bool first_response_token = tokens_generated_this_turn == prompt_size;
|
||||
++tokens_generated_this_turn;
|
||||
if (in_prompt || model.GetModelConfig().IsEOS(token)) {
|
||||
if (in_prompt || model.Config().IsEOS(token)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -180,7 +180,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
inference_args.CopyTo(runtime_config);
|
||||
size_t prefix_end = 0;
|
||||
|
||||
const ModelConfig& model_config = model.GetModelConfig();
|
||||
const ModelConfig& model_config = model.Config();
|
||||
|
||||
// generate
|
||||
std::vector<int> prompt;
|
||||
|
|
|
|||
|
|
@ -180,7 +180,7 @@ class GemmaContext {
|
|||
active_conversation->abs_pos = 0;
|
||||
// Replace the cache within the current ConversationData object
|
||||
active_conversation->kv_cache = std::make_unique<KVCache>(
|
||||
model.GetModelConfig(), inference_args, ctx.allocator);
|
||||
model.Config(), inference_args, ctx.allocator);
|
||||
|
||||
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
|
||||
} else {
|
||||
|
|
@ -198,7 +198,7 @@ class GemmaContext {
|
|||
LogDebug("Creating new conversation");
|
||||
// Create a new ConversationData object using make_shared
|
||||
conversation_cache[name] = std::make_shared<ConversationData>(
|
||||
model.GetModelConfig(), inference_args, ctx.allocator);
|
||||
model.Config(), inference_args, ctx.allocator);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -236,8 +236,7 @@ class Gemma {
|
|||
ThreadingContext& ctx);
|
||||
~Gemma();
|
||||
|
||||
// TODO: rename to Config()
|
||||
const ModelConfig& GetModelConfig() const { return model_.Config(); }
|
||||
const ModelConfig& Config() const { return model_.Config(); }
|
||||
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
|
||||
const WeightsPtrs& Weights() const { return weights_; }
|
||||
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
size_t abs_pos = 0; // across turns
|
||||
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
||||
size_t prompt_size = 0;
|
||||
const ModelConfig& config = gemma.GetModelConfig();
|
||||
const ModelConfig& config = gemma.Config();
|
||||
|
||||
std::mt19937 gen;
|
||||
InitGenerator(inference, gen);
|
||||
|
|
@ -258,7 +258,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
MatMulEnv env(ctx);
|
||||
if (inference.verbosity >= 2) env.print_best = true;
|
||||
const Gemma gemma(loader, inference, ctx);
|
||||
KVCache kv_cache(gemma.GetModelConfig(), inference, ctx.allocator);
|
||||
KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
||||
|
||||
if (inference.verbosity >= 1) {
|
||||
std::string instructions =
|
||||
|
|
@ -285,7 +285,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
if (inference.IsInteractive()) {
|
||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||
<< kAsciiArtBanner << "\n\n";
|
||||
ShowConfig(loader, threading, inference, gemma.GetModelConfig(), ctx);
|
||||
ShowConfig(loader, threading, inference, gemma.Config(), ctx);
|
||||
std::cout << "\n" << instructions << "\n";
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ namespace gcpp {
|
|||
void PaliGemmaHelper::InitVit(const std::string& path) {
|
||||
HWY_ASSERT(env_->GetGemma() != nullptr);
|
||||
const Gemma& gemma = *(env_->GetGemma());
|
||||
const ModelConfig& config = gemma.GetModelConfig();
|
||||
const ModelConfig& config = gemma.Config();
|
||||
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
|
||||
|
||||
image_tokens_ = std::make_unique<ImageTokens>(
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ TEST_F(PaliGemmaTest, QueryObjects) {
|
|||
const char* question = "answer en What objects are in the image?";
|
||||
// 3B PT/Mix 224, 10B Mix 224
|
||||
const char* expected_substring = "Building, Tower";
|
||||
const Model model = s_env->GetGemma()->GetModelConfig().model;
|
||||
const Model model = s_env->GetGemma()->Config().model;
|
||||
if (model == Model::PALIGEMMA2_3B_448) {
|
||||
expected_substring = "Lake.";
|
||||
} else if (model == Model::PALIGEMMA2_10B_224) {
|
||||
|
|
|
|||
|
|
@ -167,7 +167,7 @@ class GemmaModel {
|
|||
void SetImage(const py::array_t<float, py::array::c_style |
|
||||
py::array::forcecast>& image) {
|
||||
const gcpp::Gemma& gemma = *env_.GetGemma();
|
||||
const gcpp::ModelConfig& config = gemma.GetModelConfig();
|
||||
const gcpp::ModelConfig& config = gemma.Config();
|
||||
if (config.wrapping != gcpp::PromptWrapping::PALIGEMMA &&
|
||||
config.wrapping != gcpp::PromptWrapping::GEMMA_VLM) {
|
||||
throw std::invalid_argument("Not a PaliGemma model.");
|
||||
|
|
|
|||
Loading…
Reference in New Issue