From f1eab987d8595648d47580a8e75ccde355f170dd Mon Sep 17 00:00:00 2001 From: Apoorv Reddy Date: Mon, 13 May 2024 00:03:36 -0700 Subject: [PATCH] Store tokens/sec in auxiliary struct TimingInfo. PiperOrigin-RevId: 633108908 --- BUILD.bazel | 9 +------- debug_prompt.cc | 13 ++++++++++-- gemma/benchmark.cc | 3 ++- gemma/gemma.cc | 52 ++++++++++++++++++++++++---------------------- gemma/gemma.h | 11 ++++++++-- gemma/run.cc | 11 +++++----- 6 files changed, 56 insertions(+), 43 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index b3c1090..aefd1dc 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -64,7 +64,6 @@ cc_library( ], deps = [ ":ops", - # "//base", "//compression:compress", "//compression:io", "@hwy//:hwy", @@ -153,11 +152,8 @@ cc_binary( ":app", ":args", ":gemma_lib", - # "//base", - "//compression:compress", "@hwy//:hwy", "@hwy//:nanobenchmark", - "@hwy//:profiler", "@hwy//:thread_pool", "@nlohmann_json//:json", ], @@ -172,11 +168,8 @@ cc_binary( ":app", ":args", ":gemma_lib", - # "//base", - "//compression:compress", + "//compression:io", "@hwy//:hwy", - "@hwy//:nanobenchmark", - "@hwy//:profiler", "@hwy//:thread_pool", "@nlohmann_json//:json", ], diff --git a/debug_prompt.cc b/debug_prompt.cc index e844311..62a427e 100644 --- a/debug_prompt.cc +++ b/debug_prompt.cc @@ -1,12 +1,18 @@ +#include #include #include +#include #include +#include +#include +#include "compression/io.h" #include "gemma/gemma.h" #include "nlohmann/json.hpp" #include "util/app.h" #include "util/args.h" +#include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" using json = nlohmann::json; @@ -54,9 +60,11 @@ std::pair QueryModel( std::cout << args.max_tokens << " " << args.max_generated_tokens << " " << args.temperature; } + gcpp::TimingInfo timing_info; GenerateGemma(model, args.max_tokens, args.max_generated_tokens, args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool, - stream_token, accept_token, gen, app.verbosity, layers_output); + stream_token, accept_token, gen, app.verbosity, timing_info, + layers_output); return {res, total_tokens}; } @@ -65,7 +73,8 @@ class OutputJsonLogger { json json_output; gcpp::LayersOutputT layers_output_log_f = - [this](int pos, const std::string& key, const float* values, size_t values_len) { + [this](int pos, const std::string& key, const float* values, + size_t values_len) { std::vector v{values, values + values_len}; json_output[std::to_string(pos)][key] = v; }; diff --git a/gemma/benchmark.cc b/gemma/benchmark.cc index fdcf1b2..60bbcdb 100644 --- a/gemma/benchmark.cc +++ b/gemma/benchmark.cc @@ -87,9 +87,10 @@ std::pair QueryModel( std::cout << args.max_tokens << " " << args.max_generated_tokens << " " << args.temperature; } + gcpp::TimingInfo timing_info; GenerateGemma(model, args.max_tokens, args.max_generated_tokens, args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool, - stream_token, accept_token, gen, app.verbosity); + stream_token, accept_token, gen, app.verbosity, timing_info); if (app.verbosity >= 1) { LogSpeedStats(time_start, total_tokens); } diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 41e7534..51fbad8 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -458,7 +458,8 @@ struct GemmaInterface { size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity, LayersOutputT* layers_output) = 0; + int verbosity, TimingInfo& timing_info, + LayersOutputT* layers_output) = 0; virtual float ComputeCrossEntropy(size_t max_tokens, const std::vector& prompt, @@ -550,7 +551,7 @@ struct GemmaImpl : public GemmaInterface { float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937&, int verbosity, + std::mt19937&, int verbosity, TimingInfo& timing_info, LayersOutputT* layers_output) override; float ComputeCrossEntropy(size_t max_tokens, const std::vector& prompt, @@ -1087,7 +1088,8 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, const std::vector& prompt, size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity, LayersOutputT* layers_output) { + int verbosity, TimingInfo& timing_info, + LayersOutputT* layers_output) { static constexpr size_t kVocabSize = TConfig::kVocabSize; Activations& activations = *gemma.state.get(); Activations& prefill_activations = @@ -1137,12 +1139,9 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, } if (verbosity >= 2) { - // in the future this output should not occur in GenerateImpl but instead - // should be available as observable state for frontend code to handle I/O. const double prefill_end = hwy::platform::Now(); - const double prefill_tok_sec = + timing_info.prefill_tok_sec = static_cast(pos_offset) / (prefill_end - prefill_start); - std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]"; } const double gen_start = hwy::platform::Now(); @@ -1186,10 +1185,9 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, if (token == EOS_ID) { if (verbosity >= 2) { const double gen_end = hwy::platform::Now(); - const double gen_tok_sec = + timing_info.gen_tok_sec = static_cast(pos_offset - pos_gen_start) / (gen_end - gen_start); - std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n"; } break; } @@ -1266,11 +1264,11 @@ void Generate2B(GemmaImpl& gemma, size_t max_tokens, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity, + std::mt19937& gen, int verbosity, TimingInfo& timing_info, LayersOutputT* layers_output) { GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, start_pos, kv_cache, pool, stream_token, accept_token, gen, - verbosity, layers_output); + verbosity, timing_info, layers_output); } void Generate7B(GemmaImpl& gemma, size_t max_tokens, @@ -1278,11 +1276,11 @@ void Generate7B(GemmaImpl& gemma, size_t max_tokens, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity, + std::mt19937& gen, int verbosity, TimingInfo& timing_info, LayersOutputT* layers_output) { GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, start_pos, kv_cache, pool, stream_token, accept_token, gen, - verbosity, layers_output); + verbosity, timing_info, layers_output); } void GenerateGriffin2B(GemmaImpl& gemma, size_t max_tokens, @@ -1291,10 +1289,11 @@ void GenerateGriffin2B(GemmaImpl& gemma, size_t max_tokens, KVCache& kv_cache, hwy::ThreadPool& pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity, LayersOutputT* layers_output) { + int verbosity, TimingInfo& timing_info, + LayersOutputT* layers_output) { GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, start_pos, kv_cache, pool, stream_token, accept_token, gen, - verbosity, layers_output); + verbosity, timing_info, layers_output); } float ComputeCrossEntropy2B(GemmaImpl& gemma, size_t max_tokens, @@ -1559,10 +1558,10 @@ void GemmaImpl::Generate( const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, - LayersOutputT* layers_output) { + TimingInfo& timing_info, LayersOutputT* layers_output) { HWY_DYNAMIC_DISPATCH(Generate2B) (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - kv_cache, pool, stream_token, accept_token, gen, verbosity, + kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info, layers_output); } @@ -1572,10 +1571,11 @@ void GemmaImpl::Generate( const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, - LayersOutputT* layers_output) { + TimingInfo& timing_info, LayersOutputT* layers_output) { HWY_DYNAMIC_DISPATCH(Generate7B) (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - kv_cache, pool, stream_token, accept_token, gen, verbosity, layers_output); + kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info, + layers_output); } template <> @@ -1584,10 +1584,10 @@ void GemmaImpl::Generate( const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, - LayersOutputT* layers_output) { + TimingInfo& timing_info, LayersOutputT* layers_output) { HWY_DYNAMIC_DISPATCH(GenerateGriffin2B) (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - kv_cache, pool, stream_token, accept_token, gen, verbosity, + kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info, layers_output); } @@ -1658,23 +1658,25 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity, LayersOutputT* layers_output) { + int verbosity, TimingInfo& timing_info, + LayersOutputT* layers_output) { pool.SetWaitMode(hwy::PoolWaitMode::kSpin); gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt, start_pos, kv_cache, pool, stream_token, accept_token, - gen, verbosity, layers_output); + gen, verbosity, timing_info, layers_output); pool.SetWaitMode(hwy::PoolWaitMode::kBlock); } void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - const StreamFunc& stream_token, std::mt19937& gen) { + const StreamFunc& stream_token, std::mt19937& gen, + TimingInfo& timing_info) { GenerateGemma( gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens, runtime_config.temperature, prompt, start_pos, kv_cache, pool, stream_token, [](int) { return true; }, gen, runtime_config.verbosity, - /*layers_output=*/nullptr); + timing_info, /*layers_output=*/nullptr); } void CompressWeights(gcpp::Model model, const Path& weights, diff --git a/gemma/gemma.h b/gemma/gemma.h index 822f75b..bd43046 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -88,6 +88,11 @@ struct Gemma { std::unique_ptr impl_; }; +struct TimingInfo { + double prefill_tok_sec = 0.0; + double gen_tok_sec = 0.0; +}; + KVCache CreateKVCache(Model type); // convenient workaround for now KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len, size_t conv1d_cache_size, size_t rglru_cache_size); @@ -104,7 +109,8 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity, LayersOutputT* layers_output = nullptr); + int verbosity, TimingInfo& timing_info, + LayersOutputT* layers_output = nullptr); // Convenience function for the common case: // - Bundle runtime parameters as RuntimeConfig @@ -112,7 +118,8 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - const StreamFunc& stream_token, std::mt19937& gen); + const StreamFunc& stream_token, std::mt19937& gen, + int verbosity, TimingInfo& timing_info); void CompressWeights(gcpp::Model model, const Path& weights, const Path& compressed_weights, hwy::ThreadPool& pool); diff --git a/gemma/run.cc b/gemma/run.cc index 942bff8..f8a8cc8 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include // NOLINT #include @@ -206,16 +207,16 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training, } } - const double time_start = hwy::platform::Now(); + TimingInfo timing_info; GenerateGemma(model, args.max_tokens, args.max_generated_tokens, args.temperature, prompt, abs_pos, kv_cache, pool, - stream_token, accept_token, gen, verbosity); - const double time_end = hwy::platform::Now(); - const double tok_sec = current_pos / (time_end - time_start); + stream_token, accept_token, gen, verbosity, timing_info); if (verbosity >= 2) { std::cout << current_pos << " tokens (" << abs_pos << " total tokens)" << "\n" - << tok_sec << " tokens / sec" << "\n"; + << timing_info.prefill_tok_sec << " prefill tokens / sec" + << "\n" + << timing_info.gen_tok_sec << " tokens / sec" << "\n"; } std::cout << "\n\n"; }