Move MatMulEnv out of Gemma to enable concurrent calls

Also update benchmark_helper config print: add profiler, remove free mem

PiperOrigin-RevId: 774662974
This commit is contained in:
Jan Wassenberg 2025-06-23 01:19:39 -07:00 committed by Copybara-Service
parent 0f70f285e0
commit a04cc287b2
17 changed files with 99 additions and 92 deletions

View File

@ -551,6 +551,7 @@ cc_library(
"//compression:compress",
"@highway//:hwy",
"@highway//:nanobenchmark",
"@highway//:profiler",
],
)

View File

@ -75,8 +75,9 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
std::vector<int> prompt_slice(prompt.begin() + pos,
prompt.begin() + pos + num_tokens);
KVCache kv_cache(gemma.GetModelConfig(), gemma.Inference());
float entropy = ComputeCrossEntropy(
*env.GetGemma(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
float entropy =
ComputeCrossEntropy(*env.GetGemma(), num_tokens, prompt_slice, kv_cache,
env.MutableEnv(), env.Verbosity());
total_entropy += entropy;
LogSpeedStats(time_start, pos + num_tokens);
std::string text_slice = env.StringFromTokens(prompt_slice);

View File

@ -32,6 +32,7 @@
#include "util/threading_context.h"
#include "hwy/highway.h"
#include "hwy/per_target.h" // DispatchedTarget
#include "hwy/profiler.h" // PROFILER_ENABLED
#include "hwy/timer.h"
namespace gcpp {
@ -50,7 +51,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference)
: env_(MakeMatMulEnv(threading, inference)),
gemma_(loader, inference, env_) {
gemma_(loader, inference, env_.ctx.pools) {
const ModelConfig& config = gemma_.GetModelConfig();
// Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.push_back(KVCache(config, inference));
@ -94,7 +95,7 @@ QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
}
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
runtime_config_.batch_stream_token = batch_stream_token;
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
timing_info);
return result;
}
@ -104,7 +105,7 @@ void GemmaEnv::QueryModel(
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
const StreamFunc previous_stream_token = runtime_config_.stream_token;
runtime_config_.stream_token = stream_token;
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
timing_info);
runtime_config_.stream_token = previous_stream_token;
}
@ -146,7 +147,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
gcpp::AllQueries all_queries(queries_prompt, kv_caches, prefix_end);
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
gemma_.GenerateBatch(runtime_config_, all_queries, timing_info);
gemma_.GenerateBatch(runtime_config_, all_queries, env_, timing_info);
return res;
}
@ -176,7 +177,7 @@ float GemmaEnv::CrossEntropy(const std::string& input) {
std::vector<int> prompt = Tokenize(input);
prompt.insert(prompt.begin(), BOS_ID);
return ComputeCrossEntropy(*GetGemma(), /*max_generated_tokens=*/3072, prompt,
MutableKVCache(),
MutableKVCache(), env_,
/*verbosity=*/0) /
static_cast<int>(input.size());
}
@ -247,13 +248,13 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
"CPU : %s, bind %d\n"
"CPU topology : %s, %s, %s\n"
"Instruction set : %s (%zu bits)\n"
"Compiled config : %s\n"
"Memory MiB : %4zu, %4zu free\n",
"Compiled config : %s, profiler %d\n"
"Memory MiB : %4zu\n",
dt, cpu100, static_cast<int>(threading.bind),
ctx.topology.TopologyString(), ctx.pools.PinString(),
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
ctx.allocator.VectorBytes() * 8, CompiledConfig(),
ctx.allocator.TotalMiB(), ctx.allocator.FreeMiB());
ctx.allocator.VectorBytes() * 8, CompiledConfig(), PROFILER_ENABLED,
ctx.allocator.TotalMiB());
}
}

View File

@ -112,6 +112,7 @@ class GemmaEnv {
RuntimeConfig& MutableConfig() { return runtime_config_; }
std::mt19937& MutableGen() { return gen_; }
KVCache& MutableKVCache() { return kv_caches_[0]; }
MatMulEnv& MutableEnv() { return env_; }
private:
MatMulEnv env_;

View File

@ -99,7 +99,7 @@ HWY_EXPORT(CallSoftmax);
float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
int verbosity) {
MatMulEnv& env, int verbosity) {
const StreamFunc stream_token = [](int, float) { return true; };
const int vocab_size = gemma.GetModelConfig().vocab_size;
@ -145,7 +145,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
};
TimingInfo timing_info;
gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info);
gemma.Generate(runtime, prompt0, 0, kv_cache, env, timing_info);
const float scale = 1.0f / std::log(2.0f);
return cross_entropy * scale;

View File

@ -26,7 +26,7 @@ namespace gcpp {
float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
int verbosity);
MatMulEnv& env, int verbosity);
} // namespace gcpp

View File

@ -127,7 +127,7 @@ TEST_F(GemmaTest, Multiturn) {
config.wrapping, abs_pos, mutable_prompt);
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
timing_info);
s_env->MutableEnv(), timing_info);
// Note: we do not rewind any <end_of_turn> tokens here. If the model
// produced one and WrapAndTokenize() inserts another one, it will just be
// duplicated.
@ -139,7 +139,7 @@ TEST_F(GemmaTest, Multiturn) {
// access to the previous turn by asking to reproduce.
response.clear();
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
timing_info);
s_env->MutableEnv(), timing_info);
fprintf(stderr, "decoded: '%s'\n", response.c_str());
bool remembered_turquoise =
response.find("turquoise") != std::string::npos; // NOLINT

View File

@ -131,7 +131,8 @@ void Run(GemmaEnv& env, JsonArgs& json) {
.stream_token = stream_token,
};
env.GetGemma()->Generate(runtime_config, prompt, /*pos=*/0,
env.MutableKVCache(), timing_info);
env.MutableKVCache(), env.MutableEnv(),
timing_info);
std::string output_string = env.StringFromTokens(predicted_token_ids);
fprintf(stderr, "Correct %s, model '%s'\n", correct_answer.c_str(),

View File

@ -52,7 +52,7 @@ int main(int argc, char** argv) {
// Instantiate model and KV Cache
gcpp::MatMulEnv env(MakeMatMulEnv(threading, inference));
gcpp::Gemma gemma(loader, inference, env);
gcpp::Gemma gemma(loader, inference, env.ctx.pools);
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference);
size_t generated = 0;
@ -93,5 +93,5 @@ int main(int argc, char** argv) {
return !reject_tokens.contains(token);
},
};
gemma.Generate(runtime_config, tokens, 0, kv_cache, timing_info);
gemma.Generate(runtime_config, tokens, 0, kv_cache, env, timing_info);
}

View File

@ -36,7 +36,7 @@ class SimplifiedGemma {
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
: env_(MakeMatMulEnv(threading, inference)),
gemma_(loader, inference, env_),
gemma_(loader, inference, env_.ctx.pools),
kv_cache_(gemma_.GetModelConfig(), inference) {
// Initialize random number generator
std::random_device rd;
@ -83,7 +83,7 @@ class SimplifiedGemma {
return !reject_tokens.contains(token);
},
};
gemma_.Generate(runtime_config, tokens, 0, kv_cache_, timing_info);
gemma_.Generate(runtime_config, tokens, 0, kv_cache_, env_, timing_info);
}
~SimplifiedGemma() = default;

View File

@ -103,7 +103,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
threading_args(threading_args),
matmul_env(MakeMatMulEnv(threading_args, inference_args)),
active_conversation_name("default"),
model(loader, inference_args, matmul_env) {
model(loader, inference_args, matmul_env.ctx.pools) {
std::stringstream ss;
LogDebug("Creating initial ConversationData");
@ -207,7 +207,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
// Pass the populated image object to GenerateImageTokens
model.GenerateImageTokens(runtime_config,
active_conversation->kv_cache->SeqLen(), image,
image_tokens);
image_tokens, matmul_env);
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
ss.str("");
@ -244,7 +244,8 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
// Pass the KVCache object by reference from the active conversation
model.Generate(runtime_config, prompt_span, active_conversation->abs_pos,
prefix_end, *(active_conversation->kv_cache), timing_info);
prefix_end, *active_conversation->kv_cache, matmul_env,
timing_info);
// prepare for next turn
if (!inference_args.multiturn ||

View File

@ -610,62 +610,62 @@ MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args,
}
Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
MatMulEnv& env)
: env_(env),
reader_(loader.weights),
NestedPools& pools)
: reader_(loader.weights),
model_(reader_, loader.tokenizer, loader.wrapping),
weights_(model_.Config()),
chat_template_(model_.Tokenizer(), model_.Config().model),
inference_(inference) {
weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_,
env.ctx.pools.Pool());
pools.Pool());
reader_.CloseFile();
}
Gemma::~Gemma() = default;
void Gemma::Save(const Path& weights_path, hwy::ThreadPool& pool) const {
void Gemma::Save(const Path& weights_path, NestedPools& pools) const {
BlobWriter writer;
const std::vector<uint32_t> serialized_mat_ptrs =
weights_.AddTensorDataToWriter(writer);
WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs,
writer, env_.ctx.pools.Pool(), weights_path);
writer, pools.Pool(), weights_path);
}
void Gemma::Generate(const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, size_t prefix_end,
KVCache& kv_cache, TimingInfo& timing_info) const {
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
KVCache& kv_cache, MatMulEnv& env,
TimingInfo& timing_info) const {
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
HWY_DYNAMIC_DISPATCH(GenerateSingleT)(prompt, pos, prefix_end,
model_.Config(), runtime_config,
weights_, kv_cache, env_, timing_info);
weights_, kv_cache, env, timing_info);
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
AllQueries& all_queries,
AllQueries& all_queries, MatMulEnv& env,
TimingInfo& timing_info) const {
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
HWY_DYNAMIC_DISPATCH(GenerateBatchT)(model_.Config(), runtime_config,
weights_, all_queries, env_,
timing_info);
weights_, all_queries, env, timing_info);
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
size_t seq_len, const Image& image,
ImageTokens& image_tokens) const {
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
ImageTokens& image_tokens,
MatMulEnv& env) const {
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(model_.Config(), runtime_config,
seq_len, weights_, image,
image_tokens, env_);
image_tokens, env);
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}
} // namespace gcpp

View File

@ -229,16 +229,16 @@ struct TimingInfo {
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args,
const InferenceArgs& inference_args);
// After construction, all methods are const and thread-compatible if using
// separate MatMulEnv for each thread.
class Gemma {
public:
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
// `env` must remain valid for the lifetime of this Gemma.
// `pools` are used to parallelize loading.
Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
MatMulEnv& env);
NestedPools& pools);
~Gemma();
MatMulEnv& Env() const { return env_; }
// TODO: rename to Config()
const ModelConfig& GetModelConfig() const { return model_.Config(); }
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
@ -246,29 +246,31 @@ class Gemma {
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
const InferenceArgs& Inference() const { return inference_; }
void Save(const Path& weights_path, hwy::ThreadPool& pool) const;
void Save(const Path& weights_path, NestedPools& pools) const;
// `pos` is the position in the KV cache. Users are responsible for
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn.
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
size_t pos, KVCache& kv_cache, TimingInfo& timing_info) const {
Generate(runtime_config, prompt, pos, /*prefix_end=*/0, kv_cache,
size_t pos, KVCache& kv_cache, MatMulEnv& env,
TimingInfo& timing_info) const {
Generate(runtime_config, prompt, pos, /*prefix_end=*/0, kv_cache, env,
timing_info);
}
// For prefix-LM style attention, we can pass the end of the prefix.
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
size_t pos, size_t prefix_end, KVCache& kv_cache,
TimingInfo& timing_info) const;
MatMulEnv& env, TimingInfo& timing_info) const;
void GenerateBatch(const RuntimeConfig& runtime_config,
AllQueries& all_queries, TimingInfo& timing_info) const;
AllQueries& all_queries, MatMulEnv& env,
TimingInfo& timing_info) const;
// Generates the image tokens by running the image encoder ViT.
void GenerateImageTokens(const RuntimeConfig& runtime_config, size_t seq_len,
const Image& image, ImageTokens& image_tokens) const;
const Image& image, ImageTokens& image_tokens,
MatMulEnv& env) const;
private:
MatMulEnv& env_;
BlobReader reader_;
ModelStore model_;
std::vector<MatOwner> mat_owners_;

View File

@ -92,7 +92,7 @@ std::string GetPrompt(const InferenceArgs& inference) {
// The main Read-Eval-Print Loop.
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
const Gemma& gemma, KVCache& kv_cache) {
const Gemma& gemma, KVCache& kv_cache, MatMulEnv& env) {
PROFILER_ZONE("Gen.misc");
size_t abs_pos = 0; // across turns
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
@ -111,7 +111,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
config.model_dim)
: Extents2D(0, 0),
MatPadding::kOdd);
image_tokens.AllocateAndAttachRowPtrs(gemma.Env().row_ptrs);
image_tokens.AllocateAndAttachRowPtrs(env.row_ptrs);
if (have_image) {
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA ||
config.wrapping == PromptWrapping::GEMMA_VLM);
@ -123,7 +123,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
.use_spinning = threading.spin};
double image_tokens_start = hwy::platform::Now();
gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
image_tokens);
image_tokens, env);
if (inference.verbosity >= 1) {
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
fprintf(stderr,
@ -224,7 +224,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
if (inference.verbosity >= 1) {
std::cerr << "\n[ Reading prompt ] " << std::flush;
}
gemma.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache,
gemma.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache, env,
timing_info);
std::cout << "\n\n";
@ -256,7 +256,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
MatMulEnv env(MakeMatMulEnv(threading, inference));
if (inference.verbosity >= 2) env.print_best = true;
const Gemma gemma(loader, inference, env);
const Gemma gemma(loader, inference, env.ctx.pools);
KVCache kv_cache(gemma.GetModelConfig(), inference);
if (inference.verbosity >= 1) {
@ -289,7 +289,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
}
}
ReplGemma(threading, inference, gemma, kv_cache);
ReplGemma(threading, inference, gemma, kv_cache, env);
}
} // namespace gcpp

View File

@ -44,6 +44,6 @@ int main(int argc, char** argv) {
}
gcpp::GemmaEnv env(argc, argv);
env.GetGemma()->Save(args.output_weights, env.Env().ctx.pools.Pool());
env.GetGemma()->Save(args.output_weights, env.Env().ctx.pools);
return 0;
}

View File

@ -30,7 +30,7 @@ void PaliGemmaHelper::InitVit(const std::string& path) {
RuntimeConfig runtime_config = {.gen = &env_->MutableGen(),
.verbosity = 0};
gemma.GenerateImageTokens(runtime_config, env_->MutableKVCache().SeqLen(),
image, *image_tokens_);
image, *image_tokens_, env_->MutableEnv());
}
std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const {
@ -61,7 +61,7 @@ std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const {
const size_t prefix_end = tokens.size();
TimingInfo timing_info = {.verbosity = 0};
model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end,
env_->MutableKVCache(), timing_info);
env_->MutableKVCache(), env_->MutableEnv(), timing_info);
return response;
}

View File

@ -48,16 +48,16 @@ class GemmaModel {
GemmaModel(const gcpp::LoaderArgs& loader,
const gcpp::ThreadingArgs& threading,
const gcpp::InferenceArgs& inference)
: gemma_(loader, threading, inference), last_prob_(0.0f) {}
: env_(loader, threading, inference), last_prob_(0.0f) {}
// Generates a single example, given a prompt and a callback to stream the
// generated tokens.
void GenerateEx(std::string prompt, gcpp::StreamFunc stream,
size_t max_generated_tokens, float temperature, float seed,
gcpp::AcceptFunc accept, bool skip_prompt) {
gemma_.MutableGen().seed(seed);
std::vector<int> prompt_tokens = gemma_.WrapAndTokenize(prompt);
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
env_.MutableGen().seed(seed);
std::vector<int> prompt_tokens = env_.WrapAndTokenize(prompt);
gcpp::RuntimeConfig& config = env_.MutableConfig();
config.max_generated_tokens = max_generated_tokens;
config.temperature = temperature;
config.verbosity = 0;
@ -72,8 +72,7 @@ class GemmaModel {
}
return stream(token, score);
};
gemma_.QueryModel(prompt_tokens,
skip_prompt ? stream_with_skipping : stream);
env_.QueryModel(prompt_tokens, skip_prompt ? stream_with_skipping : stream);
}
// Generates a single example, given a prompt, and returns the result.
@ -83,13 +82,13 @@ class GemmaModel {
const std::vector<std::string>& end) {
std::set<int> end_token_set{};
for (const std::string& end_token : end) {
std::vector<int> end_token_ids = gemma_.Tokenize(end_token);
std::vector<int> end_token_ids = env_.Tokenize(end_token);
end_token_set.insert(end_token_ids.begin(), end_token_ids.end());
}
std::vector<int> predicted_token_ids;
predicted_token_ids.reserve(max_generated_tokens);
std::vector<int> prompt_token_ids = gemma_.WrapAndTokenize(prompt);
std::vector<int> prompt_token_ids = env_.WrapAndTokenize(prompt);
int generated = 0;
auto stream_token = [&generated, &prompt_token_ids, &predicted_token_ids,
&end_token_set, this](int token, float proba) {
@ -106,7 +105,7 @@ class GemmaModel {
std::set<int> accept_token_set{};
for (const std::string& accept_token : accept) {
std::vector<int> accept_token_ids = gemma_.Tokenize(accept_token);
std::vector<int> accept_token_ids = env_.Tokenize(accept_token);
accept_token_set.insert(accept_token_ids.begin(), accept_token_ids.end());
}
@ -125,17 +124,17 @@ class GemmaModel {
}
};
gemma_.MutableGen().seed(seed);
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
env_.MutableGen().seed(seed);
gcpp::RuntimeConfig& config = env_.MutableConfig();
config.max_generated_tokens = max_generated_tokens;
config.temperature = temperature;
config.verbosity = 0;
config.accept_token = accept_token;
gemma_.QueryModel(prompt_token_ids, stream_token);
env_.QueryModel(prompt_token_ids, stream_token);
if (!predicted_token_ids.empty()) {
return gemma_.StringFromTokens(predicted_token_ids);
return env_.StringFromTokens(predicted_token_ids);
} else {
return "";
}
@ -147,14 +146,14 @@ class GemmaModel {
size_t max_generated_tokens,
float temperature, float seed,
size_t top_k) {
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
gcpp::RuntimeConfig& config = env_.MutableConfig();
config.max_generated_tokens = max_generated_tokens;
config.temperature = temperature;
config.top_k = top_k;
config.verbosity = 0;
gemma_.MutableGen().seed(seed);
env_.MutableGen().seed(seed);
std::vector<gcpp::QueryResult> outputs = gemma_.BatchQueryModel(inputs);
std::vector<gcpp::QueryResult> outputs = env_.BatchQueryModel(inputs);
std::vector<std::string> result;
result.reserve(outputs.size());
for (const gcpp::QueryResult& output : outputs) {
@ -167,7 +166,7 @@ class GemmaModel {
// Generate* will use this image. Throws an error for other models.
void SetImage(const py::array_t<float, py::array::c_style |
py::array::forcecast>& image) {
const gcpp::Gemma& gemma = *gemma_.GetGemma();
const gcpp::Gemma& gemma = *env_.GetGemma();
const gcpp::ModelConfig& config = gemma.GetModelConfig();
if (config.wrapping != gcpp::PromptWrapping::PALIGEMMA &&
config.wrapping != gcpp::PromptWrapping::GEMMA_VLM) {
@ -188,10 +187,10 @@ class GemmaModel {
"image_tokens",
gcpp::Extents2D(config.vit_config.seq_len, config.model_dim),
gcpp::MatPadding::kOdd));
gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(),
gcpp::RuntimeConfig runtime_config = {.gen = &env_.MutableGen(),
.verbosity = 0};
gemma.GenerateImageTokens(runtime_config, gemma_.MutableKVCache().SeqLen(),
c_image, *image_tokens_);
gemma.GenerateImageTokens(runtime_config, env_.MutableKVCache().SeqLen(),
c_image, *image_tokens_, env_.MutableEnv());
}
// Generates a response to the given prompt, using the last set image.
@ -200,9 +199,9 @@ class GemmaModel {
std::string prompt, size_t max_generated_tokens, float temperature,
float seed, gcpp::AcceptFunc accept, std::vector<int> prompt_tokens) {
if (!image_tokens_) throw std::invalid_argument("No image set.");
const gcpp::Gemma& model = *gemma_.GetGemma();
gemma_.MutableGen().seed(seed);
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
const gcpp::Gemma& model = *env_.GetGemma();
env_.MutableGen().seed(seed);
gcpp::RuntimeConfig& config = env_.MutableConfig();
config.max_generated_tokens = max_generated_tokens;
config.temperature = temperature;
config.verbosity = 0;
@ -217,7 +216,7 @@ class GemmaModel {
tokens = prompt_tokens;
RemoveTrailingZeros(tokens); // Remove padding, if any.
} else {
tokens = gemma_.WrapAndTokenize(prompt);
tokens = env_.WrapAndTokenize(prompt);
}
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
size_t num_tokens = tokens.size();
@ -235,8 +234,8 @@ class GemmaModel {
};
config.stream_token = stream_token;
gcpp::TimingInfo timing_info = {.verbosity = 0};
model.Generate(config, tokens, /*pos=*/0, prefix_end,
gemma_.MutableKVCache(), timing_info);
model.Generate(config, tokens, /*pos=*/0, prefix_end, env_.MutableKVCache(),
env_.MutableEnv(), timing_info);
std::string response;
model.Tokenizer().Decode(response_tokens, &response);
return {response, response_tokens};
@ -245,13 +244,13 @@ class GemmaModel {
float GetLastProb() const { return last_prob_; }
std::string Detokenize(const std::vector<int>& token_ids) const {
return gemma_.StringFromTokens(token_ids);
return env_.StringFromTokens(token_ids);
}
bool ModelIsLoaded() const { return gemma_.GetGemma() != nullptr; }
bool ModelIsLoaded() const { return env_.GetGemma() != nullptr; }
private:
gcpp::GemmaEnv gemma_;
gcpp::GemmaEnv env_;
std::unique_ptr<gcpp::ImageTokens> image_tokens_;
float last_prob_;
};