Pass most runtime parameters using const RuntimeConfig&

PiperOrigin-RevId: 633572507
This commit is contained in:
Apoorv Reddy 2024-05-14 07:04:14 -07:00 committed by Copybara-Service
parent f1eab987d8
commit eb0b96e0a8
6 changed files with 123 additions and 130 deletions

View File

@ -61,10 +61,17 @@ std::pair<std::string, int> QueryModel(
<< args.temperature; << args.temperature;
} }
gcpp::TimingInfo timing_info; gcpp::TimingInfo timing_info;
GenerateGemma(model, args.max_tokens, args.max_generated_tokens, gcpp::RuntimeConfig runtime_config = {
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool, .max_tokens = args.max_tokens,
stream_token, accept_token, gen, app.verbosity, timing_info, .max_generated_tokens = args.max_generated_tokens,
layers_output); .temperature = args.temperature,
.verbosity = app.verbosity,
.gen = &gen,
.stream_token = stream_token,
.accept_token = accept_token,
};
GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0, kv_cache, pool,
timing_info, layers_output);
return {res, total_tokens}; return {res, total_tokens};
} }

View File

@ -2,6 +2,7 @@
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <ostream> #include <ostream>
#include <random>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <utility> // std::pair #include <utility> // std::pair
@ -88,9 +89,17 @@ std::pair<std::string, int> QueryModel(
<< args.temperature; << args.temperature;
} }
gcpp::TimingInfo timing_info; gcpp::TimingInfo timing_info;
GenerateGemma(model, args.max_tokens, args.max_generated_tokens, gcpp::RuntimeConfig runtime_config = {
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool, .max_tokens = args.max_tokens,
stream_token, accept_token, gen, app.verbosity, timing_info); .max_generated_tokens = args.max_generated_tokens,
.temperature = args.temperature,
.verbosity = app.verbosity,
.gen = &gen,
.stream_token = stream_token,
.accept_token = accept_token,
};
GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0, kv_cache, pool,
timing_info, /*layers_output=*/nullptr);
if (app.verbosity >= 1) { if (app.verbosity >= 1) {
LogSpeedStats(time_start, total_tokens); LogSpeedStats(time_start, total_tokens);
} }

View File

