mirror of https://github.com/google/gemma.cpp.git
Move MatMulEnv out of Gemma to enable concurrent calls
Also update benchmark_helper config print: add profiler, remove free mem PiperOrigin-RevId: 774662974
This commit is contained in:
parent
0f70f285e0
commit
a04cc287b2
|
|
@ -551,6 +551,7 @@ cc_library(
|
|||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:profiler",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -75,8 +75,9 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
|||
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||
prompt.begin() + pos + num_tokens);
|
||||
KVCache kv_cache(gemma.GetModelConfig(), gemma.Inference());
|
||||
float entropy = ComputeCrossEntropy(
|
||||
*env.GetGemma(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
|
||||
float entropy =
|
||||
ComputeCrossEntropy(*env.GetGemma(), num_tokens, prompt_slice, kv_cache,
|
||||
env.MutableEnv(), env.Verbosity());
|
||||
total_entropy += entropy;
|
||||
LogSpeedStats(time_start, pos + num_tokens);
|
||||
std::string text_slice = env.StringFromTokens(prompt_slice);
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@
|
|||
#include "util/threading_context.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/per_target.h" // DispatchedTarget
|
||||
#include "hwy/profiler.h" // PROFILER_ENABLED
|
||||
#include "hwy/timer.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -50,7 +51,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
|
|||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference)
|
||||
: env_(MakeMatMulEnv(threading, inference)),
|
||||
gemma_(loader, inference, env_) {
|
||||
gemma_(loader, inference, env_.ctx.pools) {
|
||||
const ModelConfig& config = gemma_.GetModelConfig();
|
||||
// Only allocate one for starters because GenerateBatch might not be called.
|
||||
kv_caches_.push_back(KVCache(config, inference));
|
||||
|
|
@ -94,7 +95,7 @@ QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
|||
}
|
||||
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
|
||||
runtime_config_.batch_stream_token = batch_stream_token;
|
||||
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
|
||||
timing_info);
|
||||
return result;
|
||||
}
|
||||
|
|
@ -104,7 +105,7 @@ void GemmaEnv::QueryModel(
|
|||
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
|
||||
const StreamFunc previous_stream_token = runtime_config_.stream_token;
|
||||
runtime_config_.stream_token = stream_token;
|
||||
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
|
||||
timing_info);
|
||||
runtime_config_.stream_token = previous_stream_token;
|
||||
}
|
||||
|
|
@ -146,7 +147,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
|||
|
||||
gcpp::AllQueries all_queries(queries_prompt, kv_caches, prefix_end);
|
||||
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
|
||||
gemma_.GenerateBatch(runtime_config_, all_queries, timing_info);
|
||||
gemma_.GenerateBatch(runtime_config_, all_queries, env_, timing_info);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
@ -176,7 +177,7 @@ float GemmaEnv::CrossEntropy(const std::string& input) {
|
|||
std::vector<int> prompt = Tokenize(input);
|
||||
prompt.insert(prompt.begin(), BOS_ID);
|
||||
return ComputeCrossEntropy(*GetGemma(), /*max_generated_tokens=*/3072, prompt,
|
||||
MutableKVCache(),
|
||||
MutableKVCache(), env_,
|
||||
/*verbosity=*/0) /
|
||||
static_cast<int>(input.size());
|
||||
}
|
||||
|
|
@ -247,13 +248,13 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
"CPU : %s, bind %d\n"
|
||||
"CPU topology : %s, %s, %s\n"
|
||||
"Instruction set : %s (%zu bits)\n"
|
||||
"Compiled config : %s\n"
|
||||
"Memory MiB : %4zu, %4zu free\n",
|
||||
"Compiled config : %s, profiler %d\n"
|
||||
"Memory MiB : %4zu\n",
|
||||
dt, cpu100, static_cast<int>(threading.bind),
|
||||
ctx.topology.TopologyString(), ctx.pools.PinString(),
|
||||
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
|
||||
ctx.allocator.VectorBytes() * 8, CompiledConfig(),
|
||||
ctx.allocator.TotalMiB(), ctx.allocator.FreeMiB());
|
||||
ctx.allocator.VectorBytes() * 8, CompiledConfig(), PROFILER_ENABLED,
|
||||
ctx.allocator.TotalMiB());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -112,6 +112,7 @@ class GemmaEnv {
|
|||
RuntimeConfig& MutableConfig() { return runtime_config_; }
|
||||
std::mt19937& MutableGen() { return gen_; }
|
||||
KVCache& MutableKVCache() { return kv_caches_[0]; }
|
||||
MatMulEnv& MutableEnv() { return env_; }
|
||||
|
||||
private:
|
||||
MatMulEnv env_;
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ HWY_EXPORT(CallSoftmax);
|
|||
|
||||
float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
int verbosity) {
|
||||
MatMulEnv& env, int verbosity) {
|
||||
const StreamFunc stream_token = [](int, float) { return true; };
|
||||
|
||||
const int vocab_size = gemma.GetModelConfig().vocab_size;
|
||||
|
|
@ -145,7 +145,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
|||
};
|
||||
TimingInfo timing_info;
|
||||
|
||||
gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info);
|
||||
gemma.Generate(runtime, prompt0, 0, kv_cache, env, timing_info);
|
||||
|
||||
const float scale = 1.0f / std::log(2.0f);
|
||||
return cross_entropy * scale;
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ namespace gcpp {
|
|||
|
||||
float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
int verbosity);
|
||||
MatMulEnv& env, int verbosity);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ TEST_F(GemmaTest, Multiturn) {
|
|||
config.wrapping, abs_pos, mutable_prompt);
|
||||
|
||||
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
||||
timing_info);
|
||||
s_env->MutableEnv(), timing_info);
|
||||
// Note: we do not rewind any <end_of_turn> tokens here. If the model
|
||||
// produced one and WrapAndTokenize() inserts another one, it will just be
|
||||
// duplicated.
|
||||
|
|
@ -139,7 +139,7 @@ TEST_F(GemmaTest, Multiturn) {
|
|||
// access to the previous turn by asking to reproduce.
|
||||
response.clear();
|
||||
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
||||
timing_info);
|
||||
s_env->MutableEnv(), timing_info);
|
||||
fprintf(stderr, "decoded: '%s'\n", response.c_str());
|
||||
bool remembered_turquoise =
|
||||
response.find("turquoise") != std::string::npos; // NOLINT
|
||||
|
|
|
|||
|
|
@ -131,7 +131,8 @@ void Run(GemmaEnv& env, JsonArgs& json) {
|
|||
.stream_token = stream_token,
|
||||
};
|
||||
env.GetGemma()->Generate(runtime_config, prompt, /*pos=*/0,
|
||||
env.MutableKVCache(), timing_info);
|
||||
env.MutableKVCache(), env.MutableEnv(),
|
||||
timing_info);
|
||||
|
||||
std::string output_string = env.StringFromTokens(predicted_token_ids);
|
||||
fprintf(stderr, "Correct %s, model '%s'\n", correct_answer.c_str(),
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ int main(int argc, char** argv) {
|
|||
|
||||
// Instantiate model and KV Cache
|
||||
gcpp::MatMulEnv env(MakeMatMulEnv(threading, inference));
|
||||
gcpp::Gemma gemma(loader, inference, env);
|
||||
gcpp::Gemma gemma(loader, inference, env.ctx.pools);
|
||||
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference);
|
||||
size_t generated = 0;
|
||||
|
||||
|
|
@ -93,5 +93,5 @@ int main(int argc, char** argv) {
|
|||
return !reject_tokens.contains(token);
|
||||
},
|
||||
};
|
||||
gemma.Generate(runtime_config, tokens, 0, kv_cache, timing_info);
|
||||
gemma.Generate(runtime_config, tokens, 0, kv_cache, env, timing_info);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class SimplifiedGemma {
|
|||
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
|
||||
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
|
||||
: env_(MakeMatMulEnv(threading, inference)),
|
||||
gemma_(loader, inference, env_),
|
||||
gemma_(loader, inference, env_.ctx.pools),
|
||||
kv_cache_(gemma_.GetModelConfig(), inference) {
|
||||
// Initialize random number generator
|
||||
std::random_device rd;
|
||||
|
|
@ -83,7 +83,7 @@ class SimplifiedGemma {
|
|||
return !reject_tokens.contains(token);
|
||||
},
|
||||
};
|
||||
gemma_.Generate(runtime_config, tokens, 0, kv_cache_, timing_info);
|
||||
gemma_.Generate(runtime_config, tokens, 0, kv_cache_, env_, timing_info);
|
||||
}
|
||||
~SimplifiedGemma() = default;
|
||||
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
|
|||
threading_args(threading_args),
|
||||
matmul_env(MakeMatMulEnv(threading_args, inference_args)),
|
||||
active_conversation_name("default"),
|
||||
model(loader, inference_args, matmul_env) {
|
||||
model(loader, inference_args, matmul_env.ctx.pools) {
|
||||
std::stringstream ss;
|
||||
|
||||
LogDebug("Creating initial ConversationData");
|
||||
|
|
@ -207,7 +207,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
// Pass the populated image object to GenerateImageTokens
|
||||
model.GenerateImageTokens(runtime_config,
|
||||
active_conversation->kv_cache->SeqLen(), image,
|
||||
image_tokens);
|
||||
image_tokens, matmul_env);
|
||||
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||
|
||||
ss.str("");
|
||||
|
|
@ -244,7 +244,8 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
|
||||
// Pass the KVCache object by reference from the active conversation
|
||||
model.Generate(runtime_config, prompt_span, active_conversation->abs_pos,
|
||||
prefix_end, *(active_conversation->kv_cache), timing_info);
|
||||
prefix_end, *active_conversation->kv_cache, matmul_env,
|
||||
timing_info);
|
||||
|
||||
// prepare for next turn
|
||||
if (!inference_args.multiturn ||
|
||||
|
|
|
|||
|
|
@ -610,62 +610,62 @@ MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args,
|
|||
}
|
||||
|
||||
Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
MatMulEnv& env)
|
||||
: env_(env),
|
||||
reader_(loader.weights),
|
||||
NestedPools& pools)
|
||||
: reader_(loader.weights),
|
||||
model_(reader_, loader.tokenizer, loader.wrapping),
|
||||
weights_(model_.Config()),
|
||||
chat_template_(model_.Tokenizer(), model_.Config().model),
|
||||
inference_(inference) {
|
||||
weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_,
|
||||
env.ctx.pools.Pool());
|
||||
pools.Pool());
|
||||
reader_.CloseFile();
|
||||
}
|
||||
|
||||
Gemma::~Gemma() = default;
|
||||
|
||||
void Gemma::Save(const Path& weights_path, hwy::ThreadPool& pool) const {
|
||||
void Gemma::Save(const Path& weights_path, NestedPools& pools) const {
|
||||
BlobWriter writer;
|
||||
const std::vector<uint32_t> serialized_mat_ptrs =
|
||||
weights_.AddTensorDataToWriter(writer);
|
||||
WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs,
|
||||
writer, env_.ctx.pools.Pool(), weights_path);
|
||||
writer, pools.Pool(), weights_path);
|
||||
}
|
||||
|
||||
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
KVCache& kv_cache, TimingInfo& timing_info) const {
|
||||
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
KVCache& kv_cache, MatMulEnv& env,
|
||||
TimingInfo& timing_info) const {
|
||||
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
HWY_DYNAMIC_DISPATCH(GenerateSingleT)(prompt, pos, prefix_end,
|
||||
model_.Config(), runtime_config,
|
||||
weights_, kv_cache, env_, timing_info);
|
||||
weights_, kv_cache, env, timing_info);
|
||||
|
||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
AllQueries& all_queries,
|
||||
AllQueries& all_queries, MatMulEnv& env,
|
||||
TimingInfo& timing_info) const {
|
||||
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
HWY_DYNAMIC_DISPATCH(GenerateBatchT)(model_.Config(), runtime_config,
|
||||
weights_, all_queries, env_,
|
||||
timing_info);
|
||||
weights_, all_queries, env, timing_info);
|
||||
|
||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
||||
size_t seq_len, const Image& image,
|
||||
ImageTokens& image_tokens) const {
|
||||
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
ImageTokens& image_tokens,
|
||||
MatMulEnv& env) const {
|
||||
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(model_.Config(), runtime_config,
|
||||
seq_len, weights_, image,
|
||||
image_tokens, env_);
|
||||
image_tokens, env);
|
||||
|
||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -229,16 +229,16 @@ struct TimingInfo {
|
|||
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args,
|
||||
const InferenceArgs& inference_args);
|
||||
|
||||
// After construction, all methods are const and thread-compatible if using
|
||||
// separate MatMulEnv for each thread.
|
||||
class Gemma {
|
||||
public:
|
||||
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
|
||||
// `env` must remain valid for the lifetime of this Gemma.
|
||||
// `pools` are used to parallelize loading.
|
||||
Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
MatMulEnv& env);
|
||||
|
||||
NestedPools& pools);
|
||||
~Gemma();
|
||||
|
||||
MatMulEnv& Env() const { return env_; }
|
||||
// TODO: rename to Config()
|
||||
const ModelConfig& GetModelConfig() const { return model_.Config(); }
|
||||
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
|
||||
|
|
@ -246,29 +246,31 @@ class Gemma {
|
|||
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
|
||||
const InferenceArgs& Inference() const { return inference_; }
|
||||
|
||||
void Save(const Path& weights_path, hwy::ThreadPool& pool) const;
|
||||
void Save(const Path& weights_path, NestedPools& pools) const;
|
||||
|
||||
// `pos` is the position in the KV cache. Users are responsible for
|
||||
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn.
|
||||
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
|
||||
size_t pos, KVCache& kv_cache, TimingInfo& timing_info) const {
|
||||
Generate(runtime_config, prompt, pos, /*prefix_end=*/0, kv_cache,
|
||||
size_t pos, KVCache& kv_cache, MatMulEnv& env,
|
||||
TimingInfo& timing_info) const {
|
||||
Generate(runtime_config, prompt, pos, /*prefix_end=*/0, kv_cache, env,
|
||||
timing_info);
|
||||
}
|
||||
// For prefix-LM style attention, we can pass the end of the prefix.
|
||||
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
|
||||
size_t pos, size_t prefix_end, KVCache& kv_cache,
|
||||
TimingInfo& timing_info) const;
|
||||
MatMulEnv& env, TimingInfo& timing_info) const;
|
||||
|
||||
void GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
AllQueries& all_queries, TimingInfo& timing_info) const;
|
||||
AllQueries& all_queries, MatMulEnv& env,
|
||||
TimingInfo& timing_info) const;
|
||||
|
||||
// Generates the image tokens by running the image encoder ViT.
|
||||
void GenerateImageTokens(const RuntimeConfig& runtime_config, size_t seq_len,
|
||||
const Image& image, ImageTokens& image_tokens) const;
|
||||
const Image& image, ImageTokens& image_tokens,
|
||||
MatMulEnv& env) const;
|
||||
|
||||
private:
|
||||
MatMulEnv& env_;
|
||||
BlobReader reader_;
|
||||
ModelStore model_;
|
||||
std::vector<MatOwner> mat_owners_;
|
||||
|
|
|
|||
12
gemma/run.cc
12
gemma/run.cc
|
|
@ -92,7 +92,7 @@ std::string GetPrompt(const InferenceArgs& inference) {
|
|||
|
||||
// The main Read-Eval-Print Loop.
|
||||
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||
const Gemma& gemma, KVCache& kv_cache) {
|
||||
const Gemma& gemma, KVCache& kv_cache, MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.misc");
|
||||
size_t abs_pos = 0; // across turns
|
||||
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
||||
|
|
@ -111,7 +111,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
config.model_dim)
|
||||
: Extents2D(0, 0),
|
||||
MatPadding::kOdd);
|
||||
image_tokens.AllocateAndAttachRowPtrs(gemma.Env().row_ptrs);
|
||||
image_tokens.AllocateAndAttachRowPtrs(env.row_ptrs);
|
||||
if (have_image) {
|
||||
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA ||
|
||||
config.wrapping == PromptWrapping::GEMMA_VLM);
|
||||
|
|
@ -123,7 +123,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
.use_spinning = threading.spin};
|
||||
double image_tokens_start = hwy::platform::Now();
|
||||
gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
|
||||
image_tokens);
|
||||
image_tokens, env);
|
||||
if (inference.verbosity >= 1) {
|
||||
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||
fprintf(stderr,
|
||||
|
|
@ -224,7 +224,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
if (inference.verbosity >= 1) {
|
||||
std::cerr << "\n[ Reading prompt ] " << std::flush;
|
||||
}
|
||||
gemma.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache,
|
||||
gemma.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache, env,
|
||||
timing_info);
|
||||
std::cout << "\n\n";
|
||||
|
||||
|
|
@ -256,7 +256,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
|
||||
MatMulEnv env(MakeMatMulEnv(threading, inference));
|
||||
if (inference.verbosity >= 2) env.print_best = true;
|
||||
const Gemma gemma(loader, inference, env);
|
||||
const Gemma gemma(loader, inference, env.ctx.pools);
|
||||
KVCache kv_cache(gemma.GetModelConfig(), inference);
|
||||
|
||||
if (inference.verbosity >= 1) {
|
||||
|
|
@ -289,7 +289,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
}
|
||||
}
|
||||
|
||||
ReplGemma(threading, inference, gemma, kv_cache);
|
||||
ReplGemma(threading, inference, gemma, kv_cache, env);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -44,6 +44,6 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
env.GetGemma()->Save(args.output_weights, env.Env().ctx.pools.Pool());
|
||||
env.GetGemma()->Save(args.output_weights, env.Env().ctx.pools);
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ void PaliGemmaHelper::InitVit(const std::string& path) {
|
|||
RuntimeConfig runtime_config = {.gen = &env_->MutableGen(),
|
||||
.verbosity = 0};
|
||||
gemma.GenerateImageTokens(runtime_config, env_->MutableKVCache().SeqLen(),
|
||||
image, *image_tokens_);
|
||||
image, *image_tokens_, env_->MutableEnv());
|
||||
}
|
||||
|
||||
std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const {
|
||||
|
|
@ -61,7 +61,7 @@ std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const {
|
|||
const size_t prefix_end = tokens.size();
|
||||
TimingInfo timing_info = {.verbosity = 0};
|
||||
model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end,
|
||||
env_->MutableKVCache(), timing_info);
|
||||
env_->MutableKVCache(), env_->MutableEnv(), timing_info);
|
||||
return response;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -48,16 +48,16 @@ class GemmaModel {
|
|||
GemmaModel(const gcpp::LoaderArgs& loader,
|
||||
const gcpp::ThreadingArgs& threading,
|
||||
const gcpp::InferenceArgs& inference)
|
||||
: gemma_(loader, threading, inference), last_prob_(0.0f) {}
|
||||
: env_(loader, threading, inference), last_prob_(0.0f) {}
|
||||
|
||||
// Generates a single example, given a prompt and a callback to stream the
|
||||
// generated tokens.
|
||||
void GenerateEx(std::string prompt, gcpp::StreamFunc stream,
|
||||
size_t max_generated_tokens, float temperature, float seed,
|
||||
gcpp::AcceptFunc accept, bool skip_prompt) {
|
||||
gemma_.MutableGen().seed(seed);
|
||||
std::vector<int> prompt_tokens = gemma_.WrapAndTokenize(prompt);
|
||||
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
|
||||
env_.MutableGen().seed(seed);
|
||||
std::vector<int> prompt_tokens = env_.WrapAndTokenize(prompt);
|
||||
gcpp::RuntimeConfig& config = env_.MutableConfig();
|
||||
config.max_generated_tokens = max_generated_tokens;
|
||||
config.temperature = temperature;
|
||||
config.verbosity = 0;
|
||||
|
|
@ -72,8 +72,7 @@ class GemmaModel {
|
|||
}
|
||||
return stream(token, score);
|
||||
};
|
||||
gemma_.QueryModel(prompt_tokens,
|
||||
skip_prompt ? stream_with_skipping : stream);
|
||||
env_.QueryModel(prompt_tokens, skip_prompt ? stream_with_skipping : stream);
|
||||
}
|
||||
|
||||
// Generates a single example, given a prompt, and returns the result.
|
||||
|
|
@ -83,13 +82,13 @@ class GemmaModel {
|
|||
const std::vector<std::string>& end) {
|
||||
std::set<int> end_token_set{};
|
||||
for (const std::string& end_token : end) {
|
||||
std::vector<int> end_token_ids = gemma_.Tokenize(end_token);
|
||||
std::vector<int> end_token_ids = env_.Tokenize(end_token);
|
||||
end_token_set.insert(end_token_ids.begin(), end_token_ids.end());
|
||||
}
|
||||
|
||||
std::vector<int> predicted_token_ids;
|
||||
predicted_token_ids.reserve(max_generated_tokens);
|
||||
std::vector<int> prompt_token_ids = gemma_.WrapAndTokenize(prompt);
|
||||
std::vector<int> prompt_token_ids = env_.WrapAndTokenize(prompt);
|
||||
int generated = 0;
|
||||
auto stream_token = [&generated, &prompt_token_ids, &predicted_token_ids,
|
||||
&end_token_set, this](int token, float proba) {
|
||||
|
|
@ -106,7 +105,7 @@ class GemmaModel {
|
|||
|
||||
std::set<int> accept_token_set{};
|
||||
for (const std::string& accept_token : accept) {
|
||||
std::vector<int> accept_token_ids = gemma_.Tokenize(accept_token);
|
||||
std::vector<int> accept_token_ids = env_.Tokenize(accept_token);
|
||||
accept_token_set.insert(accept_token_ids.begin(), accept_token_ids.end());
|
||||
}
|
||||
|
||||
|
|
@ -125,17 +124,17 @@ class GemmaModel {
|
|||
}
|
||||
};
|
||||
|
||||
gemma_.MutableGen().seed(seed);
|
||||
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
|
||||
env_.MutableGen().seed(seed);
|
||||
gcpp::RuntimeConfig& config = env_.MutableConfig();
|
||||
config.max_generated_tokens = max_generated_tokens;
|
||||
config.temperature = temperature;
|
||||
config.verbosity = 0;
|
||||
config.accept_token = accept_token;
|
||||
|
||||
gemma_.QueryModel(prompt_token_ids, stream_token);
|
||||
env_.QueryModel(prompt_token_ids, stream_token);
|
||||
|
||||
if (!predicted_token_ids.empty()) {
|
||||
return gemma_.StringFromTokens(predicted_token_ids);
|
||||
return env_.StringFromTokens(predicted_token_ids);
|
||||
} else {
|
||||
return "";
|
||||
}
|
||||
|
|
@ -147,14 +146,14 @@ class GemmaModel {
|
|||
size_t max_generated_tokens,
|
||||
float temperature, float seed,
|
||||
size_t top_k) {
|
||||
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
|
||||
gcpp::RuntimeConfig& config = env_.MutableConfig();
|
||||
config.max_generated_tokens = max_generated_tokens;
|
||||
config.temperature = temperature;
|
||||
config.top_k = top_k;
|
||||
config.verbosity = 0;
|
||||
gemma_.MutableGen().seed(seed);
|
||||
env_.MutableGen().seed(seed);
|
||||
|
||||
std::vector<gcpp::QueryResult> outputs = gemma_.BatchQueryModel(inputs);
|
||||
std::vector<gcpp::QueryResult> outputs = env_.BatchQueryModel(inputs);
|
||||
std::vector<std::string> result;
|
||||
result.reserve(outputs.size());
|
||||
for (const gcpp::QueryResult& output : outputs) {
|
||||
|
|
@ -167,7 +166,7 @@ class GemmaModel {
|
|||
// Generate* will use this image. Throws an error for other models.
|
||||
void SetImage(const py::array_t<float, py::array::c_style |
|
||||
py::array::forcecast>& image) {
|
||||
const gcpp::Gemma& gemma = *gemma_.GetGemma();
|
||||
const gcpp::Gemma& gemma = *env_.GetGemma();
|
||||
const gcpp::ModelConfig& config = gemma.GetModelConfig();
|
||||
if (config.wrapping != gcpp::PromptWrapping::PALIGEMMA &&
|
||||
config.wrapping != gcpp::PromptWrapping::GEMMA_VLM) {
|
||||
|
|
@ -188,10 +187,10 @@ class GemmaModel {
|
|||
"image_tokens",
|
||||
gcpp::Extents2D(config.vit_config.seq_len, config.model_dim),
|
||||
gcpp::MatPadding::kOdd));
|
||||
gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(),
|
||||
gcpp::RuntimeConfig runtime_config = {.gen = &env_.MutableGen(),
|
||||
.verbosity = 0};
|
||||
gemma.GenerateImageTokens(runtime_config, gemma_.MutableKVCache().SeqLen(),
|
||||
c_image, *image_tokens_);
|
||||
gemma.GenerateImageTokens(runtime_config, env_.MutableKVCache().SeqLen(),
|
||||
c_image, *image_tokens_, env_.MutableEnv());
|
||||
}
|
||||
|
||||
// Generates a response to the given prompt, using the last set image.
|
||||
|
|
@ -200,9 +199,9 @@ class GemmaModel {
|
|||
std::string prompt, size_t max_generated_tokens, float temperature,
|
||||
float seed, gcpp::AcceptFunc accept, std::vector<int> prompt_tokens) {
|
||||
if (!image_tokens_) throw std::invalid_argument("No image set.");
|
||||
const gcpp::Gemma& model = *gemma_.GetGemma();
|
||||
gemma_.MutableGen().seed(seed);
|
||||
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
|
||||
const gcpp::Gemma& model = *env_.GetGemma();
|
||||
env_.MutableGen().seed(seed);
|
||||
gcpp::RuntimeConfig& config = env_.MutableConfig();
|
||||
config.max_generated_tokens = max_generated_tokens;
|
||||
config.temperature = temperature;
|
||||
config.verbosity = 0;
|
||||
|
|
@ -217,7 +216,7 @@ class GemmaModel {
|
|||
tokens = prompt_tokens;
|
||||
RemoveTrailingZeros(tokens); // Remove padding, if any.
|
||||
} else {
|
||||
tokens = gemma_.WrapAndTokenize(prompt);
|
||||
tokens = env_.WrapAndTokenize(prompt);
|
||||
}
|
||||
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
|
||||
size_t num_tokens = tokens.size();
|
||||
|
|
@ -235,8 +234,8 @@ class GemmaModel {
|
|||
};
|
||||
config.stream_token = stream_token;
|
||||
gcpp::TimingInfo timing_info = {.verbosity = 0};
|
||||
model.Generate(config, tokens, /*pos=*/0, prefix_end,
|
||||
gemma_.MutableKVCache(), timing_info);
|
||||
model.Generate(config, tokens, /*pos=*/0, prefix_end, env_.MutableKVCache(),
|
||||
env_.MutableEnv(), timing_info);
|
||||
std::string response;
|
||||
model.Tokenizer().Decode(response_tokens, &response);
|
||||
return {response, response_tokens};
|
||||
|
|
@ -245,13 +244,13 @@ class GemmaModel {
|
|||
float GetLastProb() const { return last_prob_; }
|
||||
|
||||
std::string Detokenize(const std::vector<int>& token_ids) const {
|
||||
return gemma_.StringFromTokens(token_ids);
|
||||
return env_.StringFromTokens(token_ids);
|
||||
}
|
||||
|
||||
bool ModelIsLoaded() const { return gemma_.GetGemma() != nullptr; }
|
||||
bool ModelIsLoaded() const { return env_.GetGemma() != nullptr; }
|
||||
|
||||
private:
|
||||
gcpp::GemmaEnv gemma_;
|
||||
gcpp::GemmaEnv env_;
|
||||
std::unique_ptr<gcpp::ImageTokens> image_tokens_;
|
||||
float last_prob_;
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in New Issue