diff --git a/BUILD.bazel b/BUILD.bazel index c37dd17..dbf64e3 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -551,6 +551,7 @@ cc_library( "//compression:compress", "@highway//:hwy", "@highway//:nanobenchmark", + "@highway//:profiler", ], ) diff --git a/evals/benchmark.cc b/evals/benchmark.cc index aceee59..738070a 100644 --- a/evals/benchmark.cc +++ b/evals/benchmark.cc @@ -75,8 +75,9 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text, std::vector prompt_slice(prompt.begin() + pos, prompt.begin() + pos + num_tokens); KVCache kv_cache(gemma.GetModelConfig(), gemma.Inference()); - float entropy = ComputeCrossEntropy( - *env.GetGemma(), num_tokens, prompt_slice, kv_cache, env.Verbosity()); + float entropy = + ComputeCrossEntropy(*env.GetGemma(), num_tokens, prompt_slice, kv_cache, + env.MutableEnv(), env.Verbosity()); total_entropy += entropy; LogSpeedStats(time_start, pos + num_tokens); std::string text_slice = env.StringFromTokens(prompt_slice); diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 902bc87..98a4761 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -32,6 +32,7 @@ #include "util/threading_context.h" #include "hwy/highway.h" #include "hwy/per_target.h" // DispatchedTarget +#include "hwy/profiler.h" // PROFILER_ENABLED #include "hwy/timer.h" namespace gcpp { @@ -50,7 +51,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) { GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, const InferenceArgs& inference) : env_(MakeMatMulEnv(threading, inference)), - gemma_(loader, inference, env_) { + gemma_(loader, inference, env_.ctx.pools) { const ModelConfig& config = gemma_.GetModelConfig(); // Only allocate one for starters because GenerateBatch might not be called. kv_caches_.push_back(KVCache(config, inference)); @@ -94,7 +95,7 @@ QueryResult GemmaEnv::QueryModel(const std::vector& tokens) { } gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity }; runtime_config_.batch_stream_token = batch_stream_token; - gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], + gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_, timing_info); return result; } @@ -104,7 +105,7 @@ void GemmaEnv::QueryModel( gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity }; const StreamFunc previous_stream_token = runtime_config_.stream_token; runtime_config_.stream_token = stream_token; - gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], + gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_, timing_info); runtime_config_.stream_token = previous_stream_token; } @@ -146,7 +147,7 @@ std::vector GemmaEnv::BatchQueryModel( gcpp::AllQueries all_queries(queries_prompt, kv_caches, prefix_end); gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity}; - gemma_.GenerateBatch(runtime_config_, all_queries, timing_info); + gemma_.GenerateBatch(runtime_config_, all_queries, env_, timing_info); return res; } @@ -176,7 +177,7 @@ float GemmaEnv::CrossEntropy(const std::string& input) { std::vector prompt = Tokenize(input); prompt.insert(prompt.begin(), BOS_ID); return ComputeCrossEntropy(*GetGemma(), /*max_generated_tokens=*/3072, prompt, - MutableKVCache(), + MutableKVCache(), env_, /*verbosity=*/0) / static_cast(input.size()); } @@ -247,13 +248,13 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, "CPU : %s, bind %d\n" "CPU topology : %s, %s, %s\n" "Instruction set : %s (%zu bits)\n" - "Compiled config : %s\n" - "Memory MiB : %4zu, %4zu free\n", + "Compiled config : %s, profiler %d\n" + "Memory MiB : %4zu\n", dt, cpu100, static_cast(threading.bind), ctx.topology.TopologyString(), ctx.pools.PinString(), CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()), - ctx.allocator.VectorBytes() * 8, CompiledConfig(), - ctx.allocator.TotalMiB(), ctx.allocator.FreeMiB()); + ctx.allocator.VectorBytes() * 8, CompiledConfig(), PROFILER_ENABLED, + ctx.allocator.TotalMiB()); } } diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 870ad02..176267e 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -112,6 +112,7 @@ class GemmaEnv { RuntimeConfig& MutableConfig() { return runtime_config_; } std::mt19937& MutableGen() { return gen_; } KVCache& MutableKVCache() { return kv_caches_[0]; } + MatMulEnv& MutableEnv() { return env_; } private: MatMulEnv env_; diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index e1a6ff4..f94ea11 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -99,7 +99,7 @@ HWY_EXPORT(CallSoftmax); float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens, const std::vector& prompt, KVCache& kv_cache, - int verbosity) { + MatMulEnv& env, int verbosity) { const StreamFunc stream_token = [](int, float) { return true; }; const int vocab_size = gemma.GetModelConfig().vocab_size; @@ -145,7 +145,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens, }; TimingInfo timing_info; - gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info); + gemma.Generate(runtime, prompt0, 0, kv_cache, env, timing_info); const float scale = 1.0f / std::log(2.0f); return cross_entropy * scale; diff --git a/evals/cross_entropy.h b/evals/cross_entropy.h index 0b4479e..0a143cc 100644 --- a/evals/cross_entropy.h +++ b/evals/cross_entropy.h @@ -26,7 +26,7 @@ namespace gcpp { float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens, const std::vector& prompt, KVCache& kv_cache, - int verbosity); + MatMulEnv& env, int verbosity); } // namespace gcpp diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index a22ba32..96ff08b 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -127,7 +127,7 @@ TEST_F(GemmaTest, Multiturn) { config.wrapping, abs_pos, mutable_prompt); model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(), - timing_info); + s_env->MutableEnv(), timing_info); // Note: we do not rewind any tokens here. If the model // produced one and WrapAndTokenize() inserts another one, it will just be // duplicated. @@ -139,7 +139,7 @@ TEST_F(GemmaTest, Multiturn) { // access to the previous turn by asking to reproduce. response.clear(); model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(), - timing_info); + s_env->MutableEnv(), timing_info); fprintf(stderr, "decoded: '%s'\n", response.c_str()); bool remembered_turquoise = response.find("turquoise") != std::string::npos; // NOLINT diff --git a/evals/run_mmlu.cc b/evals/run_mmlu.cc index 9cffc41..b6537fe 100644 --- a/evals/run_mmlu.cc +++ b/evals/run_mmlu.cc @@ -131,7 +131,8 @@ void Run(GemmaEnv& env, JsonArgs& json) { .stream_token = stream_token, }; env.GetGemma()->Generate(runtime_config, prompt, /*pos=*/0, - env.MutableKVCache(), timing_info); + env.MutableKVCache(), env.MutableEnv(), + timing_info); std::string output_string = env.StringFromTokens(predicted_token_ids); fprintf(stderr, "Correct %s, model '%s'\n", correct_answer.c_str(), diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index e5b57da..e4cfcd5 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -52,7 +52,7 @@ int main(int argc, char** argv) { // Instantiate model and KV Cache gcpp::MatMulEnv env(MakeMatMulEnv(threading, inference)); - gcpp::Gemma gemma(loader, inference, env); + gcpp::Gemma gemma(loader, inference, env.ctx.pools); gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference); size_t generated = 0; @@ -93,5 +93,5 @@ int main(int argc, char** argv) { return !reject_tokens.contains(token); }, }; - gemma.Generate(runtime_config, tokens, 0, kv_cache, timing_info); + gemma.Generate(runtime_config, tokens, 0, kv_cache, env, timing_info); } diff --git a/examples/simplified_gemma/gemma.hpp b/examples/simplified_gemma/gemma.hpp index 2372591..b5eab41 100644 --- a/examples/simplified_gemma/gemma.hpp +++ b/examples/simplified_gemma/gemma.hpp @@ -36,7 +36,7 @@ class SimplifiedGemma { const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(), const gcpp::InferenceArgs& inference = gcpp::InferenceArgs()) : env_(MakeMatMulEnv(threading, inference)), - gemma_(loader, inference, env_), + gemma_(loader, inference, env_.ctx.pools), kv_cache_(gemma_.GetModelConfig(), inference) { // Initialize random number generator std::random_device rd; @@ -83,7 +83,7 @@ class SimplifiedGemma { return !reject_tokens.contains(token); }, }; - gemma_.Generate(runtime_config, tokens, 0, kv_cache_, timing_info); + gemma_.Generate(runtime_config, tokens, 0, kv_cache_, env_, timing_info); } ~SimplifiedGemma() = default; diff --git a/gemma/bindings/context.cc b/gemma/bindings/context.cc index 89fb650..e540bb4 100644 --- a/gemma/bindings/context.cc +++ b/gemma/bindings/context.cc @@ -103,7 +103,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader, threading_args(threading_args), matmul_env(MakeMatMulEnv(threading_args, inference_args)), active_conversation_name("default"), - model(loader, inference_args, matmul_env) { + model(loader, inference_args, matmul_env.ctx.pools) { std::stringstream ss; LogDebug("Creating initial ConversationData"); @@ -207,7 +207,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string, // Pass the populated image object to GenerateImageTokens model.GenerateImageTokens(runtime_config, active_conversation->kv_cache->SeqLen(), image, - image_tokens); + image_tokens, matmul_env); double image_tokens_duration = hwy::platform::Now() - image_tokens_start; ss.str(""); @@ -244,7 +244,8 @@ int GemmaContext::GenerateInternal(const char* prompt_string, // Pass the KVCache object by reference from the active conversation model.Generate(runtime_config, prompt_span, active_conversation->abs_pos, - prefix_end, *(active_conversation->kv_cache), timing_info); + prefix_end, *active_conversation->kv_cache, matmul_env, + timing_info); // prepare for next turn if (!inference_args.multiturn || diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 1f21ec0..79ba7cf 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -610,62 +610,62 @@ MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args, } Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference, - MatMulEnv& env) - : env_(env), - reader_(loader.weights), + NestedPools& pools) + : reader_(loader.weights), model_(reader_, loader.tokenizer, loader.wrapping), weights_(model_.Config()), chat_template_(model_.Tokenizer(), model_.Config().model), inference_(inference) { weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_, - env.ctx.pools.Pool()); + pools.Pool()); reader_.CloseFile(); } Gemma::~Gemma() = default; -void Gemma::Save(const Path& weights_path, hwy::ThreadPool& pool) const { +void Gemma::Save(const Path& weights_path, NestedPools& pools) const { BlobWriter writer; const std::vector serialized_mat_ptrs = weights_.AddTensorDataToWriter(writer); WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs, - writer, env_.ctx.pools.Pool(), weights_path); + writer, pools.Pool(), weights_path); } void Gemma::Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, size_t prefix_end, - KVCache& kv_cache, TimingInfo& timing_info) const { - env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); + KVCache& kv_cache, MatMulEnv& env, + TimingInfo& timing_info) const { + env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); HWY_DYNAMIC_DISPATCH(GenerateSingleT)(prompt, pos, prefix_end, model_.Config(), runtime_config, - weights_, kv_cache, env_, timing_info); + weights_, kv_cache, env, timing_info); - env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); + env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, - AllQueries& all_queries, + AllQueries& all_queries, MatMulEnv& env, TimingInfo& timing_info) const { - env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); + env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); HWY_DYNAMIC_DISPATCH(GenerateBatchT)(model_.Config(), runtime_config, - weights_, all_queries, env_, - timing_info); + weights_, all_queries, env, timing_info); - env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); + env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config, size_t seq_len, const Image& image, - ImageTokens& image_tokens) const { - env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); + ImageTokens& image_tokens, + MatMulEnv& env) const { + env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(model_.Config(), runtime_config, seq_len, weights_, image, - image_tokens, env_); + image_tokens, env); - env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); + env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } } // namespace gcpp diff --git a/gemma/gemma.h b/gemma/gemma.h index 423133d..27ce523 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -229,16 +229,16 @@ struct TimingInfo { MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args, const InferenceArgs& inference_args); +// After construction, all methods are const and thread-compatible if using +// separate MatMulEnv for each thread. class Gemma { public: // Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`. - // `env` must remain valid for the lifetime of this Gemma. + // `pools` are used to parallelize loading. Gemma(const LoaderArgs& loader, const InferenceArgs& inference, - MatMulEnv& env); - + NestedPools& pools); ~Gemma(); - MatMulEnv& Env() const { return env_; } // TODO: rename to Config() const ModelConfig& GetModelConfig() const { return model_.Config(); } const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); } @@ -246,29 +246,31 @@ class Gemma { const GemmaChatTemplate& ChatTemplate() const { return chat_template_; } const InferenceArgs& Inference() const { return inference_; } - void Save(const Path& weights_path, hwy::ThreadPool& pool) const; + void Save(const Path& weights_path, NestedPools& pools) const; // `pos` is the position in the KV cache. Users are responsible for // incrementing it in the `*StreamFunc`, or setting to zero for single-turn. void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt, - size_t pos, KVCache& kv_cache, TimingInfo& timing_info) const { - Generate(runtime_config, prompt, pos, /*prefix_end=*/0, kv_cache, + size_t pos, KVCache& kv_cache, MatMulEnv& env, + TimingInfo& timing_info) const { + Generate(runtime_config, prompt, pos, /*prefix_end=*/0, kv_cache, env, timing_info); } // For prefix-LM style attention, we can pass the end of the prefix. void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, size_t prefix_end, KVCache& kv_cache, - TimingInfo& timing_info) const; + MatMulEnv& env, TimingInfo& timing_info) const; void GenerateBatch(const RuntimeConfig& runtime_config, - AllQueries& all_queries, TimingInfo& timing_info) const; + AllQueries& all_queries, MatMulEnv& env, + TimingInfo& timing_info) const; // Generates the image tokens by running the image encoder ViT. void GenerateImageTokens(const RuntimeConfig& runtime_config, size_t seq_len, - const Image& image, ImageTokens& image_tokens) const; + const Image& image, ImageTokens& image_tokens, + MatMulEnv& env) const; private: - MatMulEnv& env_; BlobReader reader_; ModelStore model_; std::vector mat_owners_; diff --git a/gemma/run.cc b/gemma/run.cc index 071a5a5..5a2ba58 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -92,7 +92,7 @@ std::string GetPrompt(const InferenceArgs& inference) { // The main Read-Eval-Print Loop. void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, - const Gemma& gemma, KVCache& kv_cache) { + const Gemma& gemma, KVCache& kv_cache, MatMulEnv& env) { PROFILER_ZONE("Gen.misc"); size_t abs_pos = 0; // across turns size_t tokens_generated_this_turn = 0; // differentiates prefill from reply @@ -111,7 +111,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, config.model_dim) : Extents2D(0, 0), MatPadding::kOdd); - image_tokens.AllocateAndAttachRowPtrs(gemma.Env().row_ptrs); + image_tokens.AllocateAndAttachRowPtrs(env.row_ptrs); if (have_image) { HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA || config.wrapping == PromptWrapping::GEMMA_VLM); @@ -123,7 +123,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, .use_spinning = threading.spin}; double image_tokens_start = hwy::platform::Now(); gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image, - image_tokens); + image_tokens, env); if (inference.verbosity >= 1) { double image_tokens_duration = hwy::platform::Now() - image_tokens_start; fprintf(stderr, @@ -224,7 +224,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, if (inference.verbosity >= 1) { std::cerr << "\n[ Reading prompt ] " << std::flush; } - gemma.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache, + gemma.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache, env, timing_info); std::cout << "\n\n"; @@ -256,7 +256,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading, MatMulEnv env(MakeMatMulEnv(threading, inference)); if (inference.verbosity >= 2) env.print_best = true; - const Gemma gemma(loader, inference, env); + const Gemma gemma(loader, inference, env.ctx.pools); KVCache kv_cache(gemma.GetModelConfig(), inference); if (inference.verbosity >= 1) { @@ -289,7 +289,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading, } } - ReplGemma(threading, inference, gemma, kv_cache); + ReplGemma(threading, inference, gemma, kv_cache, env); } } // namespace gcpp diff --git a/io/migrate_weights.cc b/io/migrate_weights.cc index 7588326..aa500bb 100644 --- a/io/migrate_weights.cc +++ b/io/migrate_weights.cc @@ -44,6 +44,6 @@ int main(int argc, char** argv) { } gcpp::GemmaEnv env(argc, argv); - env.GetGemma()->Save(args.output_weights, env.Env().ctx.pools.Pool()); + env.GetGemma()->Save(args.output_weights, env.Env().ctx.pools); return 0; } diff --git a/paligemma/paligemma_helper.cc b/paligemma/paligemma_helper.cc index d5255c1..6f811e3 100644 --- a/paligemma/paligemma_helper.cc +++ b/paligemma/paligemma_helper.cc @@ -30,7 +30,7 @@ void PaliGemmaHelper::InitVit(const std::string& path) { RuntimeConfig runtime_config = {.gen = &env_->MutableGen(), .verbosity = 0}; gemma.GenerateImageTokens(runtime_config, env_->MutableKVCache().SeqLen(), - image, *image_tokens_); + image, *image_tokens_, env_->MutableEnv()); } std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const { @@ -61,7 +61,7 @@ std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const { const size_t prefix_end = tokens.size(); TimingInfo timing_info = {.verbosity = 0}; model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end, - env_->MutableKVCache(), timing_info); + env_->MutableKVCache(), env_->MutableEnv(), timing_info); return response; } diff --git a/python/gemma_py.cc b/python/gemma_py.cc index c8f5192..3496858 100644 --- a/python/gemma_py.cc +++ b/python/gemma_py.cc @@ -48,16 +48,16 @@ class GemmaModel { GemmaModel(const gcpp::LoaderArgs& loader, const gcpp::ThreadingArgs& threading, const gcpp::InferenceArgs& inference) - : gemma_(loader, threading, inference), last_prob_(0.0f) {} + : env_(loader, threading, inference), last_prob_(0.0f) {} // Generates a single example, given a prompt and a callback to stream the // generated tokens. void GenerateEx(std::string prompt, gcpp::StreamFunc stream, size_t max_generated_tokens, float temperature, float seed, gcpp::AcceptFunc accept, bool skip_prompt) { - gemma_.MutableGen().seed(seed); - std::vector prompt_tokens = gemma_.WrapAndTokenize(prompt); - gcpp::RuntimeConfig& config = gemma_.MutableConfig(); + env_.MutableGen().seed(seed); + std::vector prompt_tokens = env_.WrapAndTokenize(prompt); + gcpp::RuntimeConfig& config = env_.MutableConfig(); config.max_generated_tokens = max_generated_tokens; config.temperature = temperature; config.verbosity = 0; @@ -72,8 +72,7 @@ class GemmaModel { } return stream(token, score); }; - gemma_.QueryModel(prompt_tokens, - skip_prompt ? stream_with_skipping : stream); + env_.QueryModel(prompt_tokens, skip_prompt ? stream_with_skipping : stream); } // Generates a single example, given a prompt, and returns the result. @@ -83,13 +82,13 @@ class GemmaModel { const std::vector& end) { std::set end_token_set{}; for (const std::string& end_token : end) { - std::vector end_token_ids = gemma_.Tokenize(end_token); + std::vector end_token_ids = env_.Tokenize(end_token); end_token_set.insert(end_token_ids.begin(), end_token_ids.end()); } std::vector predicted_token_ids; predicted_token_ids.reserve(max_generated_tokens); - std::vector prompt_token_ids = gemma_.WrapAndTokenize(prompt); + std::vector prompt_token_ids = env_.WrapAndTokenize(prompt); int generated = 0; auto stream_token = [&generated, &prompt_token_ids, &predicted_token_ids, &end_token_set, this](int token, float proba) { @@ -106,7 +105,7 @@ class GemmaModel { std::set accept_token_set{}; for (const std::string& accept_token : accept) { - std::vector accept_token_ids = gemma_.Tokenize(accept_token); + std::vector accept_token_ids = env_.Tokenize(accept_token); accept_token_set.insert(accept_token_ids.begin(), accept_token_ids.end()); } @@ -125,17 +124,17 @@ class GemmaModel { } }; - gemma_.MutableGen().seed(seed); - gcpp::RuntimeConfig& config = gemma_.MutableConfig(); + env_.MutableGen().seed(seed); + gcpp::RuntimeConfig& config = env_.MutableConfig(); config.max_generated_tokens = max_generated_tokens; config.temperature = temperature; config.verbosity = 0; config.accept_token = accept_token; - gemma_.QueryModel(prompt_token_ids, stream_token); + env_.QueryModel(prompt_token_ids, stream_token); if (!predicted_token_ids.empty()) { - return gemma_.StringFromTokens(predicted_token_ids); + return env_.StringFromTokens(predicted_token_ids); } else { return ""; } @@ -147,14 +146,14 @@ class GemmaModel { size_t max_generated_tokens, float temperature, float seed, size_t top_k) { - gcpp::RuntimeConfig& config = gemma_.MutableConfig(); + gcpp::RuntimeConfig& config = env_.MutableConfig(); config.max_generated_tokens = max_generated_tokens; config.temperature = temperature; config.top_k = top_k; config.verbosity = 0; - gemma_.MutableGen().seed(seed); + env_.MutableGen().seed(seed); - std::vector outputs = gemma_.BatchQueryModel(inputs); + std::vector outputs = env_.BatchQueryModel(inputs); std::vector result; result.reserve(outputs.size()); for (const gcpp::QueryResult& output : outputs) { @@ -167,7 +166,7 @@ class GemmaModel { // Generate* will use this image. Throws an error for other models. void SetImage(const py::array_t& image) { - const gcpp::Gemma& gemma = *gemma_.GetGemma(); + const gcpp::Gemma& gemma = *env_.GetGemma(); const gcpp::ModelConfig& config = gemma.GetModelConfig(); if (config.wrapping != gcpp::PromptWrapping::PALIGEMMA && config.wrapping != gcpp::PromptWrapping::GEMMA_VLM) { @@ -188,10 +187,10 @@ class GemmaModel { "image_tokens", gcpp::Extents2D(config.vit_config.seq_len, config.model_dim), gcpp::MatPadding::kOdd)); - gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(), + gcpp::RuntimeConfig runtime_config = {.gen = &env_.MutableGen(), .verbosity = 0}; - gemma.GenerateImageTokens(runtime_config, gemma_.MutableKVCache().SeqLen(), - c_image, *image_tokens_); + gemma.GenerateImageTokens(runtime_config, env_.MutableKVCache().SeqLen(), + c_image, *image_tokens_, env_.MutableEnv()); } // Generates a response to the given prompt, using the last set image. @@ -200,9 +199,9 @@ class GemmaModel { std::string prompt, size_t max_generated_tokens, float temperature, float seed, gcpp::AcceptFunc accept, std::vector prompt_tokens) { if (!image_tokens_) throw std::invalid_argument("No image set."); - const gcpp::Gemma& model = *gemma_.GetGemma(); - gemma_.MutableGen().seed(seed); - gcpp::RuntimeConfig& config = gemma_.MutableConfig(); + const gcpp::Gemma& model = *env_.GetGemma(); + env_.MutableGen().seed(seed); + gcpp::RuntimeConfig& config = env_.MutableConfig(); config.max_generated_tokens = max_generated_tokens; config.temperature = temperature; config.verbosity = 0; @@ -217,7 +216,7 @@ class GemmaModel { tokens = prompt_tokens; RemoveTrailingZeros(tokens); // Remove padding, if any. } else { - tokens = gemma_.WrapAndTokenize(prompt); + tokens = env_.WrapAndTokenize(prompt); } tokens.insert(tokens.begin(), image_tokens_->Rows(), 0); size_t num_tokens = tokens.size(); @@ -235,8 +234,8 @@ class GemmaModel { }; config.stream_token = stream_token; gcpp::TimingInfo timing_info = {.verbosity = 0}; - model.Generate(config, tokens, /*pos=*/0, prefix_end, - gemma_.MutableKVCache(), timing_info); + model.Generate(config, tokens, /*pos=*/0, prefix_end, env_.MutableKVCache(), + env_.MutableEnv(), timing_info); std::string response; model.Tokenizer().Decode(response_tokens, &response); return {response, response_tokens}; @@ -245,13 +244,13 @@ class GemmaModel { float GetLastProb() const { return last_prob_; } std::string Detokenize(const std::vector& token_ids) const { - return gemma_.StringFromTokens(token_ids); + return env_.StringFromTokens(token_ids); } - bool ModelIsLoaded() const { return gemma_.GetGemma() != nullptr; } + bool ModelIsLoaded() const { return env_.GetGemma() != nullptr; } private: - gcpp::GemmaEnv gemma_; + gcpp::GemmaEnv env_; std::unique_ptr image_tokens_; float last_prob_; };