@ -453,12 +453,10 @@ struct GemmaInterface {
virtual const GemmaTokenizer* Tokenizer() const = 0; virtual const GemmaTokenizer* Tokenizer() const = 0;
virtual void Generate(size_t max_tokens, size_t max_generated_tokens, virtual void Generate(const RuntimeConfig& runtime_config,
float temperature, const std::vector<int>& prompt, const std::vector<int>& prompt, size_t start_pos,
size_t start_pos, KVCache& kv_cache, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& pool, const StreamFunc& stream_token, TimingInfo& timing_info,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity, TimingInfo& timing_info,
LayersOutputT* layers_output) = 0; LayersOutputT* layers_output) = 0;
virtual float ComputeCrossEntropy(size_t max_tokens, virtual float ComputeCrossEntropy(size_t max_tokens,
@ -547,12 +545,10 @@ struct GemmaImpl : public GemmaInterface {
const GemmaTokenizer* Tokenizer() const override { return &tokenizer; } const GemmaTokenizer* Tokenizer() const override { return &tokenizer; }
void Generate(size_t max_tokens, size_t max_generated_tokens, void Generate(const RuntimeConfig& runtime_config,
float temperature, const std::vector<int>& prompt, const std::vector<int>& prompt, size_t start_pos,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token, TimingInfo& timing_info, LayersOutputT* layers_output) override;
std::mt19937&, int verbosity, TimingInfo& timing_info,
LayersOutputT* layers_output) override;
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt, float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
@ -1083,12 +1079,10 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
} }
template <class TConfig> template <class TConfig>
void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens, void GenerateImpl(GemmaImpl<TConfig>& gemma,
size_t max_generated_tokens, float temperature, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache, const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, const StreamFunc& stream_token, hwy::ThreadPool& pool, TimingInfo& timing_info,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity, TimingInfo& timing_info,
LayersOutputT* layers_output) { LayersOutputT* layers_output) {
static constexpr size_t kVocabSize = TConfig::kVocabSize; static constexpr size_t kVocabSize = TConfig::kVocabSize;
Activations<TConfig, 1>& activations = *gemma.state.get(); Activations<TConfig, 1>& activations = *gemma.state.get();
@ -1099,6 +1093,8 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
*reinterpret_cast<WeightsT<TConfig>*>(gemma.weights_u8.get()); *reinterpret_cast<WeightsT<TConfig>*>(gemma.weights_u8.get());
size_t prompt_size = prompt.size(); size_t prompt_size = prompt.size();
size_t max_tokens = runtime_config.max_tokens;
size_t max_generated_tokens = runtime_config.max_generated_tokens;
RangeChecks<TConfig>(max_tokens, max_generated_tokens, prompt_size); RangeChecks<TConfig>(max_tokens, max_generated_tokens, prompt_size);
if (pos >= max_tokens) { if (pos >= max_tokens) {
fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos, fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos,
@ -1132,13 +1128,13 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights, Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
prefill_activations, kv_cache, pool); prefill_activations, kv_cache, pool);
for (size_t idx = 0; idx < batch_size; ++idx) { for (size_t idx = 0; idx < batch_size; ++idx) {
if (!stream_token(batch_tokens[idx], 0.0f)) return; if (!runtime_config.stream_token(batch_tokens[idx], 0.0f)) return;
} }
pos += batch_size; pos += batch_size;
pos_offset += batch_size; pos_offset += batch_size;
} }
if (verbosity >= 2) { if (runtime_config.verbosity >= 2) {
const double prefill_end = hwy::platform::Now(); const double prefill_end = hwy::platform::Now();
timing_info.prefill_tok_sec = timing_info.prefill_tok_sec =
static_cast<double>(pos_offset) / (prefill_end - prefill_start); static_cast<double>(pos_offset) / (prefill_end - prefill_start);
@ -1150,7 +1146,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
size_t pos_gen_start = pos_offset; size_t pos_gen_start = pos_offset;
int token = prompt.at(pos_offset); int token = prompt.at(pos_offset);
stream_token(token, 0); runtime_config.stream_token(token, 0);
for (size_t generate_pos = 0; for (size_t generate_pos = 0;
pos < max_tokens && generate_pos < max_generated_tokens; pos < max_tokens && generate_pos < max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) { ++pos, ++pos_offset, ++generate_pos) {
@ -1169,21 +1165,22 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
activations.even_odd.data(), activations.logits.data(), pool); activations.even_odd.data(), activations.logits.data(), pool);
// Barrier: must have all logits so we can subtract max. // Barrier: must have all logits so we can subtract max.
Softmax(activations.logits.data(), kVocabSize); Softmax(activations.logits.data(), kVocabSize);
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize, token = SampleTopK<TConfig::kTopK>(
gen, temperature, accept_token); activations.logits.data(), kVocabSize, *runtime_config.gen,
if (!stream_token(token, activations.logits[token])) { runtime_config.temperature, runtime_config.accept_token);
if (!runtime_config.stream_token(token, activations.logits[token])) {
token = EOS_ID; token = EOS_ID;
} }
} else { } else {
// We would take this branch if we were not doing Prefill but would // We would take this branch if we were not doing Prefill but would
// process the tokens of the prompt one at a time. // process the tokens of the prompt one at a time.
token = prompt.at(pos_offset + 1); token = prompt.at(pos_offset + 1);
if (!stream_token(token, 0)) { if (!runtime_config.stream_token(token, 0)) {
token = EOS_ID; token = EOS_ID;
} }
} }
if (token == EOS_ID) { if (token == EOS_ID) {
if (verbosity >= 2) { if (runtime_config.verbosity >= 2) {
const double gen_end = hwy::platform::Now(); const double gen_end = hwy::platform::Now();
timing_info.gen_tok_sec = timing_info.gen_tok_sec =
static_cast<double>(pos_offset - pos_gen_start) / static_cast<double>(pos_offset - pos_gen_start) /
@ -1259,41 +1256,31 @@ float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
#undef TOKEN #undef TOKEN
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens, void Generate2B(GemmaImpl<ConfigGemma2B>& gemma,
size_t max_generated_tokens, float temperature, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token, TimingInfo& timing_info, LayersOutputT* layers_output) {
std::mt19937& gen, int verbosity, TimingInfo& timing_info, GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
LayersOutputT* layers_output) { timing_info, layers_output);
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, stream_token, accept_token, gen,
verbosity, timing_info, layers_output);
} }
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens, void Generate7B(GemmaImpl<ConfigGemma7B>& gemma,
size_t max_generated_tokens, float temperature, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token, TimingInfo& timing_info, LayersOutputT* layers_output) {
std::mt19937& gen, int verbosity, TimingInfo& timing_info, GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
LayersOutputT* layers_output) { timing_info, layers_output);
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, stream_token, accept_token, gen,
verbosity, timing_info, layers_output);
} }
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens, void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma,
size_t max_generated_tokens, float temperature, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, TimingInfo& timing_info, LayersOutputT* layers_output) {
const AcceptFunc& accept_token, std::mt19937& gen, GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
int verbosity, TimingInfo& timing_info, timing_info, layers_output);
LayersOutputT* layers_output) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, stream_token, accept_token, gen,
verbosity, timing_info, layers_output);
} }
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens, float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
@ -1553,41 +1540,38 @@ GemmaImpl<Config>::GemmaImpl(
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {} state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {}
template <> template <>
void GemmaImpl<ConfigGemma2B>::Generate( void GemmaImpl<ConfigGemma2B>::Generate(const RuntimeConfig& runtime_config,
size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector<int>& prompt,
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, const StreamFunc& stream_token, hwy::ThreadPool& pool,
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, TimingInfo& timing_info,
TimingInfo& timing_info, LayersOutputT* layers_output) { LayersOutputT* layers_output) {
HWY_DYNAMIC_DISPATCH(Generate2B) HWY_DYNAMIC_DISPATCH(Generate2B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, (*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info,
layers_output); layers_output);
} }
template <> template <>
void GemmaImpl<ConfigGemma7B>::Generate( void GemmaImpl<ConfigGemma7B>::Generate(const RuntimeConfig& runtime_config,
size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector<int>& prompt,
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, const StreamFunc& stream_token, hwy::ThreadPool& pool,
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, TimingInfo& timing_info,
TimingInfo& timing_info, LayersOutputT* layers_output) { LayersOutputT* layers_output) {
HWY_DYNAMIC_DISPATCH(Generate7B) HWY_DYNAMIC_DISPATCH(Generate7B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, (*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info,
layers_output); layers_output);
} }
template <> template <>
void GemmaImpl<ConfigGriffin2B>::Generate( void GemmaImpl<ConfigGriffin2B>::Generate(const RuntimeConfig& runtime_config,
size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector<int>& prompt,
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, const StreamFunc& stream_token, hwy::ThreadPool& pool,
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, TimingInfo& timing_info,
TimingInfo& timing_info, LayersOutputT* layers_output) { LayersOutputT* layers_output) {
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B) HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, (*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info,
layers_output); layers_output);
} }
@ -1653,30 +1637,15 @@ Gemma::~Gemma() = default; // after GemmaInterface is defined
const GemmaTokenizer* Gemma::Tokenizer() const { return impl_->Tokenizer(); } const GemmaTokenizer* Gemma::Tokenizer() const { return impl_->Tokenizer(); }
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
float temperature, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity, TimingInfo& timing_info,
LayersOutputT* layers_output) {
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, stream_token, accept_token,
gen, verbosity, timing_info, layers_output);
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
}
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, std::mt19937& gen, TimingInfo& timing_info,
TimingInfo& timing_info) { LayersOutputT* layers_output) {
GenerateGemma( pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens, gemma.impl_->Generate(runtime_config, prompt, start_pos, kv_cache, pool,
runtime_config.temperature, prompt, start_pos, kv_cache, pool, timing_info, layers_output);
stream_token, [](int) { return true; }, gen, runtime_config.verbosity, pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
timing_info, /*layers_output=*/nullptr);
} }
void CompressWeights(gcpp::Model model, const Path& weights, void CompressWeights(gcpp::Model model, const Path& weights,

View File

@ -60,11 +60,19 @@ enum class ModelTraining { GEMMA_IT, GEMMA_PT };
const char* ParseModelTypeAndTraining(const std::string& model_flag, const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training); Model& model, ModelTraining& training);
// StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f.
using StreamFunc = std::function<bool(int, float)>;
using AcceptFunc = std::function<bool(int)>;
struct RuntimeConfig { struct RuntimeConfig {
size_t max_tokens; size_t max_tokens;
size_t max_generated_tokens; size_t max_generated_tokens;
float temperature; float temperature;
int verbosity; int verbosity;
std::mt19937* gen;
const StreamFunc& stream_token;
const AcceptFunc& accept_token;
}; };
struct GemmaInterface; struct GemmaInterface;
@ -97,29 +105,14 @@ KVCache CreateKVCache(Model type); // convenient workaround for now
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len, KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
size_t conv1d_cache_size, size_t rglru_cache_size); size_t conv1d_cache_size, size_t rglru_cache_size);
// StreamFunc is called with (token, probability). For prompt tokens, // Bundle runtime parameters as RuntimeConfig
// probability is 0.0f.
using StreamFunc = std::function<bool(int, float)>;
using AcceptFunc = std::function<bool(int)>;
// layers_output is optional; if set - it will be called with the activations // layers_output is optional; if set - it will be called with the activations
// output after applying each layer. // output after applying each layer.
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
float temperature, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity, TimingInfo& timing_info,
LayersOutputT* layers_output = nullptr);
// Convenience function for the common case:
// - Bundle runtime parameters as RuntimeConfig
// - All tokens accepted
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, std::mt19937& gen, TimingInfo& timing_info,
int verbosity, TimingInfo& timing_info); LayersOutputT* layers_output = nullptr);
void CompressWeights(gcpp::Model model, const Path& weights, void CompressWeights(gcpp::Model model, const Path& weights,
const Path& compressed_weights, hwy::ThreadPool& pool); const Path& compressed_weights, hwy::ThreadPool& pool);

View File

@ -56,11 +56,18 @@ class GemmaTest : public ::testing::Test {
response.push_back(token); response.push_back(token);
return true; return true;
}; };
gcpp::GenerateGemma( gcpp::RuntimeConfig runtime_config = {
model, /*max_tokens=*/3072, /*max_generated_tokens=*/2048, .max_tokens = 3072,
/*temperature=*/1.0, prompt, /*start_pos=*/0, kv_cache, pool, .max_generated_tokens = 2048,
stream_token, /*accept=*/[](int) { return true; }, gen, .temperature = 1.0,
/*verbosity=*/0); .verbosity = 0,
.gen = &gen,
.stream_token = stream_token,
.accept_token = [](int) { return true; },
};
gcpp::TimingInfo timing_info;
gcpp::GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0,
kv_cache, pool, timing_info, /*layers_output=*/nullptr);
std::string response_text; std::string response_text;
HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text)); HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text));
return response_text; return response_text;

View File

@ -208,9 +208,17 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
} }
TimingInfo timing_info; TimingInfo timing_info;
GenerateGemma(model, args.max_tokens, args.max_generated_tokens, gcpp::RuntimeConfig runtime_config = {
args.temperature, prompt, abs_pos, kv_cache, pool, .max_tokens = args.max_tokens,
stream_token, accept_token, gen, verbosity, timing_info); .max_generated_tokens = args.max_generated_tokens,
.temperature = args.temperature,
.verbosity = verbosity,
.gen = &gen,
.stream_token = stream_token,
.accept_token = accept_token,
};
GenerateGemma(model, runtime_config, prompt, abs_pos, kv_cache, pool,
timing_info);
if (verbosity >= 2) { if (verbosity >= 2) {
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)" std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
<< "\n" << "\n"