diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 1a4beed..67953dc 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -66,7 +66,7 @@ int main(int argc, char** argv) { inference.multiturn = false; GenerateGemma( - model, inference, tokens, 0, pool, inner_pool, stream_token, + model, /*max_tokens=*/2048, /*max_generated_tokens=*/1024, /*temperature=*/1.0, tokens, 0, pool, inner_pool, stream_token, [](int) {return true;}, gen, 0); std::cout << std::endl; diff --git a/gemma.cc b/gemma.cc index dcd87ed..d49d0df 100644 --- a/gemma.cc +++ b/gemma.cc @@ -233,9 +233,10 @@ struct GemmaInterface { virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0; - virtual void Generate(const InferenceArgs& args, - const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + virtual void Generate(size_t max_tokens, size_t max_generated_tokens, + float temperature, const std::vector& prompt, + size_t start_pos, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) = 0; @@ -258,7 +259,8 @@ struct GemmaImpl : public GemmaInterface { return tokenizer.get(); } - void Generate(const InferenceArgs& args, const std::vector& prompt, + void Generate(size_t max_tokens, size_t max_generated_tokens, + float temperature, const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937&, int verbosity); @@ -295,7 +297,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, static constexpr size_t kModelDim = gcpp::Activations::kModelDim; static constexpr size_t kHeads = TConfig::kHeads; - static const float kQueryScale = static_cast(1.0 / sqrt(static_cast(kQKVDim))); + static const float kQueryScale = + static_cast(1.0 / sqrt(static_cast(kQKVDim))); pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { // linear projections to QKV @@ -418,7 +421,8 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, hwy::ThreadPool& inner_pool) { PROFILER_ZONE("Gen.Prefill\\Att\\FFW"); static constexpr size_t kModelDim = TConfig::kModelDim; - static const float kEmbScaling = static_cast(sqrt(static_cast(kModelDim))); + static const float kEmbScaling = + static_cast(sqrt(static_cast(kModelDim))); pool.Run( 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { @@ -473,7 +477,8 @@ void Transformer(int token, size_t pos, static constexpr size_t kLayers = TConfig::kLayers; static constexpr size_t kModelDim = TConfig::kModelDim; - static const float kEmbScaling = static_cast(sqrt(static_cast(kModelDim))); + static const float kEmbScaling = + static_cast(sqrt(static_cast(kModelDim))); Decompress(c_weights.c_embedder_input_embedding, token * kModelDim, activations.x.data(), kModelDim); @@ -604,24 +609,26 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, } } -void Generate2B(GemmaImpl& gemma, const InferenceArgs& args, +void Generate2B(GemmaImpl& gemma, size_t max_tokens, + size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { - GenerateImpl(gemma, args.max_tokens, args.max_generated_tokens, - args.temperature, prompt, start_pos, pool, inner_pool, - stream_token, accept_token, gen, verbosity); + GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, + start_pos, pool, inner_pool, stream_token, accept_token, gen, + verbosity); } -void Generate7B(GemmaImpl& gemma, const InferenceArgs& args, +void Generate7B(GemmaImpl& gemma, size_t max_tokens, + size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { - GenerateImpl(gemma, args.max_tokens, args.max_generated_tokens, - args.temperature, prompt, start_pos, pool, inner_pool, - stream_token, accept_token, gen, verbosity); + GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, + start_pos, pool, inner_pool, stream_token, accept_token, gen, + verbosity); } // Calls func(name, float*, CompressedArray&) for each tensor. float* is null @@ -755,28 +762,24 @@ GemmaImpl::GemmaImpl( } template <> -void GemmaImpl::Generate(const InferenceArgs& args, - const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, - const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { +void GemmaImpl::Generate( + size_t max_tokens, size_t max_generated_tokens, float temperature, + const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { HWY_DYNAMIC_DISPATCH(Generate2B) - (*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token, - gen, verbosity); + (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, + pool, inner_pool, stream_token, accept_token, gen, verbosity); } template <> -void GemmaImpl::Generate(const InferenceArgs& args, - const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, - const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { +void GemmaImpl::Generate( + size_t max_tokens, size_t max_generated_tokens, float temperature, + const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { HWY_DYNAMIC_DISPATCH(Generate7B) - (*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token, - gen, verbosity); + (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, + pool, inner_pool, stream_token, accept_token, gen, verbosity); } Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { @@ -807,15 +810,16 @@ const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { return impl_->Tokenizer(); } -void GenerateGemma(Gemma& gemma, const InferenceArgs& args, - const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, +void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, + float temperature, const std::vector& prompt, + size_t start_pos, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { pool.SetWaitMode(hwy::PoolWaitMode::kSpin); - gemma.impl_->Generate(args, prompt, start_pos, pool, inner_pool, stream_token, - accept_token, gen, verbosity); + gemma.impl_->Generate(max_tokens, max_generated_tokens, + temperature, prompt, start_pos, pool, inner_pool, + stream_token, accept_token, gen, verbosity); pool.SetWaitMode(hwy::PoolWaitMode::kBlock); } diff --git a/gemma.h b/gemma.h index 3de9f0e..5a2f2b0 100644 --- a/gemma.h +++ b/gemma.h @@ -156,9 +156,6 @@ struct Gemma { Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. - // TODO: cleanup - // const sentencepiece::SentencePieceProcessor& Tokenizer() const; - // const std::unique_ptr Tokenizer() const; const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; @@ -205,15 +202,16 @@ struct InferenceArgs : public ArgsBase { "Make top-k sampling deterministic", 2); visitor(multiturn, "multiturn", false, "Multiturn mode\n 0 = clear KV cache after every " - "interaction\n 1 = continue KV cache after every interaction\n Default : 0 (conversation " + "interaction\n 1 = continue KV cache after every interaction\n " + " Default : 0 (conversation " "resets every turn)"); } }; -void GenerateGemma(Gemma& gemma, const InferenceArgs& args, - const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, +void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, + float temperature, const std::vector& prompt, + size_t start_pos, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& g, int verbosity); diff --git a/run.cc b/run.cc index 71481f9..eac5f9e 100644 --- a/run.cc +++ b/run.cc @@ -204,8 +204,9 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, std::cerr << std::endl << "[ Reading prompt ] " << std::flush; const double time_start = hwy::platform::Now(); - GenerateGemma(model, args, prompt, abs_pos, pool, inner_pool, stream_token, - accept_token, gen, verbosity); + GenerateGemma(model, args.max_tokens, args.max_generated_tokens, + args.temperature, prompt, abs_pos, pool, inner_pool, + stream_token, accept_token, gen, verbosity); const double time_end = hwy::platform::Now(); const double tok_sec = current_pos / (time_end - time_start); if (verbosity >= 2) {