mirror of https://github.com/google/gemma.cpp.git
Rename GetModelConfig->Config
PiperOrigin-RevId: 788506480
This commit is contained in:
parent
33fabd4ed1
commit
ac0d751d20
|
|
@ -74,7 +74,7 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
||||||
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
||||||
std::vector<int> prompt_slice(prompt.begin() + pos,
|
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||||
prompt.begin() + pos + num_tokens);
|
prompt.begin() + pos + num_tokens);
|
||||||
KVCache kv_cache(gemma.GetModelConfig(), gemma.Inference(),
|
KVCache kv_cache(gemma.Config(), gemma.Inference(),
|
||||||
env.MutableEnv().ctx.allocator);
|
env.MutableEnv().ctx.allocator);
|
||||||
float entropy =
|
float entropy =
|
||||||
ComputeCrossEntropy(*env.GetGemma(), num_tokens, prompt_slice, kv_cache,
|
ComputeCrossEntropy(*env.GetGemma(), num_tokens, prompt_slice, kv_cache,
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
|
||||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
const InferenceArgs& inference)
|
const InferenceArgs& inference)
|
||||||
: ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) {
|
: ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) {
|
||||||
const ModelConfig& config = gemma_.GetModelConfig();
|
const ModelConfig& config = gemma_.Config();
|
||||||
// Only allocate one for starters because GenerateBatch might not be called.
|
// Only allocate one for starters because GenerateBatch might not be called.
|
||||||
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator));
|
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator));
|
||||||
|
|
||||||
|
|
@ -141,7 +141,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||||
// Ensure we have at least one KVCache per query.
|
// Ensure we have at least one KVCache per query.
|
||||||
while (kv_caches_.size() < num_queries) {
|
while (kv_caches_.size() < num_queries) {
|
||||||
kv_caches_.push_back(
|
kv_caches_.push_back(
|
||||||
KVCache(gemma_.GetModelConfig(), gemma_.Inference(), ctx_.allocator));
|
KVCache(gemma_.Config(), gemma_.Inference(), ctx_.allocator));
|
||||||
}
|
}
|
||||||
const hwy::Span<KVCache> kv_caches(&kv_caches_[0], num_queries);
|
const hwy::Span<KVCache> kv_caches(&kv_caches_[0], num_queries);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,7 @@ class GemmaEnv {
|
||||||
|
|
||||||
std::vector<int> WrapAndTokenize(std::string& input) const {
|
std::vector<int> WrapAndTokenize(std::string& input) const {
|
||||||
return gcpp::WrapAndTokenize(gemma_.Tokenizer(), gemma_.ChatTemplate(),
|
return gcpp::WrapAndTokenize(gemma_.Tokenizer(), gemma_.ChatTemplate(),
|
||||||
gemma_.GetModelConfig().wrapping, 0, input);
|
gemma_.Config().wrapping, 0, input);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string StringFromTokens(const std::vector<int>& tokens) const {
|
std::string StringFromTokens(const std::vector<int>& tokens) const {
|
||||||
|
|
|
||||||
|
|
@ -102,7 +102,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
||||||
MatMulEnv& env, int verbosity) {
|
MatMulEnv& env, int verbosity) {
|
||||||
const StreamFunc stream_token = [](int, float) { return true; };
|
const StreamFunc stream_token = [](int, float) { return true; };
|
||||||
|
|
||||||
const int vocab_size = gemma.GetModelConfig().vocab_size;
|
const int vocab_size = gemma.Config().vocab_size;
|
||||||
float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s)
|
float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s)
|
||||||
size_t pos = 1;
|
size_t pos = 1;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ class GemmaTest : public ::testing::Test {
|
||||||
static void InitEnv(int argc, char** argv) {
|
static void InitEnv(int argc, char** argv) {
|
||||||
HWY_ASSERT(s_env == nullptr); // Should only be called once.
|
HWY_ASSERT(s_env == nullptr); // Should only be called once.
|
||||||
s_env = new GemmaEnv(argc, argv);
|
s_env = new GemmaEnv(argc, argv);
|
||||||
const gcpp::ModelConfig& config = s_env->GetGemma()->GetModelConfig();
|
const gcpp::ModelConfig& config = s_env->GetGemma()->Config();
|
||||||
fprintf(stderr, "Using %s\n", config.Specifier().c_str());
|
fprintf(stderr, "Using %s\n", config.Specifier().c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -98,7 +98,7 @@ TEST_F(GemmaTest, Batched) {
|
||||||
|
|
||||||
TEST_F(GemmaTest, Multiturn) {
|
TEST_F(GemmaTest, Multiturn) {
|
||||||
const Gemma* model = s_env->GetGemma();
|
const Gemma* model = s_env->GetGemma();
|
||||||
const ModelConfig& config = model->GetModelConfig();
|
const ModelConfig& config = model->Config();
|
||||||
size_t abs_pos = 0;
|
size_t abs_pos = 0;
|
||||||
std::string response;
|
std::string response;
|
||||||
auto stream_token = [&](size_t query_idx, size_t pos, int token, float) {
|
auto stream_token = [&](size_t query_idx, size_t pos, int token, float) {
|
||||||
|
|
@ -149,7 +149,7 @@ TEST_F(GemmaTest, Multiturn) {
|
||||||
|
|
||||||
TEST_F(GemmaTest, CrossEntropySmall) {
|
TEST_F(GemmaTest, CrossEntropySmall) {
|
||||||
HWY_ASSERT(s_env->GetGemma() != nullptr);
|
HWY_ASSERT(s_env->GetGemma() != nullptr);
|
||||||
const ModelConfig& config = s_env->GetGemma()->GetModelConfig();
|
const ModelConfig& config = s_env->GetGemma()->Config();
|
||||||
static const char kSmall[] =
|
static const char kSmall[] =
|
||||||
"The capital of Hungary is Budapest which is located in Europe.";
|
"The capital of Hungary is Budapest which is located in Europe.";
|
||||||
float entropy = s_env->CrossEntropy(kSmall);
|
float entropy = s_env->CrossEntropy(kSmall);
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ int main(int argc, char** argv) {
|
||||||
gcpp::ThreadingContext ctx(gcpp::UpdateArgs(threading, inference));
|
gcpp::ThreadingContext ctx(gcpp::UpdateArgs(threading, inference));
|
||||||
gcpp::MatMulEnv env(ctx);
|
gcpp::MatMulEnv env(ctx);
|
||||||
gcpp::Gemma gemma(loader, inference, ctx);
|
gcpp::Gemma gemma(loader, inference, ctx);
|
||||||
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference, ctx.allocator);
|
gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
||||||
size_t generated = 0;
|
size_t generated = 0;
|
||||||
|
|
||||||
// Initialize random number generator
|
// Initialize random number generator
|
||||||
|
|
@ -66,7 +66,7 @@ int main(int argc, char** argv) {
|
||||||
std::string prompt = "Write a greeting to the world.";
|
std::string prompt = "Write a greeting to the world.";
|
||||||
const std::vector<int> tokens =
|
const std::vector<int> tokens =
|
||||||
gcpp::WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
|
gcpp::WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
|
||||||
gemma.GetModelConfig().wrapping, generated, prompt);
|
gemma.Config().wrapping, generated, prompt);
|
||||||
const size_t prompt_size = tokens.size();
|
const size_t prompt_size = tokens.size();
|
||||||
|
|
||||||
// This callback function gets invoked every time a token is generated
|
// This callback function gets invoked every time a token is generated
|
||||||
|
|
@ -74,7 +74,7 @@ int main(int argc, char** argv) {
|
||||||
++generated;
|
++generated;
|
||||||
if (generated < prompt_size) {
|
if (generated < prompt_size) {
|
||||||
// print feedback
|
// print feedback
|
||||||
} else if (!gemma.GetModelConfig().IsEOS(token)) {
|
} else if (!gemma.Config().IsEOS(token)) {
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
HWY_ASSERT(gemma.Tokenizer().Decode({token}, &token_text));
|
HWY_ASSERT(gemma.Tokenizer().Decode({token}, &token_text));
|
||||||
std::cout << token_text << std::flush;
|
std::cout << token_text << std::flush;
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ class SimplifiedGemma {
|
||||||
: ctx_(UpdateArgs(threading, inference)),
|
: ctx_(UpdateArgs(threading, inference)),
|
||||||
env_(ctx_),
|
env_(ctx_),
|
||||||
gemma_(loader, inference, ctx_),
|
gemma_(loader, inference, ctx_),
|
||||||
kv_cache_(gemma_.GetModelConfig(), inference, ctx_.allocator) {
|
kv_cache_(gemma_.Config(), inference, ctx_.allocator) {
|
||||||
// Initialize random number generator
|
// Initialize random number generator
|
||||||
std::random_device rd;
|
std::random_device rd;
|
||||||
gen_.seed(rd());
|
gen_.seed(rd());
|
||||||
|
|
@ -56,7 +56,7 @@ class SimplifiedGemma {
|
||||||
|
|
||||||
const std::vector<int> tokens = gcpp::WrapAndTokenize(
|
const std::vector<int> tokens = gcpp::WrapAndTokenize(
|
||||||
gemma_.Tokenizer(), gemma_.ChatTemplate(),
|
gemma_.Tokenizer(), gemma_.ChatTemplate(),
|
||||||
gemma_.GetModelConfig().wrapping, generated, prompt);
|
gemma_.Config().wrapping, generated, prompt);
|
||||||
const size_t prompt_size = tokens.size();
|
const size_t prompt_size = tokens.size();
|
||||||
|
|
||||||
// This callback function gets invoked every time a token is generated
|
// This callback function gets invoked every time a token is generated
|
||||||
|
|
@ -64,7 +64,7 @@ class SimplifiedGemma {
|
||||||
++generated;
|
++generated;
|
||||||
if (generated < prompt_size) {
|
if (generated < prompt_size) {
|
||||||
// print feedback
|
// print feedback
|
||||||
} else if (!gemma_.GetModelConfig().IsEOS(token)) {
|
} else if (!gemma_.Config().IsEOS(token)) {
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
HWY_ASSERT(gemma_.Tokenizer().Decode({token}, &token_text));
|
HWY_ASSERT(gemma_.Tokenizer().Decode({token}, &token_text));
|
||||||
std::cout << token_text << std::flush;
|
std::cout << token_text << std::flush;
|
||||||
|
|
|
||||||
|
|
@ -112,7 +112,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
|
||||||
LogDebug("Creating initial ConversationData");
|
LogDebug("Creating initial ConversationData");
|
||||||
// Create the initial ConversationData object using make_shared
|
// Create the initial ConversationData object using make_shared
|
||||||
active_conversation = std::make_shared<ConversationData>(
|
active_conversation = std::make_shared<ConversationData>(
|
||||||
model.GetModelConfig(), inference_args, ctx.allocator);
|
model.Config(), inference_args, ctx.allocator);
|
||||||
|
|
||||||
LogDebug(
|
LogDebug(
|
||||||
"Storing initial ConversationData in conversation_cache[\"default\"]");
|
"Storing initial ConversationData in conversation_cache[\"default\"]");
|
||||||
|
|
@ -150,7 +150,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
||||||
const bool in_prompt = tokens_generated_this_turn < prompt_size;
|
const bool in_prompt = tokens_generated_this_turn < prompt_size;
|
||||||
const bool first_response_token = tokens_generated_this_turn == prompt_size;
|
const bool first_response_token = tokens_generated_this_turn == prompt_size;
|
||||||
++tokens_generated_this_turn;
|
++tokens_generated_this_turn;
|
||||||
if (in_prompt || model.GetModelConfig().IsEOS(token)) {
|
if (in_prompt || model.Config().IsEOS(token)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -180,7 +180,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
||||||
inference_args.CopyTo(runtime_config);
|
inference_args.CopyTo(runtime_config);
|
||||||
size_t prefix_end = 0;
|
size_t prefix_end = 0;
|
||||||
|
|
||||||
const ModelConfig& model_config = model.GetModelConfig();
|
const ModelConfig& model_config = model.Config();
|
||||||
|
|
||||||
// generate
|
// generate
|
||||||
std::vector<int> prompt;
|
std::vector<int> prompt;
|
||||||
|
|
|
||||||
|
|
@ -180,7 +180,7 @@ class GemmaContext {
|
||||||
active_conversation->abs_pos = 0;
|
active_conversation->abs_pos = 0;
|
||||||
// Replace the cache within the current ConversationData object
|
// Replace the cache within the current ConversationData object
|
||||||
active_conversation->kv_cache = std::make_unique<KVCache>(
|
active_conversation->kv_cache = std::make_unique<KVCache>(
|
||||||
model.GetModelConfig(), inference_args, ctx.allocator);
|
model.Config(), inference_args, ctx.allocator);
|
||||||
|
|
||||||
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
|
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -198,7 +198,7 @@ class GemmaContext {
|
||||||
LogDebug("Creating new conversation");
|
LogDebug("Creating new conversation");
|
||||||
// Create a new ConversationData object using make_shared
|
// Create a new ConversationData object using make_shared
|
||||||
conversation_cache[name] = std::make_shared<ConversationData>(
|
conversation_cache[name] = std::make_shared<ConversationData>(
|
||||||
model.GetModelConfig(), inference_args, ctx.allocator);
|
model.Config(), inference_args, ctx.allocator);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -236,8 +236,7 @@ class Gemma {
|
||||||
ThreadingContext& ctx);
|
ThreadingContext& ctx);
|
||||||
~Gemma();
|
~Gemma();
|
||||||
|
|
||||||
// TODO: rename to Config()
|
const ModelConfig& Config() const { return model_.Config(); }
|
||||||
const ModelConfig& GetModelConfig() const { return model_.Config(); }
|
|
||||||
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
|
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
|
||||||
const WeightsPtrs& Weights() const { return weights_; }
|
const WeightsPtrs& Weights() const { return weights_; }
|
||||||
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
|
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
|
||||||
|
|
|
||||||
|
|
@ -97,7 +97,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
size_t abs_pos = 0; // across turns
|
size_t abs_pos = 0; // across turns
|
||||||
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
||||||
size_t prompt_size = 0;
|
size_t prompt_size = 0;
|
||||||
const ModelConfig& config = gemma.GetModelConfig();
|
const ModelConfig& config = gemma.Config();
|
||||||
|
|
||||||
std::mt19937 gen;
|
std::mt19937 gen;
|
||||||
InitGenerator(inference, gen);
|
InitGenerator(inference, gen);
|
||||||
|
|
@ -258,7 +258,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
MatMulEnv env(ctx);
|
MatMulEnv env(ctx);
|
||||||
if (inference.verbosity >= 2) env.print_best = true;
|
if (inference.verbosity >= 2) env.print_best = true;
|
||||||
const Gemma gemma(loader, inference, ctx);
|
const Gemma gemma(loader, inference, ctx);
|
||||||
KVCache kv_cache(gemma.GetModelConfig(), inference, ctx.allocator);
|
KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
||||||
|
|
||||||
if (inference.verbosity >= 1) {
|
if (inference.verbosity >= 1) {
|
||||||
std::string instructions =
|
std::string instructions =
|
||||||
|
|
@ -285,7 +285,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
if (inference.IsInteractive()) {
|
if (inference.IsInteractive()) {
|
||||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||||
<< kAsciiArtBanner << "\n\n";
|
<< kAsciiArtBanner << "\n\n";
|
||||||
ShowConfig(loader, threading, inference, gemma.GetModelConfig(), ctx);
|
ShowConfig(loader, threading, inference, gemma.Config(), ctx);
|
||||||
std::cout << "\n" << instructions << "\n";
|
std::cout << "\n" << instructions << "\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ namespace gcpp {
|
||||||
void PaliGemmaHelper::InitVit(const std::string& path) {
|
void PaliGemmaHelper::InitVit(const std::string& path) {
|
||||||
HWY_ASSERT(env_->GetGemma() != nullptr);
|
HWY_ASSERT(env_->GetGemma() != nullptr);
|
||||||
const Gemma& gemma = *(env_->GetGemma());
|
const Gemma& gemma = *(env_->GetGemma());
|
||||||
const ModelConfig& config = gemma.GetModelConfig();
|
const ModelConfig& config = gemma.Config();
|
||||||
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
|
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
|
||||||
|
|
||||||
image_tokens_ = std::make_unique<ImageTokens>(
|
image_tokens_ = std::make_unique<ImageTokens>(
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ TEST_F(PaliGemmaTest, QueryObjects) {
|
||||||
const char* question = "answer en What objects are in the image?";
|
const char* question = "answer en What objects are in the image?";
|
||||||
// 3B PT/Mix 224, 10B Mix 224
|
// 3B PT/Mix 224, 10B Mix 224
|
||||||
const char* expected_substring = "Building, Tower";
|
const char* expected_substring = "Building, Tower";
|
||||||
const Model model = s_env->GetGemma()->GetModelConfig().model;
|
const Model model = s_env->GetGemma()->Config().model;
|
||||||
if (model == Model::PALIGEMMA2_3B_448) {
|
if (model == Model::PALIGEMMA2_3B_448) {
|
||||||
expected_substring = "Lake.";
|
expected_substring = "Lake.";
|
||||||
} else if (model == Model::PALIGEMMA2_10B_224) {
|
} else if (model == Model::PALIGEMMA2_10B_224) {
|
||||||
|
|
|
||||||
|
|
@ -167,7 +167,7 @@ class GemmaModel {
|
||||||
void SetImage(const py::array_t<float, py::array::c_style |
|
void SetImage(const py::array_t<float, py::array::c_style |
|
||||||
py::array::forcecast>& image) {
|
py::array::forcecast>& image) {
|
||||||
const gcpp::Gemma& gemma = *env_.GetGemma();
|
const gcpp::Gemma& gemma = *env_.GetGemma();
|
||||||
const gcpp::ModelConfig& config = gemma.GetModelConfig();
|
const gcpp::ModelConfig& config = gemma.Config();
|
||||||
if (config.wrapping != gcpp::PromptWrapping::PALIGEMMA &&
|
if (config.wrapping != gcpp::PromptWrapping::PALIGEMMA &&
|
||||||
config.wrapping != gcpp::PromptWrapping::GEMMA_VLM) {
|
config.wrapping != gcpp::PromptWrapping::GEMMA_VLM) {
|
||||||
throw std::invalid_argument("Not a PaliGemma model.");
|
throw std::invalid_argument("Not a PaliGemma model.");
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue