Store tokens/sec in auxiliary struct TimingInfo.

PiperOrigin-RevId: 633108908
This commit is contained in:
Apoorv Reddy 2024-05-13 00:03:36 -07:00 committed by Copybara-Service
parent 22fe9809ac
commit f1eab987d8
6 changed files with 56 additions and 43 deletions

View File

@ -64,7 +64,6 @@ cc_library(
], ],
deps = [ deps = [
":ops", ":ops",
# "//base",
"//compression:compress", "//compression:compress",
"//compression:io", "//compression:io",
"@hwy//:hwy", "@hwy//:hwy",
@ -153,11 +152,8 @@ cc_binary(
":app", ":app",
":args", ":args",
":gemma_lib", ":gemma_lib",
# "//base",
"//compression:compress",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:nanobenchmark", "@hwy//:nanobenchmark",
"@hwy//:profiler",
"@hwy//:thread_pool", "@hwy//:thread_pool",
"@nlohmann_json//:json", "@nlohmann_json//:json",
], ],
@ -172,11 +168,8 @@ cc_binary(
":app", ":app",
":args", ":args",
":gemma_lib", ":gemma_lib",
# "//base", "//compression:io",
"//compression:compress",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:nanobenchmark",
"@hwy//:profiler",
"@hwy//:thread_pool", "@hwy//:thread_pool",
"@nlohmann_json//:json", "@nlohmann_json//:json",
], ],

View File

@ -1,12 +1,18 @@
#include <cstdlib>
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <random>
#include <string> #include <string>
#include <utility>
#include <vector>
#include "compression/io.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
#include "util/app.h" #include "util/app.h"
#include "util/args.h" #include "util/args.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
using json = nlohmann::json; using json = nlohmann::json;
@ -54,9 +60,11 @@ std::pair<std::string, int> QueryModel(
std::cout << args.max_tokens << " " << args.max_generated_tokens << " " std::cout << args.max_tokens << " " << args.max_generated_tokens << " "
<< args.temperature; << args.temperature;
} }
gcpp::TimingInfo timing_info;
GenerateGemma(model, args.max_tokens, args.max_generated_tokens, GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool, args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
stream_token, accept_token, gen, app.verbosity, layers_output); stream_token, accept_token, gen, app.verbosity, timing_info,
layers_output);
return {res, total_tokens}; return {res, total_tokens};
} }
@ -65,7 +73,8 @@ class OutputJsonLogger {
json json_output; json json_output;
gcpp::LayersOutputT layers_output_log_f = gcpp::LayersOutputT layers_output_log_f =
[this](int pos, const std::string& key, const float* values, size_t values_len) { [this](int pos, const std::string& key, const float* values,
size_t values_len) {
std::vector<float> v{values, values + values_len}; std::vector<float> v{values, values + values_len};
json_output[std::to_string(pos)][key] = v; json_output[std::to_string(pos)][key] = v;
}; };

View File

@ -87,9 +87,10 @@ std::pair<std::string, int> QueryModel(
std::cout << args.max_tokens << " " << args.max_generated_tokens << " " std::cout << args.max_tokens << " " << args.max_generated_tokens << " "
<< args.temperature; << args.temperature;
} }
gcpp::TimingInfo timing_info;
GenerateGemma(model, args.max_tokens, args.max_generated_tokens, GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool, args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
stream_token, accept_token, gen, app.verbosity); stream_token, accept_token, gen, app.verbosity, timing_info);
if (app.verbosity >= 1) { if (app.verbosity >= 1) {
LogSpeedStats(time_start, total_tokens); LogSpeedStats(time_start, total_tokens);
} }

View File

@ -458,7 +458,8 @@ struct GemmaInterface {
size_t start_pos, KVCache& kv_cache, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, const StreamFunc& stream_token, hwy::ThreadPool& pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity, LayersOutputT* layers_output) = 0; int verbosity, TimingInfo& timing_info,
LayersOutputT* layers_output) = 0;
virtual float ComputeCrossEntropy(size_t max_tokens, virtual float ComputeCrossEntropy(size_t max_tokens,
const std::vector<int>& prompt, const std::vector<int>& prompt,
@ -550,7 +551,7 @@ struct GemmaImpl : public GemmaInterface {
float temperature, const std::vector<int>& prompt, float temperature, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token, const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937&, int verbosity, std::mt19937&, int verbosity, TimingInfo& timing_info,
LayersOutputT* layers_output) override; LayersOutputT* layers_output) override;
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt, float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
@ -1087,7 +1088,8 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache, const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, const StreamFunc& stream_token, hwy::ThreadPool& pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity, LayersOutputT* layers_output) { int verbosity, TimingInfo& timing_info,
LayersOutputT* layers_output) {
static constexpr size_t kVocabSize = TConfig::kVocabSize; static constexpr size_t kVocabSize = TConfig::kVocabSize;
Activations<TConfig, 1>& activations = *gemma.state.get(); Activations<TConfig, 1>& activations = *gemma.state.get();
Activations<TConfig, kPrefillBatchSize>& prefill_activations = Activations<TConfig, kPrefillBatchSize>& prefill_activations =
@ -1137,12 +1139,9 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
} }
if (verbosity >= 2) { if (verbosity >= 2) {
// in the future this output should not occur in GenerateImpl but instead
// should be available as observable state for frontend code to handle I/O.
const double prefill_end = hwy::platform::Now(); const double prefill_end = hwy::platform::Now();
const double prefill_tok_sec = timing_info.prefill_tok_sec =
static_cast<double>(pos_offset) / (prefill_end - prefill_start); static_cast<double>(pos_offset) / (prefill_end - prefill_start);
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]";
} }
const double gen_start = hwy::platform::Now(); const double gen_start = hwy::platform::Now();
@ -1186,10 +1185,9 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
if (token == EOS_ID) { if (token == EOS_ID) {
if (verbosity >= 2) { if (verbosity >= 2) {
const double gen_end = hwy::platform::Now(); const double gen_end = hwy::platform::Now();
const double gen_tok_sec = timing_info.gen_tok_sec =
static_cast<double>(pos_offset - pos_gen_start) / static_cast<double>(pos_offset - pos_gen_start) /
(gen_end - gen_start); (gen_end - gen_start);
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
} }
break; break;
} }
@ -1266,11 +1264,11 @@ void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token, const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity, std::mt19937& gen, int verbosity, TimingInfo& timing_info,
LayersOutputT* layers_output) { LayersOutputT* layers_output) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, stream_token, accept_token, gen, start_pos, kv_cache, pool, stream_token, accept_token, gen,
verbosity, layers_output); verbosity, timing_info, layers_output);
} }
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens, void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
@ -1278,11 +1276,11 @@ void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token, const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity, std::mt19937& gen, int verbosity, TimingInfo& timing_info,
LayersOutputT* layers_output) { LayersOutputT* layers_output) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, stream_token, accept_token, gen, start_pos, kv_cache, pool, stream_token, accept_token, gen,
verbosity, layers_output); verbosity, timing_info, layers_output);
} }
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens, void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens,
@ -1291,10 +1289,11 @@ void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity, LayersOutputT* layers_output) { int verbosity, TimingInfo& timing_info,
LayersOutputT* layers_output) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, stream_token, accept_token, gen, start_pos, kv_cache, pool, stream_token, accept_token, gen,
verbosity, layers_output); verbosity, timing_info, layers_output);
} }
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens, float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
@ -1559,10 +1558,10 @@ void GemmaImpl<ConfigGemma2B>::Generate(
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache, const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, const StreamFunc& stream_token, hwy::ThreadPool& pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
LayersOutputT* layers_output) { TimingInfo& timing_info, LayersOutputT* layers_output) {
HWY_DYNAMIC_DISPATCH(Generate2B) HWY_DYNAMIC_DISPATCH(Generate2B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
kv_cache, pool, stream_token, accept_token, gen, verbosity, kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info,
layers_output); layers_output);
} }
@ -1572,10 +1571,11 @@ void GemmaImpl<ConfigGemma7B>::Generate(
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache, const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, const StreamFunc& stream_token, hwy::ThreadPool& pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
LayersOutputT* layers_output) { TimingInfo& timing_info, LayersOutputT* layers_output) {
HWY_DYNAMIC_DISPATCH(Generate7B) HWY_DYNAMIC_DISPATCH(Generate7B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
kv_cache, pool, stream_token, accept_token, gen, verbosity, layers_output); kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info,
layers_output);
} }
template <> template <>
@ -1584,10 +1584,10 @@ void GemmaImpl<ConfigGriffin2B>::Generate(
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache, const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, const StreamFunc& stream_token, hwy::ThreadPool& pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
LayersOutputT* layers_output) { TimingInfo& timing_info, LayersOutputT* layers_output) {
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B) HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
kv_cache, pool, stream_token, accept_token, gen, verbosity, kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info,
layers_output); layers_output);
} }
@ -1658,23 +1658,25 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity, LayersOutputT* layers_output) { int verbosity, TimingInfo& timing_info,
LayersOutputT* layers_output) {
pool.SetWaitMode(hwy::PoolWaitMode::kSpin); pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt, gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, stream_token, accept_token, start_pos, kv_cache, pool, stream_token, accept_token,
gen, verbosity, layers_output); gen, verbosity, timing_info, layers_output);
pool.SetWaitMode(hwy::PoolWaitMode::kBlock); pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
} }
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, std::mt19937& gen) { const StreamFunc& stream_token, std::mt19937& gen,
TimingInfo& timing_info) {
GenerateGemma( GenerateGemma(
gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens, gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens,
runtime_config.temperature, prompt, start_pos, kv_cache, pool, runtime_config.temperature, prompt, start_pos, kv_cache, pool,
stream_token, [](int) { return true; }, gen, runtime_config.verbosity, stream_token, [](int) { return true; }, gen, runtime_config.verbosity,
/*layers_output=*/nullptr); timing_info, /*layers_output=*/nullptr);
} }
void CompressWeights(gcpp::Model model, const Path& weights, void CompressWeights(gcpp::Model model, const Path& weights,

View File

@ -88,6 +88,11 @@ struct Gemma {
std::unique_ptr<GemmaInterface> impl_; std::unique_ptr<GemmaInterface> impl_;
}; };
struct TimingInfo {
double prefill_tok_sec = 0.0;
double gen_tok_sec = 0.0;
};
KVCache CreateKVCache(Model type); // convenient workaround for now KVCache CreateKVCache(Model type); // convenient workaround for now
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len, KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
size_t conv1d_cache_size, size_t rglru_cache_size); size_t conv1d_cache_size, size_t rglru_cache_size);
@ -104,7 +109,8 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity, LayersOutputT* layers_output = nullptr); int verbosity, TimingInfo& timing_info,
LayersOutputT* layers_output = nullptr);
// Convenience function for the common case: // Convenience function for the common case:
// - Bundle runtime parameters as RuntimeConfig // - Bundle runtime parameters as RuntimeConfig
@ -112,7 +118,8 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, std::mt19937& gen); const StreamFunc& stream_token, std::mt19937& gen,
int verbosity, TimingInfo& timing_info);
void CompressWeights(gcpp::Model model, const Path& weights, void CompressWeights(gcpp::Model model, const Path& weights,
const Path& compressed_weights, hwy::ThreadPool& pool); const Path& compressed_weights, hwy::ThreadPool& pool);

View File

@ -19,6 +19,7 @@
#include <iostream> #include <iostream>
#include <random> #include <random>
#include <string> #include <string>
#include <string_view>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
@ -206,16 +207,16 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
} }
} }
const double time_start = hwy::platform::Now(); TimingInfo timing_info;
GenerateGemma(model, args.max_tokens, args.max_generated_tokens, GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, abs_pos, kv_cache, pool, args.temperature, prompt, abs_pos, kv_cache, pool,
stream_token, accept_token, gen, verbosity); stream_token, accept_token, gen, verbosity, timing_info);
const double time_end = hwy::platform::Now();
const double tok_sec = current_pos / (time_end - time_start);
if (verbosity >= 2) { if (verbosity >= 2) {
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)" std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
<< "\n" << "\n"
<< tok_sec << " tokens / sec" << "\n"; << timing_info.prefill_tok_sec << " prefill tokens / sec"
<< "\n"
<< timing_info.gen_tok_sec << " tokens / sec" << "\n";
} }
std::cout << "\n\n"; std::cout << "\n\n";
} }