diff --git a/debug_prompt.cc b/debug_prompt.cc index 62a427e..7b6be33 100644 --- a/debug_prompt.cc +++ b/debug_prompt.cc @@ -61,10 +61,17 @@ std::pair QueryModel( << args.temperature; } gcpp::TimingInfo timing_info; - GenerateGemma(model, args.max_tokens, args.max_generated_tokens, - args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool, - stream_token, accept_token, gen, app.verbosity, timing_info, - layers_output); + gcpp::RuntimeConfig runtime_config = { + .max_tokens = args.max_tokens, + .max_generated_tokens = args.max_generated_tokens, + .temperature = args.temperature, + .verbosity = app.verbosity, + .gen = &gen, + .stream_token = stream_token, + .accept_token = accept_token, + }; + GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0, kv_cache, pool, + timing_info, layers_output); return {res, total_tokens}; } diff --git a/gemma/benchmark.cc b/gemma/benchmark.cc index 60bbcdb..4d0d04f 100644 --- a/gemma/benchmark.cc +++ b/gemma/benchmark.cc @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include // std::pair @@ -88,9 +89,17 @@ std::pair QueryModel( << args.temperature; } gcpp::TimingInfo timing_info; - GenerateGemma(model, args.max_tokens, args.max_generated_tokens, - args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool, - stream_token, accept_token, gen, app.verbosity, timing_info); + gcpp::RuntimeConfig runtime_config = { + .max_tokens = args.max_tokens, + .max_generated_tokens = args.max_generated_tokens, + .temperature = args.temperature, + .verbosity = app.verbosity, + .gen = &gen, + .stream_token = stream_token, + .accept_token = accept_token, + }; + GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0, kv_cache, pool, + timing_info, /*layers_output=*/nullptr); if (app.verbosity >= 1) { LogSpeedStats(time_start, total_tokens); } diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 51fbad8..916241f 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -453,12 +453,10 @@ struct GemmaInterface { virtual const GemmaTokenizer* Tokenizer() const = 0; - virtual void Generate(size_t max_tokens, size_t max_generated_tokens, - float temperature, const std::vector& prompt, - size_t start_pos, KVCache& kv_cache, - hwy::ThreadPool& pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity, TimingInfo& timing_info, + virtual void Generate(const RuntimeConfig& runtime_config, + const std::vector& prompt, size_t start_pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + TimingInfo& timing_info, LayersOutputT* layers_output) = 0; virtual float ComputeCrossEntropy(size_t max_tokens, @@ -547,12 +545,10 @@ struct GemmaImpl : public GemmaInterface { const GemmaTokenizer* Tokenizer() const override { return &tokenizer; } - void Generate(size_t max_tokens, size_t max_generated_tokens, - float temperature, const std::vector& prompt, - size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937&, int verbosity, TimingInfo& timing_info, - LayersOutputT* layers_output) override; + void Generate(const RuntimeConfig& runtime_config, + const std::vector& prompt, size_t start_pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + TimingInfo& timing_info, LayersOutputT* layers_output) override; float ComputeCrossEntropy(size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, hwy::ThreadPool& pool, @@ -1083,12 +1079,10 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens, } template -void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, - size_t max_generated_tokens, float temperature, +void GenerateImpl(GemmaImpl& gemma, + const RuntimeConfig& runtime_config, const std::vector& prompt, size_t pos, KVCache& kv_cache, - hwy::ThreadPool& pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity, TimingInfo& timing_info, + hwy::ThreadPool& pool, TimingInfo& timing_info, LayersOutputT* layers_output) { static constexpr size_t kVocabSize = TConfig::kVocabSize; Activations& activations = *gemma.state.get(); @@ -1099,6 +1093,8 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, *reinterpret_cast*>(gemma.weights_u8.get()); size_t prompt_size = prompt.size(); + size_t max_tokens = runtime_config.max_tokens; + size_t max_generated_tokens = runtime_config.max_generated_tokens; RangeChecks(max_tokens, max_generated_tokens, prompt_size); if (pos >= max_tokens) { fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos, @@ -1132,13 +1128,13 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, Prefill(batch_tokens, batch_size, pos, weights, prefill_activations, kv_cache, pool); for (size_t idx = 0; idx < batch_size; ++idx) { - if (!stream_token(batch_tokens[idx], 0.0f)) return; + if (!runtime_config.stream_token(batch_tokens[idx], 0.0f)) return; } pos += batch_size; pos_offset += batch_size; } - if (verbosity >= 2) { + if (runtime_config.verbosity >= 2) { const double prefill_end = hwy::platform::Now(); timing_info.prefill_tok_sec = static_cast(pos_offset) / (prefill_end - prefill_start); @@ -1150,7 +1146,7 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, size_t pos_gen_start = pos_offset; int token = prompt.at(pos_offset); - stream_token(token, 0); + runtime_config.stream_token(token, 0); for (size_t generate_pos = 0; pos < max_tokens && generate_pos < max_generated_tokens; ++pos, ++pos_offset, ++generate_pos) { @@ -1169,21 +1165,22 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, activations.even_odd.data(), activations.logits.data(), pool); // Barrier: must have all logits so we can subtract max. Softmax(activations.logits.data(), kVocabSize); - token = SampleTopK(activations.logits.data(), kVocabSize, - gen, temperature, accept_token); - if (!stream_token(token, activations.logits[token])) { + token = SampleTopK( + activations.logits.data(), kVocabSize, *runtime_config.gen, + runtime_config.temperature, runtime_config.accept_token); + if (!runtime_config.stream_token(token, activations.logits[token])) { token = EOS_ID; } } else { // We would take this branch if we were not doing Prefill but would // process the tokens of the prompt one at a time. token = prompt.at(pos_offset + 1); - if (!stream_token(token, 0)) { + if (!runtime_config.stream_token(token, 0)) { token = EOS_ID; } } if (token == EOS_ID) { - if (verbosity >= 2) { + if (runtime_config.verbosity >= 2) { const double gen_end = hwy::platform::Now(); timing_info.gen_tok_sec = static_cast(pos_offset - pos_gen_start) / @@ -1259,41 +1256,31 @@ float ComputeCrossEntropyImpl(GemmaImpl& gemma, size_t max_tokens, #undef TOKEN -void Generate2B(GemmaImpl& gemma, size_t max_tokens, - size_t max_generated_tokens, float temperature, +void Generate2B(GemmaImpl& gemma, + const RuntimeConfig& runtime_config, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity, TimingInfo& timing_info, - LayersOutputT* layers_output) { - GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, - start_pos, kv_cache, pool, stream_token, accept_token, gen, - verbosity, timing_info, layers_output); + TimingInfo& timing_info, LayersOutputT* layers_output) { + GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool, + timing_info, layers_output); } -void Generate7B(GemmaImpl& gemma, size_t max_tokens, - size_t max_generated_tokens, float temperature, +void Generate7B(GemmaImpl& gemma, + const RuntimeConfig& runtime_config, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity, TimingInfo& timing_info, - LayersOutputT* layers_output) { - GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, - start_pos, kv_cache, pool, stream_token, accept_token, gen, - verbosity, timing_info, layers_output); + TimingInfo& timing_info, LayersOutputT* layers_output) { + GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool, + timing_info, layers_output); } -void GenerateGriffin2B(GemmaImpl& gemma, size_t max_tokens, - size_t max_generated_tokens, float temperature, +void GenerateGriffin2B(GemmaImpl& gemma, + const RuntimeConfig& runtime_config, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity, TimingInfo& timing_info, - LayersOutputT* layers_output) { - GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, - start_pos, kv_cache, pool, stream_token, accept_token, gen, - verbosity, timing_info, layers_output); + TimingInfo& timing_info, LayersOutputT* layers_output) { + GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool, + timing_info, layers_output); } float ComputeCrossEntropy2B(GemmaImpl& gemma, size_t max_tokens, @@ -1553,41 +1540,38 @@ GemmaImpl::GemmaImpl( state(hwy::MakeUniqueAligned>()) {} template <> -void GemmaImpl::Generate( - size_t max_tokens, size_t max_generated_tokens, float temperature, - const std::vector& prompt, size_t start_pos, KVCache& kv_cache, - hwy::ThreadPool& pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, - TimingInfo& timing_info, LayersOutputT* layers_output) { +void GemmaImpl::Generate(const RuntimeConfig& runtime_config, + const std::vector& prompt, + size_t start_pos, KVCache& kv_cache, + hwy::ThreadPool& pool, + TimingInfo& timing_info, + LayersOutputT* layers_output) { HWY_DYNAMIC_DISPATCH(Generate2B) - (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info, + (*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info, layers_output); } template <> -void GemmaImpl::Generate( - size_t max_tokens, size_t max_generated_tokens, float temperature, - const std::vector& prompt, size_t start_pos, KVCache& kv_cache, - hwy::ThreadPool& pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, - TimingInfo& timing_info, LayersOutputT* layers_output) { +void GemmaImpl::Generate(const RuntimeConfig& runtime_config, + const std::vector& prompt, + size_t start_pos, KVCache& kv_cache, + hwy::ThreadPool& pool, + TimingInfo& timing_info, + LayersOutputT* layers_output) { HWY_DYNAMIC_DISPATCH(Generate7B) - (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info, + (*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info, layers_output); } template <> -void GemmaImpl::Generate( - size_t max_tokens, size_t max_generated_tokens, float temperature, - const std::vector& prompt, size_t start_pos, KVCache& kv_cache, - hwy::ThreadPool& pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, - TimingInfo& timing_info, LayersOutputT* layers_output) { +void GemmaImpl::Generate(const RuntimeConfig& runtime_config, + const std::vector& prompt, + size_t start_pos, KVCache& kv_cache, + hwy::ThreadPool& pool, + TimingInfo& timing_info, + LayersOutputT* layers_output) { HWY_DYNAMIC_DISPATCH(GenerateGriffin2B) - (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info, + (*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info, layers_output); } @@ -1653,30 +1637,15 @@ Gemma::~Gemma() = default; // after GemmaInterface is defined const GemmaTokenizer* Gemma::Tokenizer() const { return impl_->Tokenizer(); } -void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, - float temperature, const std::vector& prompt, - size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity, TimingInfo& timing_info, - LayersOutputT* layers_output) { - pool.SetWaitMode(hwy::PoolWaitMode::kSpin); - gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt, - start_pos, kv_cache, pool, stream_token, accept_token, - gen, verbosity, timing_info, layers_output); - pool.SetWaitMode(hwy::PoolWaitMode::kBlock); -} - -void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, +void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - const StreamFunc& stream_token, std::mt19937& gen, - TimingInfo& timing_info) { - GenerateGemma( - gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens, - runtime_config.temperature, prompt, start_pos, kv_cache, pool, - stream_token, [](int) { return true; }, gen, runtime_config.verbosity, - timing_info, /*layers_output=*/nullptr); + TimingInfo& timing_info, + LayersOutputT* layers_output) { + pool.SetWaitMode(hwy::PoolWaitMode::kSpin); + gemma.impl_->Generate(runtime_config, prompt, start_pos, kv_cache, pool, + timing_info, layers_output); + pool.SetWaitMode(hwy::PoolWaitMode::kBlock); } void CompressWeights(gcpp::Model model, const Path& weights, diff --git a/gemma/gemma.h b/gemma/gemma.h index bd43046..c2eb929 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -60,11 +60,19 @@ enum class ModelTraining { GEMMA_IT, GEMMA_PT }; const char* ParseModelTypeAndTraining(const std::string& model_flag, Model& model, ModelTraining& training); +// StreamFunc is called with (token, probability). For prompt tokens, +// probability is 0.0f. +using StreamFunc = std::function; +using AcceptFunc = std::function; + struct RuntimeConfig { size_t max_tokens; size_t max_generated_tokens; float temperature; int verbosity; + std::mt19937* gen; + const StreamFunc& stream_token; + const AcceptFunc& accept_token; }; struct GemmaInterface; @@ -97,29 +105,14 @@ KVCache CreateKVCache(Model type); // convenient workaround for now KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len, size_t conv1d_cache_size, size_t rglru_cache_size); -// StreamFunc is called with (token, probability). For prompt tokens, -// probability is 0.0f. -using StreamFunc = std::function; -using AcceptFunc = std::function; - +// Bundle runtime parameters as RuntimeConfig // layers_output is optional; if set - it will be called with the activations // output after applying each layer. -void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, - float temperature, const std::vector& prompt, - size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity, TimingInfo& timing_info, - LayersOutputT* layers_output = nullptr); - -// Convenience function for the common case: -// - Bundle runtime parameters as RuntimeConfig -// - All tokens accepted -void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, +void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - const StreamFunc& stream_token, std::mt19937& gen, - int verbosity, TimingInfo& timing_info); + TimingInfo& timing_info, + LayersOutputT* layers_output = nullptr); void CompressWeights(gcpp::Model model, const Path& weights, const Path& compressed_weights, hwy::ThreadPool& pool); diff --git a/gemma/gemma_test.cc b/gemma/gemma_test.cc index 258352b..9ede883 100644 --- a/gemma/gemma_test.cc +++ b/gemma/gemma_test.cc @@ -56,11 +56,18 @@ class GemmaTest : public ::testing::Test { response.push_back(token); return true; }; - gcpp::GenerateGemma( - model, /*max_tokens=*/3072, /*max_generated_tokens=*/2048, - /*temperature=*/1.0, prompt, /*start_pos=*/0, kv_cache, pool, - stream_token, /*accept=*/[](int) { return true; }, gen, - /*verbosity=*/0); + gcpp::RuntimeConfig runtime_config = { + .max_tokens = 3072, + .max_generated_tokens = 2048, + .temperature = 1.0, + .verbosity = 0, + .gen = &gen, + .stream_token = stream_token, + .accept_token = [](int) { return true; }, + }; + gcpp::TimingInfo timing_info; + gcpp::GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0, + kv_cache, pool, timing_info, /*layers_output=*/nullptr); std::string response_text; HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text)); return response_text; diff --git a/gemma/run.cc b/gemma/run.cc index f8a8cc8..8377bca 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -208,9 +208,17 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training, } TimingInfo timing_info; - GenerateGemma(model, args.max_tokens, args.max_generated_tokens, - args.temperature, prompt, abs_pos, kv_cache, pool, - stream_token, accept_token, gen, verbosity, timing_info); + gcpp::RuntimeConfig runtime_config = { + .max_tokens = args.max_tokens, + .max_generated_tokens = args.max_generated_tokens, + .temperature = args.temperature, + .verbosity = verbosity, + .gen = &gen, + .stream_token = stream_token, + .accept_token = accept_token, + }; + GenerateGemma(model, runtime_config, prompt, abs_pos, kv_cache, pool, + timing_info); if (verbosity >= 2) { std::cout << current_pos << " tokens (" << abs_pos << " total tokens)" << "\n"