diff --git a/gemma/benchmark_helper.cc b/gemma/benchmark_helper.cc index 34a2d26..d83ab23 100644 --- a/gemma/benchmark_helper.cc +++ b/gemma/benchmark_helper.cc @@ -53,15 +53,12 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) { } } -GemmaEnv::GemmaEnv(int argc, char** argv) - : loader_(argc, argv), - inference_args_(argc, argv), - app_(argc, argv), +GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, + const AppArgs& app) + : loader_(loader), + inference_args_(inference), + app_(app), pool_(app_.num_threads) { - { - // Placeholder for internal init, do not modify. - } - // For many-core, pinning workers to cores helps. if (app_.num_threads > 10) { gcpp::PinWorkersToCores(pool_); @@ -89,6 +86,15 @@ GemmaEnv::GemmaEnv(int argc, char** argv) }; } +// Note: the delegating ctor above is called before any other initializers here. +GemmaEnv::GemmaEnv(int argc, char** argv) + : GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv), + AppArgs(argc, argv)) { + { // So that indentation matches expectations. + // Placeholder for internal init, do not modify. + } +} + std::pair GemmaEnv::QueryModel( const std::vector& tokens) { std::string res; @@ -98,10 +104,7 @@ std::pair GemmaEnv::QueryModel( const StreamFunc stream_token = [&res, &total_tokens, &time_start, this]( int token, float) { ++total_tokens; - std::string token_text; - HWY_ASSERT( - model_->Tokenizer().Decode(std::vector{token}, &token_text)); - res += token_text; + res += StringFromTokens(std::vector{token}); if (app_.verbosity >= 1 && total_tokens % 128 == 0) { LogSpeedStats(time_start, total_tokens); } diff --git a/gemma/benchmark_helper.h b/gemma/benchmark_helper.h index ac6eef6..3c5b63a 100644 --- a/gemma/benchmark_helper.h +++ b/gemma/benchmark_helper.h @@ -37,7 +37,10 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen); // Convenience class to load a model and run inference. class GemmaEnv { public: + // Calls the other constructor with *Args arguments initialized from argv. GemmaEnv(int argc, char** argv); + GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, + const AppArgs& app); size_t MaxTokens() const { return inference_args_.max_tokens; } // Sets the maximum number of output tokens to generate. @@ -81,7 +84,8 @@ class GemmaEnv { return loader_.ModelTrainingType(); } int Verbosity() const { return app_.verbosity; } - gcpp::RuntimeConfig& MutableConfig() { return runtime_config_; } + RuntimeConfig& MutableConfig() { return runtime_config_; } + InferenceArgs& MutableInferenceArgs() { return inference_args_; } std::mt19937& MutableGen() { return gen_; } KVCache& MutableKVCache() { return kv_cache_; } @@ -100,15 +104,14 @@ class GemmaEnv { std::unique_ptr model_; // The KV cache to use for inference. KVCache kv_cache_; - gcpp::RuntimeConfig runtime_config_; + RuntimeConfig runtime_config_; }; // Logs the inference speed in tokens/sec. void LogSpeedStats(double time_start, size_t total_tokens); void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app); -void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, - gcpp::AppArgs& app); +void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app); } // namespace gcpp