mirror of https://github.com/google/gemma.cpp.git
Store tokens/sec in auxiliary struct TimingInfo.
PiperOrigin-RevId: 633108908
This commit is contained in:
parent
22fe9809ac
commit
f1eab987d8
|
|
@ -64,7 +64,6 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
":ops",
|
||||
# "//base",
|
||||
"//compression:compress",
|
||||
"//compression:io",
|
||||
"@hwy//:hwy",
|
||||
|
|
@ -153,11 +152,8 @@ cc_binary(
|
|||
":app",
|
||||
":args",
|
||||
":gemma_lib",
|
||||
# "//base",
|
||||
"//compression:compress",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:nanobenchmark",
|
||||
"@hwy//:profiler",
|
||||
"@hwy//:thread_pool",
|
||||
"@nlohmann_json//:json",
|
||||
],
|
||||
|
|
@ -172,11 +168,8 @@ cc_binary(
|
|||
":app",
|
||||
":args",
|
||||
":gemma_lib",
|
||||
# "//base",
|
||||
"//compression:compress",
|
||||
"//compression:io",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:nanobenchmark",
|
||||
"@hwy//:profiler",
|
||||
"@hwy//:thread_pool",
|
||||
"@nlohmann_json//:json",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -1,12 +1,18 @@
|
|||
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "util/app.h"
|
||||
#include "util/args.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
|
@ -54,9 +60,11 @@ std::pair<std::string, int> QueryModel(
|
|||
std::cout << args.max_tokens << " " << args.max_generated_tokens << " "
|
||||
<< args.temperature;
|
||||
}
|
||||
gcpp::TimingInfo timing_info;
|
||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
|
||||
stream_token, accept_token, gen, app.verbosity, layers_output);
|
||||
stream_token, accept_token, gen, app.verbosity, timing_info,
|
||||
layers_output);
|
||||
return {res, total_tokens};
|
||||
}
|
||||
|
||||
|
|
@ -65,7 +73,8 @@ class OutputJsonLogger {
|
|||
json json_output;
|
||||
|
||||
gcpp::LayersOutputT layers_output_log_f =
|
||||
[this](int pos, const std::string& key, const float* values, size_t values_len) {
|
||||
[this](int pos, const std::string& key, const float* values,
|
||||
size_t values_len) {
|
||||
std::vector<float> v{values, values + values_len};
|
||||
json_output[std::to_string(pos)][key] = v;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -87,9 +87,10 @@ std::pair<std::string, int> QueryModel(
|
|||
std::cout << args.max_tokens << " " << args.max_generated_tokens << " "
|
||||
<< args.temperature;
|
||||
}
|
||||
gcpp::TimingInfo timing_info;
|
||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
|
||||
stream_token, accept_token, gen, app.verbosity);
|
||||
stream_token, accept_token, gen, app.verbosity, timing_info);
|
||||
if (app.verbosity >= 1) {
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -458,7 +458,8 @@ struct GemmaInterface {
|
|||
size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity, LayersOutputT* layers_output) = 0;
|
||||
int verbosity, TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) = 0;
|
||||
|
||||
virtual float ComputeCrossEntropy(size_t max_tokens,
|
||||
const std::vector<int>& prompt,
|
||||
|
|
@ -550,7 +551,7 @@ struct GemmaImpl : public GemmaInterface {
|
|||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937&, int verbosity,
|
||||
std::mt19937&, int verbosity, TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) override;
|
||||
|
||||
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
|
||||
|
|
@ -1087,7 +1088,8 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity, LayersOutputT* layers_output) {
|
||||
int verbosity, TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) {
|
||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
Activations<TConfig, 1>& activations = *gemma.state.get();
|
||||
Activations<TConfig, kPrefillBatchSize>& prefill_activations =
|
||||
|
|
@ -1137,12 +1139,9 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
}
|
||||
|
||||
if (verbosity >= 2) {
|
||||
// in the future this output should not occur in GenerateImpl but instead
|
||||
// should be available as observable state for frontend code to handle I/O.
|
||||
const double prefill_end = hwy::platform::Now();
|
||||
const double prefill_tok_sec =
|
||||
timing_info.prefill_tok_sec =
|
||||
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
|
||||
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]";
|
||||
}
|
||||
|
||||
const double gen_start = hwy::platform::Now();
|
||||
|
|
@ -1186,10 +1185,9 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
if (token == EOS_ID) {
|
||||
if (verbosity >= 2) {
|
||||
const double gen_end = hwy::platform::Now();
|
||||
const double gen_tok_sec =
|
||||
timing_info.gen_tok_sec =
|
||||
static_cast<double>(pos_offset - pos_gen_start) /
|
||||
(gen_end - gen_start);
|
||||
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
|
@ -1266,11 +1264,11 @@ void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
|||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity,
|
||||
std::mt19937& gen, int verbosity, TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) {
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, kv_cache, pool, stream_token, accept_token, gen,
|
||||
verbosity, layers_output);
|
||||
verbosity, timing_info, layers_output);
|
||||
}
|
||||
|
||||
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
||||
|
|
@ -1278,11 +1276,11 @@ void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
|||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity,
|
||||
std::mt19937& gen, int verbosity, TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) {
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, kv_cache, pool, stream_token, accept_token, gen,
|
||||
verbosity, layers_output);
|
||||
verbosity, timing_info, layers_output);
|
||||
}
|
||||
|
||||
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens,
|
||||
|
|
@ -1291,10 +1289,11 @@ void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens,
|
|||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity, LayersOutputT* layers_output) {
|
||||
int verbosity, TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) {
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, kv_cache, pool, stream_token, accept_token, gen,
|
||||
verbosity, layers_output);
|
||||
verbosity, timing_info, layers_output);
|
||||
}
|
||||
|
||||
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
||||
|
|
@ -1559,10 +1558,10 @@ void GemmaImpl<ConfigGemma2B>::Generate(
|
|||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
|
||||
LayersOutputT* layers_output) {
|
||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
||||
HWY_DYNAMIC_DISPATCH(Generate2B)
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
kv_cache, pool, stream_token, accept_token, gen, verbosity,
|
||||
kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info,
|
||||
layers_output);
|
||||
}
|
||||
|
||||
|
|
@ -1572,10 +1571,11 @@ void GemmaImpl<ConfigGemma7B>::Generate(
|
|||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
|
||||
LayersOutputT* layers_output) {
|
||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
||||
HWY_DYNAMIC_DISPATCH(Generate7B)
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
kv_cache, pool, stream_token, accept_token, gen, verbosity, layers_output);
|
||||
kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info,
|
||||
layers_output);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
|
@ -1584,10 +1584,10 @@ void GemmaImpl<ConfigGriffin2B>::Generate(
|
|||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
|
||||
LayersOutputT* layers_output) {
|
||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
||||
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
kv_cache, pool, stream_token, accept_token, gen, verbosity,
|
||||
kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info,
|
||||
layers_output);
|
||||
}
|
||||
|
||||
|
|
@ -1658,23 +1658,25 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
|||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity, LayersOutputT* layers_output) {
|
||||
int verbosity, TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) {
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, kv_cache, pool, stream_token, accept_token,
|
||||
gen, verbosity, layers_output);
|
||||
gen, verbosity, timing_info, layers_output);
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
|
||||
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token, std::mt19937& gen) {
|
||||
const StreamFunc& stream_token, std::mt19937& gen,
|
||||
TimingInfo& timing_info) {
|
||||
GenerateGemma(
|
||||
gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens,
|
||||
runtime_config.temperature, prompt, start_pos, kv_cache, pool,
|
||||
stream_token, [](int) { return true; }, gen, runtime_config.verbosity,
|
||||
/*layers_output=*/nullptr);
|
||||
timing_info, /*layers_output=*/nullptr);
|
||||
}
|
||||
|
||||
void CompressWeights(gcpp::Model model, const Path& weights,
|
||||
|
|
|
|||
|
|
@ -88,6 +88,11 @@ struct Gemma {
|
|||
std::unique_ptr<GemmaInterface> impl_;
|
||||
};
|
||||
|
||||
struct TimingInfo {
|
||||
double prefill_tok_sec = 0.0;
|
||||
double gen_tok_sec = 0.0;
|
||||
};
|
||||
|
||||
KVCache CreateKVCache(Model type); // convenient workaround for now
|
||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
|
||||
size_t conv1d_cache_size, size_t rglru_cache_size);
|
||||
|
|
@ -104,7 +109,8 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
|||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity, LayersOutputT* layers_output = nullptr);
|
||||
int verbosity, TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output = nullptr);
|
||||
|
||||
// Convenience function for the common case:
|
||||
// - Bundle runtime parameters as RuntimeConfig
|
||||
|
|
@ -112,7 +118,8 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
|||
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token, std::mt19937& gen);
|
||||
const StreamFunc& stream_token, std::mt19937& gen,
|
||||
int verbosity, TimingInfo& timing_info);
|
||||
|
||||
void CompressWeights(gcpp::Model model, const Path& weights,
|
||||
const Path& compressed_weights, hwy::ThreadPool& pool);
|
||||
|
|
|
|||
11
gemma/run.cc
11
gemma/run.cc
|
|
@ -19,6 +19,7 @@
|
|||
#include <iostream>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <thread> // NOLINT
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -206,16 +207,16 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
|||
}
|
||||
}
|
||||
|
||||
const double time_start = hwy::platform::Now();
|
||||
TimingInfo timing_info;
|
||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||
args.temperature, prompt, abs_pos, kv_cache, pool,
|
||||
stream_token, accept_token, gen, verbosity);
|
||||
const double time_end = hwy::platform::Now();
|
||||
const double tok_sec = current_pos / (time_end - time_start);
|
||||
stream_token, accept_token, gen, verbosity, timing_info);
|
||||
if (verbosity >= 2) {
|
||||
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
|
||||
<< "\n"
|
||||
<< tok_sec << " tokens / sec" << "\n";
|
||||
<< timing_info.prefill_tok_sec << " prefill tokens / sec"
|
||||
<< "\n"
|
||||
<< timing_info.gen_tok_sec << " tokens / sec" << "\n";
|
||||
}
|
||||
std::cout << "\n\n";
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue