Pass most runtime parameters using const RuntimeConfig&

PiperOrigin-RevId: 633572507
This commit is contained in:
Apoorv Reddy 2024-05-14 07:04:14 -07:00 committed by Copybara-Service
parent f1eab987d8
commit eb0b96e0a8
6 changed files with 123 additions and 130 deletions

View File

@ -61,10 +61,17 @@ std::pair<std::string, int> QueryModel(
<< args.temperature;
}
gcpp::TimingInfo timing_info;
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
stream_token, accept_token, gen, app.verbosity, timing_info,
layers_output);
gcpp::RuntimeConfig runtime_config = {
.max_tokens = args.max_tokens,
.max_generated_tokens = args.max_generated_tokens,
.temperature = args.temperature,
.verbosity = app.verbosity,
.gen = &gen,
.stream_token = stream_token,
.accept_token = accept_token,
};
GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0, kv_cache, pool,
timing_info, layers_output);
return {res, total_tokens};
}

View File

@ -2,6 +2,7 @@
#include <fstream>
#include <iostream>
#include <ostream>
#include <random>
#include <sstream>
#include <string>
#include <utility> // std::pair
@ -88,9 +89,17 @@ std::pair<std::string, int> QueryModel(
<< args.temperature;
}
gcpp::TimingInfo timing_info;
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
stream_token, accept_token, gen, app.verbosity, timing_info);
gcpp::RuntimeConfig runtime_config = {
.max_tokens = args.max_tokens,
.max_generated_tokens = args.max_generated_tokens,
.temperature = args.temperature,
.verbosity = app.verbosity,
.gen = &gen,
.stream_token = stream_token,
.accept_token = accept_token,
};
GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0, kv_cache, pool,
timing_info, /*layers_output=*/nullptr);
if (app.verbosity >= 1) {
LogSpeedStats(time_start, total_tokens);
}

View File

@ -453,12 +453,10 @@ struct GemmaInterface {
virtual const GemmaTokenizer* Tokenizer() const = 0;
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity, TimingInfo& timing_info,
virtual void Generate(const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info,
LayersOutputT* layers_output) = 0;
virtual float ComputeCrossEntropy(size_t max_tokens,
@ -547,12 +545,10 @@ struct GemmaImpl : public GemmaInterface {
const GemmaTokenizer* Tokenizer() const override { return &tokenizer; }
void Generate(size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937&, int verbosity, TimingInfo& timing_info,
LayersOutputT* layers_output) override;
void Generate(const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info, LayersOutputT* layers_output) override;
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
KVCache& kv_cache, hwy::ThreadPool& pool,
@ -1083,12 +1079,10 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
}
template <class TConfig>
void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature,
void GenerateImpl(GemmaImpl<TConfig>& gemma,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity, TimingInfo& timing_info,
hwy::ThreadPool& pool, TimingInfo& timing_info,
LayersOutputT* layers_output) {
static constexpr size_t kVocabSize = TConfig::kVocabSize;
Activations<TConfig, 1>& activations = *gemma.state.get();
@ -1099,6 +1093,8 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
*reinterpret_cast<WeightsT<TConfig>*>(gemma.weights_u8.get());
size_t prompt_size = prompt.size();
size_t max_tokens = runtime_config.max_tokens;
size_t max_generated_tokens = runtime_config.max_generated_tokens;
RangeChecks<TConfig>(max_tokens, max_generated_tokens, prompt_size);
if (pos >= max_tokens) {
fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos,
@ -1132,13 +1128,13 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
prefill_activations, kv_cache, pool);
for (size_t idx = 0; idx < batch_size; ++idx) {
if (!stream_token(batch_tokens[idx], 0.0f)) return;
if (!runtime_config.stream_token(batch_tokens[idx], 0.0f)) return;
}
pos += batch_size;
pos_offset += batch_size;
}
if (verbosity >= 2) {
if (runtime_config.verbosity >= 2) {
const double prefill_end = hwy::platform::Now();
timing_info.prefill_tok_sec =
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
@ -1150,7 +1146,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
size_t pos_gen_start = pos_offset;
int token = prompt.at(pos_offset);
stream_token(token, 0);
runtime_config.stream_token(token, 0);
for (size_t generate_pos = 0;
pos < max_tokens && generate_pos < max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) {
@ -1169,21 +1165,22 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
activations.even_odd.data(), activations.logits.data(), pool);
// Barrier: must have all logits so we can subtract max.
Softmax(activations.logits.data(), kVocabSize);
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
gen, temperature, accept_token);
if (!stream_token(token, activations.logits[token])) {
token = SampleTopK<TConfig::kTopK>(
activations.logits.data(), kVocabSize, *runtime_config.gen,
runtime_config.temperature, runtime_config.accept_token);
if (!runtime_config.stream_token(token, activations.logits[token])) {
token = EOS_ID;
}
} else {
// We would take this branch if we were not doing Prefill but would
// process the tokens of the prompt one at a time.
token = prompt.at(pos_offset + 1);
if (!stream_token(token, 0)) {
if (!runtime_config.stream_token(token, 0)) {
token = EOS_ID;
}
}
if (token == EOS_ID) {
if (verbosity >= 2) {
if (runtime_config.verbosity >= 2) {
const double gen_end = hwy::platform::Now();
timing_info.gen_tok_sec =
static_cast<double>(pos_offset - pos_gen_start) /
@ -1259,41 +1256,31 @@ float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
#undef TOKEN
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature,
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity, TimingInfo& timing_info,
LayersOutputT* layers_output) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, stream_token, accept_token, gen,
verbosity, timing_info, layers_output);
TimingInfo& timing_info, LayersOutputT* layers_output) {
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
timing_info, layers_output);
}
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature,
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity, TimingInfo& timing_info,
LayersOutputT* layers_output) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, stream_token, accept_token, gen,
verbosity, timing_info, layers_output);
TimingInfo& timing_info, LayersOutputT* layers_output) {
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
timing_info, layers_output);
}
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature,
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity, TimingInfo& timing_info,
LayersOutputT* layers_output) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, stream_token, accept_token, gen,
verbosity, timing_info, layers_output);
TimingInfo& timing_info, LayersOutputT* layers_output) {
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
timing_info, layers_output);
}
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
@ -1553,41 +1540,38 @@ GemmaImpl<Config>::GemmaImpl(
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {}
template <>
void GemmaImpl<ConfigGemma2B>::Generate(
size_t max_tokens, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
TimingInfo& timing_info, LayersOutputT* layers_output) {
void GemmaImpl<ConfigGemma2B>::Generate(const RuntimeConfig& runtime_config,
const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool,
TimingInfo& timing_info,
LayersOutputT* layers_output) {
HWY_DYNAMIC_DISPATCH(Generate2B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info,
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
layers_output);
}
template <>
void GemmaImpl<ConfigGemma7B>::Generate(
size_t max_tokens, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
TimingInfo& timing_info, LayersOutputT* layers_output) {
void GemmaImpl<ConfigGemma7B>::Generate(const RuntimeConfig& runtime_config,
const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool,
TimingInfo& timing_info,
LayersOutputT* layers_output) {
HWY_DYNAMIC_DISPATCH(Generate7B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info,
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
layers_output);
}
template <>
void GemmaImpl<ConfigGriffin2B>::Generate(
size_t max_tokens, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
TimingInfo& timing_info, LayersOutputT* layers_output) {
void GemmaImpl<ConfigGriffin2B>::Generate(const RuntimeConfig& runtime_config,
const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool,
TimingInfo& timing_info,
LayersOutputT* layers_output) {
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info,
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
layers_output);
}
@ -1653,30 +1637,15 @@ Gemma::~Gemma() = default; // after GemmaInterface is defined
const GemmaTokenizer* Gemma::Tokenizer() const { return impl_->Tokenizer(); }
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity, TimingInfo& timing_info,
LayersOutputT* layers_output) {
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, stream_token, accept_token,
gen, verbosity, timing_info, layers_output);
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
}
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, std::mt19937& gen,
TimingInfo& timing_info) {
GenerateGemma(
gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens,
runtime_config.temperature, prompt, start_pos, kv_cache, pool,
stream_token, [](int) { return true; }, gen, runtime_config.verbosity,
timing_info, /*layers_output=*/nullptr);
TimingInfo& timing_info,
LayersOutputT* layers_output) {
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
gemma.impl_->Generate(runtime_config, prompt, start_pos, kv_cache, pool,
timing_info, layers_output);
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
}
void CompressWeights(gcpp::Model model, const Path& weights,

View File

@ -60,11 +60,19 @@ enum class ModelTraining { GEMMA_IT, GEMMA_PT };
const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training);
// StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f.
using StreamFunc = std::function<bool(int, float)>;
using AcceptFunc = std::function<bool(int)>;
struct RuntimeConfig {
size_t max_tokens;
size_t max_generated_tokens;
float temperature;
int verbosity;
std::mt19937* gen;
const StreamFunc& stream_token;
const AcceptFunc& accept_token;
};
struct GemmaInterface;
@ -97,29 +105,14 @@ KVCache CreateKVCache(Model type); // convenient workaround for now
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
size_t conv1d_cache_size, size_t rglru_cache_size);
// StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f.
using StreamFunc = std::function<bool(int, float)>;
using AcceptFunc = std::function<bool(int)>;
// Bundle runtime parameters as RuntimeConfig
// layers_output is optional; if set - it will be called with the activations
// output after applying each layer.
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity, TimingInfo& timing_info,
LayersOutputT* layers_output = nullptr);
// Convenience function for the common case:
// - Bundle runtime parameters as RuntimeConfig
// - All tokens accepted
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, std::mt19937& gen,
int verbosity, TimingInfo& timing_info);
TimingInfo& timing_info,
LayersOutputT* layers_output = nullptr);
void CompressWeights(gcpp::Model model, const Path& weights,
const Path& compressed_weights, hwy::ThreadPool& pool);

View File

@ -56,11 +56,18 @@ class GemmaTest : public ::testing::Test {
response.push_back(token);
return true;
};
gcpp::GenerateGemma(
model, /*max_tokens=*/3072, /*max_generated_tokens=*/2048,
/*temperature=*/1.0, prompt, /*start_pos=*/0, kv_cache, pool,
stream_token, /*accept=*/[](int) { return true; }, gen,
/*verbosity=*/0);
gcpp::RuntimeConfig runtime_config = {
.max_tokens = 3072,
.max_generated_tokens = 2048,
.temperature = 1.0,
.verbosity = 0,
.gen = &gen,
.stream_token = stream_token,
.accept_token = [](int) { return true; },
};
gcpp::TimingInfo timing_info;
gcpp::GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0,
kv_cache, pool, timing_info, /*layers_output=*/nullptr);
std::string response_text;
HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text));
return response_text;

View File

@ -208,9 +208,17 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
}
TimingInfo timing_info;
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, abs_pos, kv_cache, pool,
stream_token, accept_token, gen, verbosity, timing_info);
gcpp::RuntimeConfig runtime_config = {
.max_tokens = args.max_tokens,
.max_generated_tokens = args.max_generated_tokens,
.temperature = args.temperature,
.verbosity = verbosity,
.gen = &gen,
.stream_token = stream_token,
.accept_token = accept_token,
};
GenerateGemma(model, runtime_config, prompt, abs_pos, kv_cache, pool,
timing_info);
if (verbosity >= 2) {
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
<< "\n"