diff --git a/evals/benchmark.cc b/evals/benchmark.cc index bbabc7e..4495c57 100644 --- a/evals/benchmark.cc +++ b/evals/benchmark.cc @@ -129,7 +129,7 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text, std::vector prompt_slice(prompt.begin() + pos, prompt.begin() + pos + num_tokens); KVCache kv_cache = KVCache::Create( - env.Info().model, env.MutableInferenceArgs().prefill_tbatch_size); + env.GetModel()->Info().model, env.MutableConfig().prefill_tbatch_size); float entropy = ComputeCrossEntropy( *env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity()); total_entropy += entropy; @@ -186,7 +186,9 @@ int main(int argc, char** argv) { if (!benchmark_args.goldens.Empty()) { const std::string golden_path = benchmark_args.goldens.path + "/" + - gcpp::ModelString(env.Info().model, env.Info().training) + ".txt"; + gcpp::ModelString(env.GetModel()->Info().model, + env.GetModel()->Info().training) + + ".txt"; return BenchmarkGoldens(env, golden_path); } else if (!benchmark_args.summarize_text.Empty()) { return BenchmarkSummary(env, benchmark_args.summarize_text); diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 67e5c8a..d27b4eb 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -58,32 +58,27 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) { GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, const AppArgs& app) - : loader_(loader), - inference_args_(inference), - app_(app), - pools_(app_.max_clusters, app_.num_threads) { - AbortIfInvalidArgs(inference_args_); - - if (const char* err = loader_.Validate()) { - loader_.Help(); + : pools_(app.max_clusters, app.num_threads, app.pin) { + InferenceArgs mutable_inference = inference; + AbortIfInvalidArgs(mutable_inference); + LoaderArgs mutable_loader = loader; + if (const char* err = mutable_loader.Validate()) { + mutable_loader.Help(); fprintf(stderr, "Skipping model load because: %s\n", err); } else { fprintf(stderr, "Loading model...\n"); - model_ = AllocateGemma(loader_, pools_); - + model_ = AllocateGemma(mutable_loader, pools_); // Only allocate one for starters because GenerateBatch might not be called. kv_caches_.resize(1); kv_caches_[0] = KVCache::Create(model_->Info().model, inference.prefill_tbatch_size); } - - InitGenerator(inference_args_, gen_); - + InitGenerator(inference, gen_); runtime_config_ = { - .max_tokens = inference_args_.max_tokens, - .max_generated_tokens = inference_args_.max_generated_tokens, - .temperature = inference_args_.temperature, - .verbosity = app_.verbosity, + .max_tokens = inference.max_tokens, + .max_generated_tokens = inference.max_generated_tokens, + .temperature = inference.temperature, + .verbosity = app.verbosity, .gen = &gen_, }; } @@ -115,20 +110,30 @@ std::pair GemmaEnv::QueryModel( res += StringFromTokens(std::vector{token}); return true; }; - if (app_.verbosity >= 2) { - std::cout << "Max tokens: " << inference_args_.max_tokens + if (runtime_config_.verbosity >= 2) { + std::cout << "Max tokens: " << runtime_config_.max_tokens << "\tmax generated tokens: " - << inference_args_.max_generated_tokens - << "\ttemperature: " << inference_args_.temperature << "\n"; + << runtime_config_.max_generated_tokens + << "\ttemperature: " << runtime_config_.temperature << "\n"; } - gcpp::TimingInfo timing_info { .verbosity = app_.verbosity }; + gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity }; runtime_config_.batch_stream_token = batch_stream_token; model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], timing_info); return {res, total_tokens}; } -std::vector> GemmaEnv::BatchQueryModel2( +void GemmaEnv::QueryModel( + const std::vector& tokens, const StreamFunc& stream_token) { + gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity }; + const StreamFunc previous_stream_token = runtime_config_.stream_token; + runtime_config_.stream_token = stream_token; + model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], + timing_info); + runtime_config_.stream_token = previous_stream_token; +} + +std::vector> GemmaEnv::BatchQueryModel( const QueriesPromptTokens& queries_prompt) { const size_t num_queries = queries_prompt.size(); HWY_ASSERT(num_queries != 0); @@ -144,12 +149,12 @@ std::vector> GemmaEnv::BatchQueryModel2( res[query_index].second += 1; return true; }; - if (app_.verbosity >= 2) { + if (runtime_config_.verbosity >= 2) { fprintf(stderr, "Max tok: %zu max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n", - inference_args_.max_tokens, inference_args_.max_generated_tokens, - inference_args_.temperature, inference_args_.prefill_tbatch_size, - inference_args_.decode_qbatch_size); + runtime_config_.max_tokens, runtime_config_.max_generated_tokens, + runtime_config_.temperature, runtime_config_.prefill_tbatch_size, + runtime_config_.decode_qbatch_size); } // Ensure we have one KVCache per query. @@ -159,13 +164,12 @@ std::vector> GemmaEnv::BatchQueryModel2( for (size_t i = 1; i < num_queries; ++i) { if (kv_caches_[i].seq_len == 0) { kv_caches_[i] = KVCache::Create(model_->Info().model, - inference_args_.prefill_tbatch_size); + runtime_config_.prefill_tbatch_size); } } - gcpp::TimingInfo timing_info = {.verbosity = app_.verbosity}; + gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity}; runtime_config_.batch_stream_token = batch_stream_token; - inference_args_.CopyTo(runtime_config_); std::vector queries_pos(num_queries, 0); model_->GenerateBatch(runtime_config_, queries_prompt, QueriesPos(queries_pos.data(), num_queries), @@ -174,8 +178,9 @@ std::vector> GemmaEnv::BatchQueryModel2( } std::pair GemmaEnv::QueryModel(std::string& input) { - const std::vector prompt = WrapAndTokenize(model_->Tokenizer(), Info(), - /*pos=*/0, input); + const std::vector prompt = + WrapAndTokenize(model_->Tokenizer(), model_->Info(), + /*pos=*/0, input); return QueryModel(prompt); } @@ -194,7 +199,7 @@ std::vector> GemmaEnv::BatchQueryModel( prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size())); } QueriesPromptTokens prompt_span(prompt_vector.data(), prompt_vector.size()); - return BatchQueryModel2(prompt_span); + return BatchQueryModel(prompt_span); } float GemmaEnv::CrossEntropy(const std::string& input) { diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index a50f2cd..03a9f07 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -41,10 +41,10 @@ class GemmaEnv { GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, const AppArgs& app); - size_t MaxTokens() const { return inference_args_.max_tokens; } + size_t MaxTokens() const { return runtime_config_.max_tokens; } // Sets the maximum number of output tokens to generate. - void SetMaxGeneratedTokens(size_t max_tokens) { - inference_args_.max_generated_tokens = max_tokens; + void SetMaxGeneratedTokens(size_t max_generated_tokens) { + runtime_config_.max_generated_tokens = max_generated_tokens; } std::vector Tokenize(const std::string& input) const { @@ -68,13 +68,17 @@ class GemmaEnv { // Runs inference on the given input and returns the top-1 result string and // the number of tokens that were generated. std::pair QueryModel(const std::vector& tokens); - std::vector> BatchQueryModel2( + std::vector> BatchQueryModel( const QueriesPromptTokens& queries_prompt); // Adds turn structure to input, tokenizes and calls the above overload. std::pair QueryModel(std::string& input); std::vector> BatchQueryModel( const std::vector& inputs); + // Runs inference on the given input and calls the callback for each token. + void QueryModel(const std::vector& tokens, + const StreamFunc& stream_token); + // Runs inference on the given input and returns the cross entropy, a measure // of how well the model predicts the correct output. It is the average // number of bits per token. @@ -83,20 +87,12 @@ class GemmaEnv { // Returns nullptr if the model failed to load. Gemma* GetModel() const { return model_.get(); } - int Verbosity() const { return app_.verbosity; } + int Verbosity() const { return runtime_config_.verbosity; } RuntimeConfig& MutableConfig() { return runtime_config_; } - const ModelInfo& Info() const { return loader_.Info(); } - InferenceArgs& MutableInferenceArgs() { return inference_args_; } std::mt19937& MutableGen() { return gen_; } KVCache& MutableKVCache() { return kv_caches_[0]; } private: - // Arguments to the model loader: file locations, etc. - LoaderArgs loader_; - // Arguments to the inference function: max tokens, etc. - InferenceArgs inference_args_; - // Controls overall behavior of the app. - AppArgs app_; // Thread pool for running inference. PerClusterPools pools_; // Random number generator. @@ -105,6 +101,7 @@ class GemmaEnv { std::unique_ptr model_; // KV caches, same number as query batch. std::vector kv_caches_; + // Runtime config for inference. RuntimeConfig runtime_config_; }; diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 8856e04..6b12f5c 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -83,7 +83,7 @@ class GemmaTest : public ::testing::Test { prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size())); } QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size()); - for (auto [response, n] : s_env->BatchQueryModel2(prompts)) { + for (auto [response, n] : s_env->BatchQueryModel(prompts)) { replies.push_back(response); } } @@ -116,7 +116,7 @@ class GemmaTest : public ::testing::Test { }; TEST_F(GemmaTest, GeographyBatched) { - s_env->MutableInferenceArgs().decode_qbatch_size = 3; + s_env->MutableConfig().decode_qbatch_size = 3; // 6 are enough to test batching and the loop. static const char* kQA[][2] = { {"What is the capital of Australia?", "Canberra"}, diff --git a/evals/run_mmlu.cc b/evals/run_mmlu.cc index b061f68..729f7e9 100644 --- a/evals/run_mmlu.cc +++ b/evals/run_mmlu.cc @@ -104,7 +104,7 @@ void Run(GemmaEnv& env, JsonArgs& json) { "Do not include any justifications or explanations. Reply only with a " "letter."; const std::vector prompt = - WrapAndTokenize(env.GetModel()->Tokenizer(), env.Info(), + WrapAndTokenize(env.GetModel()->Tokenizer(), env.GetModel()->Info(), /*pos=*/0, prompt_string); const size_t prompt_size = prompt.size(); diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 3af2a43..9197e3e 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -98,26 +98,26 @@ struct GenerateBatchT { void Gemma::Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, KVCache& kv_cache, TimingInfo& timing_info) { - pools_.StartSpinning(); + if (runtime_config.use_spinning) pools_.StartSpinning(); CallForModelAndWeight(info_.model, info_.weight, weights_u8_, runtime_config, prompt, pos, kv_cache, pools_, timing_info); - pools_.StopSpinning(); + if (runtime_config.use_spinning) pools_.StopSpinning(); } void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos, const KVCaches& kv_caches, TimingInfo& timing_info) { - pools_.StartSpinning(); + if (runtime_config.use_spinning) pools_.StartSpinning(); CallForModelAndWeight( info_.model, info_.weight, weights_u8_, runtime_config, queries_prompt, queries_pos, kv_caches, pools_, timing_info); - pools_.StopSpinning(); + if (runtime_config.use_spinning) pools_.StopSpinning(); } template diff --git a/gemma/gemma.h b/gemma/gemma.h index ae4d5eb..42f2034 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -27,6 +27,7 @@ #include "gemma/common.h" #include "gemma/kv_cache.h" #include "gemma/tokenizer.h" +#include "util/allocator.h" #include "util/threading.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/timer.h" @@ -74,7 +75,10 @@ using LayersOutputFunc = std::function; +// RuntimeConfig holds configuration for a single generation run. struct RuntimeConfig { + // If not empty, batch_stream_token is called for each token in the batch, + // instead of stream_token. bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const { if (batch_stream_token) { return batch_stream_token(query_idx, pos, token, prob); @@ -82,6 +86,7 @@ struct RuntimeConfig { return stream_token(token, prob); } + // Limits on the number of tokens generated. size_t max_tokens; size_t max_generated_tokens; @@ -91,15 +96,24 @@ struct RuntimeConfig { // Max queries per batch (one token from each) during decode. size_t decode_qbatch_size = 16; - float temperature; - int verbosity; - std::mt19937* gen; + float temperature; // Temperature for sampling. + int verbosity; // Controls verbosity of printed messages. + std::mt19937* gen; // Random number generator used for sampling. + + // Functions operating on the generated tokens. StreamFunc stream_token; BatchStreamFunc batch_stream_token; AcceptFunc accept_token; // if empty, accepts all tokens. SampleFunc sample_func; // if empty, uses SampleTopK. + + // Observer callbacks for intermediate data. LayersOutputFunc layers_output; // if not empty, called after each layer. - ActivationsObserverFunc activations_observer; // if set, called per-layer + ActivationsObserverFunc activations_observer; // if set, called per-layer. + + // Whether to use thread spinning to reduce barrier synchronization latency. + bool use_spinning = true; + + // End-of-sequence token. int eos_id = EOS_ID; }; diff --git a/util/app.h b/util/app.h index bbec5d7..5a09406 100644 --- a/util/app.h +++ b/util/app.h @@ -53,6 +53,7 @@ static inline const char* CompiledConfig() { class AppArgs : public ArgsBase { public: AppArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + AppArgs() { Init(); }; int verbosity; @@ -88,6 +89,13 @@ class AppArgs : public ArgsBase { struct LoaderArgs : public ArgsBase { LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path, + const std::string& model) { + Init(); // Init sets to defaults, so assignments must come after Init(). + tokenizer.path = tokenizer_path; + weights.path = weights_path; + model_type_str = model; + }; // Returns error string or nullptr if OK. const char* Validate() { @@ -168,6 +176,7 @@ static inline std::unique_ptr AllocateGemma(const LoaderArgs& loader, struct InferenceArgs : public ArgsBase { InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + InferenceArgs() { Init(); }; size_t max_tokens; size_t max_generated_tokens;