mirror of https://github.com/google/gemma.cpp.git
parent
4154f5a910
commit
27258b03e6
|
|
@ -114,9 +114,6 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
|||
size_t query_index, size_t pos, int token, float) {
|
||||
++total_tokens;
|
||||
res += StringFromTokens(std::vector<int>{token});
|
||||
if (app_.verbosity >= 1 && total_tokens % 128 == 0) {
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
if (app_.verbosity >= 2) {
|
||||
|
|
@ -125,13 +122,10 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
|||
<< inference_args_.max_generated_tokens
|
||||
<< "\ttemperature: " << inference_args_.temperature << "\n";
|
||||
}
|
||||
gcpp::TimingInfo timing_info;
|
||||
gcpp::TimingInfo timing_info { .verbosity = app_.verbosity };
|
||||
runtime_config_.batch_stream_token = batch_stream_token;
|
||||
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||
timing_info);
|
||||
if (app_.verbosity >= 1) {
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
}
|
||||
return {res, total_tokens};
|
||||
}
|
||||
|
||||
|
|
@ -153,9 +147,6 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
|
|||
res[query_index].first.append(token_text);
|
||||
res[query_index].second += 1;
|
||||
++total_tokens;
|
||||
if (app_.verbosity >= 1 && total_tokens % 128 == 0) {
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
if (app_.verbosity >= 2) {
|
||||
|
|
@ -177,14 +168,11 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
|
|||
}
|
||||
}
|
||||
|
||||
gcpp::TimingInfo timing_info;
|
||||
gcpp::TimingInfo timing_info = {.verbosity = app_.verbosity};
|
||||
runtime_config_.batch_stream_token = batch_stream_token;
|
||||
inference_args_.CopyTo(runtime_config_);
|
||||
model_->GenerateBatch(runtime_config_, prompts, /*start_pos=*/0,
|
||||
KVCaches(&kv_caches_[0], num_queries), timing_info);
|
||||
if (app_.verbosity >= 1) {
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -900,7 +900,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
}
|
||||
Softmax(logits, kVocabSize);
|
||||
const int token = sample_token(logits, kVocabSize);
|
||||
timing_info.NotifyGenerated(prefill_start);
|
||||
timing_info.NotifyGenerated(prefill_start, gen_start);
|
||||
|
||||
const bool is_eos = token_streamer(query_idx_start + query_idx,
|
||||
prefill_per_query + 1 + gen_per_query,
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@
|
|||
#include "hwy/timer.h"
|
||||
// IWYU pragma: end_exports
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h" // hwy::bfloat16_t
|
||||
#include "hwy/base.h" // hwy::bfloat16_t
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -82,37 +82,63 @@ struct RuntimeConfig {
|
|||
std::mt19937* gen;
|
||||
StreamFunc stream_token;
|
||||
BatchStreamFunc batch_stream_token;
|
||||
AcceptFunc accept_token; // if empty, accepts all tokens.
|
||||
SampleFunc sample_func; // if empty, uses SampleTopK.
|
||||
AcceptFunc accept_token; // if empty, accepts all tokens.
|
||||
SampleFunc sample_func; // if empty, uses SampleTopK.
|
||||
LayersOutputFunc layers_output; // if not empty, called after each layer.
|
||||
int eos_id = EOS_ID;
|
||||
};
|
||||
|
||||
struct TimingInfo {
|
||||
void NotifyPrefill(size_t tokens, double start) {
|
||||
prefill_tok_sec =
|
||||
static_cast<double>(tokens) / (hwy::platform::Now() - start);
|
||||
gen_tok_sec = 0.0;
|
||||
prefill_duration = hwy::platform::Now() - start;
|
||||
prefill_tokens = tokens;
|
||||
time_to_first_token = 0.0;
|
||||
tokens_generated = 0;
|
||||
}
|
||||
|
||||
void NotifyGenerated(double prefill_start) {
|
||||
void NotifyGenerated(double prefill_start, double gen_start) {
|
||||
++tokens_generated;
|
||||
if (HWY_UNLIKELY(tokens_generated == 1)) {
|
||||
time_to_first_token = hwy::platform::Now() - prefill_start;
|
||||
if (verbosity >= 1) {
|
||||
double prefill_tok_sec =
|
||||
static_cast<double>(prefill_tokens) / prefill_duration;
|
||||
fprintf(stderr,
|
||||
"\n\n[ Timing info ] Prefill: %d ms for %zu prompt tokens "
|
||||
"(%.2f tokens / sec); Time to first token: %d ms\n",
|
||||
static_cast<int>(prefill_duration * 1000), prefill_tokens,
|
||||
prefill_tok_sec, static_cast<int>(time_to_first_token * 1000));
|
||||
}
|
||||
}
|
||||
if (verbosity >= 2 && tokens_generated % 128 == 0) {
|
||||
double gen_tok_sec = static_cast<double>(tokens_generated) /
|
||||
(hwy::platform::Now() - gen_start);
|
||||
fprintf(stderr,
|
||||
"\n\n[ Timing info ] %zu tokens generated "
|
||||
"(avg speed %.2f tokens / sec)\n\n",
|
||||
tokens_generated, gen_tok_sec);
|
||||
}
|
||||
}
|
||||
|
||||
void NotifyGenerateDone(double gen_start) {
|
||||
gen_tok_sec = static_cast<double>(tokens_generated) /
|
||||
(hwy::platform::Now() - gen_start);
|
||||
generate_duration = hwy::platform::Now() - gen_start;
|
||||
if (verbosity >= 1) {
|
||||
double gen_tok_sec =
|
||||
static_cast<double>(tokens_generated) / generate_duration;
|
||||
fprintf(stderr,
|
||||
"\n[ Timing info ] Generate: %d ms for %zu tokens (%.2f tokens / "
|
||||
"sec)\n",
|
||||
static_cast<int>(generate_duration * 1000), tokens_generated,
|
||||
gen_tok_sec);
|
||||
}
|
||||
}
|
||||
|
||||
double prefill_tok_sec;
|
||||
double gen_tok_sec;
|
||||
double time_to_first_token;
|
||||
size_t tokens_generated;
|
||||
int verbosity = 0;
|
||||
double prefill_duration = 0;
|
||||
size_t prefill_tokens = 0;
|
||||
double time_to_first_token = 0;
|
||||
double generate_duration = 0;
|
||||
size_t tokens_generated = 0;
|
||||
};
|
||||
|
||||
using PromptTokens = hwy::Span<const int>;
|
||||
|
|
|
|||
11
gemma/run.cc
11
gemma/run.cc
|
|
@ -144,7 +144,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
}
|
||||
}
|
||||
|
||||
TimingInfo timing_info;
|
||||
TimingInfo timing_info = {.verbosity = verbosity};
|
||||
RuntimeConfig runtime_config = {
|
||||
.verbosity = verbosity,
|
||||
.gen = &gen,
|
||||
|
|
@ -153,15 +153,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
};
|
||||
args.CopyTo(runtime_config);
|
||||
model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info);
|
||||
if (verbosity >= 2) {
|
||||
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
|
||||
<< "\n"
|
||||
<< timing_info.prefill_tok_sec << " prefill tokens / sec"
|
||||
<< "\n"
|
||||
<< timing_info.gen_tok_sec << " tokens / sec" << "\n"
|
||||
<< static_cast<int>(timing_info.time_to_first_token * 1000)
|
||||
<< " milliseconds time to first token" << "\n";
|
||||
}
|
||||
std::cout << "\n\n";
|
||||
}
|
||||
std::cout
|
||||
|
|
|
|||
Loading…
Reference in New Issue