mirror of https://github.com/google/gemma.cpp.git
Improved timing for image tokens
Move to TimingInfo, extra newline before profiler PiperOrigin-RevId: 881943820
This commit is contained in:
parent
70cb9cf1c2
commit
cab77f8dc7
|
|
@ -29,7 +29,6 @@
|
|||
#include "util/threading.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <Windows.h>
|
||||
|
|
@ -195,17 +194,10 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
|
||||
// Use the existing runtime_config defined earlier in the function.
|
||||
// RuntimeConfig runtime_config = { ... }; // This was already defined
|
||||
double image_tokens_start = hwy::platform::Now();
|
||||
// Pass the populated image object to GenerateImageTokens
|
||||
model.GenerateImageTokens(runtime_config,
|
||||
active_conversation->kv_cache->SeqLen(), image,
|
||||
image_tokens, matmul_env);
|
||||
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||
|
||||
ss.str("");
|
||||
ss << "\n\n[ Timing info ] Image token generation took: ";
|
||||
ss << static_cast<int>(image_tokens_duration * 1000) << " ms\n",
|
||||
LogDebug(ss.str().c_str());
|
||||
image_tokens, matmul_env, timing_info);
|
||||
|
||||
prompt = WrapAndTokenize(
|
||||
model.Tokenizer(), model.ChatTemplate(), model_config.wrapping,
|
||||
|
|
|
|||
|
|
@ -605,6 +605,9 @@ static void GenerateT(const ModelConfig& config,
|
|||
config, runtime_config, weights, activations, qbatch, env, timing_info);
|
||||
// No-op if the profiler is disabled, but useful to separate prefill and
|
||||
// generate phases for profiling.
|
||||
if constexpr (PROFILER_ENABLED) {
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
env.ctx.profiler.PrintResults();
|
||||
|
||||
hwy::BitSet4096<> non_eos; // indexed by qi
|
||||
|
|
@ -725,25 +728,33 @@ void GenerateBatchT(const ModelConfig& config,
|
|||
void GenerateImageTokensT(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config, size_t seq_len,
|
||||
const WeightsPtrs& weights, const Image& image,
|
||||
ImageTokens& image_tokens, MatMulEnv& env) {
|
||||
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenImageTokens);
|
||||
if (config.vit_config.layer_configs.empty()) {
|
||||
HWY_ABORT("Model does not support generating image tokens.");
|
||||
}
|
||||
RuntimeConfig prefill_runtime_config = runtime_config;
|
||||
ImageTokens& image_tokens, MatMulEnv& env,
|
||||
TimingInfo& timing_info) {
|
||||
const ModelConfig vit_config = GetVitConfig(config);
|
||||
const size_t num_tokens = vit_config.max_seq_len;
|
||||
prefill_runtime_config.prefill_tbatch_size =
|
||||
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
|
||||
Activations prefill_activations(runtime_config, vit_config, num_tokens,
|
||||
num_tokens, env.ctx, env.row_ptrs);
|
||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
|
||||
prefill_activations, env);
|
||||
|
||||
timing_info.NotifyImageTokenStart();
|
||||
|
||||
{
|
||||
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenImageTokens);
|
||||
if (config.vit_config.layer_configs.empty()) {
|
||||
HWY_ABORT("Model does not support generating image tokens.");
|
||||
}
|
||||
RuntimeConfig prefill_runtime_config = runtime_config;
|
||||
prefill_runtime_config.prefill_tbatch_size =
|
||||
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
|
||||
Activations prefill_activations(runtime_config, vit_config, num_tokens,
|
||||
num_tokens, env.ctx, env.row_ptrs);
|
||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
|
||||
prefill_activations, env);
|
||||
} // end GCPP_ZONE before we print results.
|
||||
|
||||
// No-op if the profiler is disabled. Printing now ensures that the
|
||||
// `PrintResults` after prefill does not include the image token part.
|
||||
env.ctx.profiler.PrintResults();
|
||||
|
||||
timing_info.NotifyImageTokenDone(num_tokens);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
@ -814,13 +825,13 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
|||
|
||||
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
||||
size_t seq_len, const Image& image,
|
||||
ImageTokens& image_tokens,
|
||||
MatMulEnv& env) const {
|
||||
ImageTokens& image_tokens, MatMulEnv& env,
|
||||
TimingInfo& timing_info) const {
|
||||
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(model_.Config(), runtime_config,
|
||||
seq_len, weights_, image,
|
||||
image_tokens, env);
|
||||
image_tokens, env, timing_info);
|
||||
|
||||
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -65,6 +65,21 @@ class ContinuousQBatch : public QBatch {
|
|||
};
|
||||
|
||||
struct TimingInfo {
|
||||
void NotifyImageTokenStart() { image_tokens_start = hwy::platform::Now(); }
|
||||
|
||||
void NotifyImageTokenDone(size_t tokens) {
|
||||
image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||
image_tokens = tokens;
|
||||
|
||||
if (verbosity >= 1) {
|
||||
fprintf(stderr,
|
||||
"\n\n[ Timing info ] Image token generation took: %d ms (%.1f "
|
||||
"tok/sec)\n",
|
||||
static_cast<int>(image_tokens_duration * 1E3),
|
||||
image_tokens / image_tokens_duration);
|
||||
}
|
||||
}
|
||||
|
||||
// be sure to populate prefill_start before calling NotifyPrefill.
|
||||
void NotifyPrefill(size_t tokens) {
|
||||
prefill_duration = hwy::platform::Now() - prefill_start;
|
||||
|
|
@ -87,8 +102,8 @@ struct TimingInfo {
|
|||
fprintf(stderr,
|
||||
"\n\n[ Timing info ] Prefill: %d ms for %zu prompt tokens "
|
||||
"(%.2f tokens / sec); Time to first token: %d ms\n",
|
||||
static_cast<int>(prefill_duration * 1000), prefill_tokens,
|
||||
prefill_tok_sec, static_cast<int>(time_to_first_token * 1000));
|
||||
static_cast<int>(prefill_duration * 1E3), prefill_tokens,
|
||||
prefill_tok_sec, static_cast<int>(time_to_first_token * 1E3));
|
||||
}
|
||||
}
|
||||
if (HWY_UNLIKELY(verbosity >= 2 && tokens_generated % 1024 == 0)) {
|
||||
|
|
@ -110,20 +125,27 @@ struct TimingInfo {
|
|||
fprintf(stderr,
|
||||
"\n[ Timing info ] Generate: %d ms for %zu tokens (%.2f tokens / "
|
||||
"sec)\n",
|
||||
static_cast<int>(generate_duration * 1000), tokens_generated,
|
||||
static_cast<int>(generate_duration * 1E3), tokens_generated,
|
||||
gen_tok_sec);
|
||||
}
|
||||
}
|
||||
|
||||
int verbosity = 0;
|
||||
double prefill_start = 0;
|
||||
double generate_start = 0;
|
||||
double prefill_duration = 0;
|
||||
double image_tokens_start = 0.0;
|
||||
double image_tokens_duration = 0.0;
|
||||
size_t image_tokens = 0;
|
||||
|
||||
double prefill_start = 0.0;
|
||||
double prefill_duration = 0.0;
|
||||
size_t prefill_tokens = 0;
|
||||
double time_to_first_token = 0;
|
||||
double generate_duration = 0;
|
||||
|
||||
double generate_start = 0.0;
|
||||
double generate_duration = 0.0;
|
||||
size_t tokens_generated = 0;
|
||||
|
||||
double time_to_first_token = 0.0;
|
||||
size_t generation_steps = 0;
|
||||
|
||||
int verbosity = 0;
|
||||
};
|
||||
|
||||
// After construction, all methods are const and thread-compatible if using
|
||||
|
|
@ -173,7 +195,7 @@ class Gemma {
|
|||
// Generates the image tokens by running the image encoder ViT.
|
||||
void GenerateImageTokens(const RuntimeConfig& runtime_config, size_t seq_len,
|
||||
const Image& image, ImageTokens& image_tokens,
|
||||
MatMulEnv& env) const;
|
||||
MatMulEnv& env, TimingInfo& timing_info) const;
|
||||
|
||||
private:
|
||||
BlobReader reader_;
|
||||
|
|
|
|||
12
gemma/run.cc
12
gemma/run.cc
|
|
@ -99,6 +99,8 @@ void ReplGemma(const GemmaArgs& args, const Gemma& gemma, KVCache& kv_cache,
|
|||
size_t prompt_size = 0;
|
||||
const ModelConfig& config = gemma.Config();
|
||||
|
||||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||
|
||||
const bool have_image = !inference.image_file.path.empty();
|
||||
Image image;
|
||||
const size_t pool_dim = config.vit_config.pool_dim;
|
||||
|
|
@ -117,15 +119,8 @@ void ReplGemma(const GemmaArgs& args, const Gemma& gemma, KVCache& kv_cache,
|
|||
image.Resize(image_size, image_size);
|
||||
RuntimeConfig runtime_config = {.verbosity = verbosity,
|
||||
.use_spinning = args.threading.spin};
|
||||
double image_tokens_start = hwy::platform::Now();
|
||||
gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
|
||||
image_tokens, env);
|
||||
if (verbosity >= 1) {
|
||||
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||
fprintf(stderr,
|
||||
"\n\n[ Timing info ] Image token generation took: %d ms\n",
|
||||
static_cast<int>(image_tokens_duration * 1000));
|
||||
}
|
||||
image_tokens, env, timing_info);
|
||||
}
|
||||
|
||||
// callback function invoked for each generated token.
|
||||
|
|
@ -188,7 +183,6 @@ void ReplGemma(const GemmaArgs& args, const Gemma& gemma, KVCache& kv_cache,
|
|||
}
|
||||
|
||||
// Set up runtime config.
|
||||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||
RuntimeConfig runtime_config = {.verbosity = inference.verbosity,
|
||||
.batch_stream_token = batch_stream_token,
|
||||
.use_spinning = args.threading.spin};
|
||||
|
|
|
|||
|
|
@ -29,7 +29,8 @@ void PaliGemmaHelper::InitVit(const std::string& path) {
|
|||
image.Resize(image_size, image_size);
|
||||
RuntimeConfig runtime_config = {.verbosity = 0};
|
||||
gemma.GenerateImageTokens(runtime_config, env_->MutableKVCache().SeqLen(),
|
||||
image, *image_tokens_, env_->MutableEnv());
|
||||
image, *image_tokens_, env_->MutableEnv(),
|
||||
timing_info_);
|
||||
}
|
||||
|
||||
std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const {
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@
|
|||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -18,6 +20,7 @@ class PaliGemmaHelper {
|
|||
private:
|
||||
std::unique_ptr<ImageTokens> image_tokens_;
|
||||
GemmaEnv* env_;
|
||||
TimingInfo timing_info_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -183,7 +183,8 @@ class GemmaModel {
|
|||
env_.MutableEnv().ctx.allocator, gcpp::MatPadding::kOdd));
|
||||
gcpp::RuntimeConfig runtime_config = {.verbosity = 0};
|
||||
gemma.GenerateImageTokens(runtime_config, env_.MutableKVCache().SeqLen(),
|
||||
c_image, *image_tokens_, env_.MutableEnv());
|
||||
c_image, *image_tokens_, env_.MutableEnv(),
|
||||
timing_info_);
|
||||
}
|
||||
|
||||
// Generates a response to the given prompt, using the last set image.
|
||||
|
|
@ -244,6 +245,7 @@ class GemmaModel {
|
|||
private:
|
||||
gcpp::GemmaEnv env_;
|
||||
std::unique_ptr<gcpp::ImageTokens> image_tokens_;
|
||||
gcpp::TimingInfo timing_info_;
|
||||
float last_prob_;
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue