mirror of https://github.com/google/gemma.cpp.git
parent
4154f5a910
commit
27258b03e6
|
|
@ -114,9 +114,6 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
||||||
size_t query_index, size_t pos, int token, float) {
|
size_t query_index, size_t pos, int token, float) {
|
||||||
++total_tokens;
|
++total_tokens;
|
||||||
res += StringFromTokens(std::vector<int>{token});
|
res += StringFromTokens(std::vector<int>{token});
|
||||||
if (app_.verbosity >= 1 && total_tokens % 128 == 0) {
|
|
||||||
LogSpeedStats(time_start, total_tokens);
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
if (app_.verbosity >= 2) {
|
if (app_.verbosity >= 2) {
|
||||||
|
|
@ -125,13 +122,10 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
||||||
<< inference_args_.max_generated_tokens
|
<< inference_args_.max_generated_tokens
|
||||||
<< "\ttemperature: " << inference_args_.temperature << "\n";
|
<< "\ttemperature: " << inference_args_.temperature << "\n";
|
||||||
}
|
}
|
||||||
gcpp::TimingInfo timing_info;
|
gcpp::TimingInfo timing_info { .verbosity = app_.verbosity };
|
||||||
runtime_config_.batch_stream_token = batch_stream_token;
|
runtime_config_.batch_stream_token = batch_stream_token;
|
||||||
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||||
timing_info);
|
timing_info);
|
||||||
if (app_.verbosity >= 1) {
|
|
||||||
LogSpeedStats(time_start, total_tokens);
|
|
||||||
}
|
|
||||||
return {res, total_tokens};
|
return {res, total_tokens};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -153,9 +147,6 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
|
||||||
res[query_index].first.append(token_text);
|
res[query_index].first.append(token_text);
|
||||||
res[query_index].second += 1;
|
res[query_index].second += 1;
|
||||||
++total_tokens;
|
++total_tokens;
|
||||||
if (app_.verbosity >= 1 && total_tokens % 128 == 0) {
|
|
||||||
LogSpeedStats(time_start, total_tokens);
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
if (app_.verbosity >= 2) {
|
if (app_.verbosity >= 2) {
|
||||||
|
|
@ -177,14 +168,11 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
gcpp::TimingInfo timing_info;
|
gcpp::TimingInfo timing_info = {.verbosity = app_.verbosity};
|
||||||
runtime_config_.batch_stream_token = batch_stream_token;
|
runtime_config_.batch_stream_token = batch_stream_token;
|
||||||
inference_args_.CopyTo(runtime_config_);
|
inference_args_.CopyTo(runtime_config_);
|
||||||
model_->GenerateBatch(runtime_config_, prompts, /*start_pos=*/0,
|
model_->GenerateBatch(runtime_config_, prompts, /*start_pos=*/0,
|
||||||
KVCaches(&kv_caches_[0], num_queries), timing_info);
|
KVCaches(&kv_caches_[0], num_queries), timing_info);
|
||||||
if (app_.verbosity >= 1) {
|
|
||||||
LogSpeedStats(time_start, total_tokens);
|
|
||||||
}
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -900,7 +900,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
}
|
}
|
||||||
Softmax(logits, kVocabSize);
|
Softmax(logits, kVocabSize);
|
||||||
const int token = sample_token(logits, kVocabSize);
|
const int token = sample_token(logits, kVocabSize);
|
||||||
timing_info.NotifyGenerated(prefill_start);
|
timing_info.NotifyGenerated(prefill_start, gen_start);
|
||||||
|
|
||||||
const bool is_eos = token_streamer(query_idx_start + query_idx,
|
const bool is_eos = token_streamer(query_idx_start + query_idx,
|
||||||
prefill_per_query + 1 + gen_per_query,
|
prefill_per_query + 1 + gen_per_query,
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
#include "hwy/aligned_allocator.h" // Span
|
#include "hwy/aligned_allocator.h" // Span
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
#include "hwy/base.h" // hwy::bfloat16_t
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -82,37 +82,63 @@ struct RuntimeConfig {
|
||||||
std::mt19937* gen;
|
std::mt19937* gen;
|
||||||
StreamFunc stream_token;
|
StreamFunc stream_token;
|
||||||
BatchStreamFunc batch_stream_token;
|
BatchStreamFunc batch_stream_token;
|
||||||
AcceptFunc accept_token; // if empty, accepts all tokens.
|
AcceptFunc accept_token; // if empty, accepts all tokens.
|
||||||
SampleFunc sample_func; // if empty, uses SampleTopK.
|
SampleFunc sample_func; // if empty, uses SampleTopK.
|
||||||
LayersOutputFunc layers_output; // if not empty, called after each layer.
|
LayersOutputFunc layers_output; // if not empty, called after each layer.
|
||||||
int eos_id = EOS_ID;
|
int eos_id = EOS_ID;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TimingInfo {
|
struct TimingInfo {
|
||||||
void NotifyPrefill(size_t tokens, double start) {
|
void NotifyPrefill(size_t tokens, double start) {
|
||||||
prefill_tok_sec =
|
prefill_duration = hwy::platform::Now() - start;
|
||||||
static_cast<double>(tokens) / (hwy::platform::Now() - start);
|
prefill_tokens = tokens;
|
||||||
gen_tok_sec = 0.0;
|
|
||||||
time_to_first_token = 0.0;
|
time_to_first_token = 0.0;
|
||||||
tokens_generated = 0;
|
tokens_generated = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void NotifyGenerated(double prefill_start) {
|
void NotifyGenerated(double prefill_start, double gen_start) {
|
||||||
++tokens_generated;
|
++tokens_generated;
|
||||||
if (HWY_UNLIKELY(tokens_generated == 1)) {
|
if (HWY_UNLIKELY(tokens_generated == 1)) {
|
||||||
time_to_first_token = hwy::platform::Now() - prefill_start;
|
time_to_first_token = hwy::platform::Now() - prefill_start;
|
||||||
|
if (verbosity >= 1) {
|
||||||
|
double prefill_tok_sec =
|
||||||
|
static_cast<double>(prefill_tokens) / prefill_duration;
|
||||||
|
fprintf(stderr,
|
||||||
|
"\n\n[ Timing info ] Prefill: %d ms for %zu prompt tokens "
|
||||||
|
"(%.2f tokens / sec); Time to first token: %d ms\n",
|
||||||
|
static_cast<int>(prefill_duration * 1000), prefill_tokens,
|
||||||
|
prefill_tok_sec, static_cast<int>(time_to_first_token * 1000));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (verbosity >= 2 && tokens_generated % 128 == 0) {
|
||||||
|
double gen_tok_sec = static_cast<double>(tokens_generated) /
|
||||||
|
(hwy::platform::Now() - gen_start);
|
||||||
|
fprintf(stderr,
|
||||||
|
"\n\n[ Timing info ] %zu tokens generated "
|
||||||
|
"(avg speed %.2f tokens / sec)\n\n",
|
||||||
|
tokens_generated, gen_tok_sec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void NotifyGenerateDone(double gen_start) {
|
void NotifyGenerateDone(double gen_start) {
|
||||||
gen_tok_sec = static_cast<double>(tokens_generated) /
|
generate_duration = hwy::platform::Now() - gen_start;
|
||||||
(hwy::platform::Now() - gen_start);
|
if (verbosity >= 1) {
|
||||||
|
double gen_tok_sec =
|
||||||
|
static_cast<double>(tokens_generated) / generate_duration;
|
||||||
|
fprintf(stderr,
|
||||||
|
"\n[ Timing info ] Generate: %d ms for %zu tokens (%.2f tokens / "
|
||||||
|
"sec)\n",
|
||||||
|
static_cast<int>(generate_duration * 1000), tokens_generated,
|
||||||
|
gen_tok_sec);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
double prefill_tok_sec;
|
int verbosity = 0;
|
||||||
double gen_tok_sec;
|
double prefill_duration = 0;
|
||||||
double time_to_first_token;
|
size_t prefill_tokens = 0;
|
||||||
size_t tokens_generated;
|
double time_to_first_token = 0;
|
||||||
|
double generate_duration = 0;
|
||||||
|
size_t tokens_generated = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
using PromptTokens = hwy::Span<const int>;
|
using PromptTokens = hwy::Span<const int>;
|
||||||
|
|
|
||||||
11
gemma/run.cc
11
gemma/run.cc
|
|
@ -144,7 +144,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TimingInfo timing_info;
|
TimingInfo timing_info = {.verbosity = verbosity};
|
||||||
RuntimeConfig runtime_config = {
|
RuntimeConfig runtime_config = {
|
||||||
.verbosity = verbosity,
|
.verbosity = verbosity,
|
||||||
.gen = &gen,
|
.gen = &gen,
|
||||||
|
|
@ -153,15 +153,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
||||||
};
|
};
|
||||||
args.CopyTo(runtime_config);
|
args.CopyTo(runtime_config);
|
||||||
model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info);
|
model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info);
|
||||||
if (verbosity >= 2) {
|
|
||||||
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
|
|
||||||
<< "\n"
|
|
||||||
<< timing_info.prefill_tok_sec << " prefill tokens / sec"
|
|
||||||
<< "\n"
|
|
||||||
<< timing_info.gen_tok_sec << " tokens / sec" << "\n"
|
|
||||||
<< static_cast<int>(timing_info.time_to_first_token * 1000)
|
|
||||||
<< " milliseconds time to first token" << "\n";
|
|
||||||
}
|
|
||||||
std::cout << "\n\n";
|
std::cout << "\n\n";
|
||||||
}
|
}
|
||||||
std::cout
|
std::cout
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue