From 0c64987a961eed5214708ea5a981b763c50b6a4e Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Mon, 15 Dec 2025 03:18:11 -0800 Subject: [PATCH] Abort if args are unrecognized, refactor argument passing This catches typos/incorrect usage. Refactor: group Loader/Threading/Inference into GemmaArgs. All *Args ctors now have an extra ConsumedArgs& argument. PiperOrigin-RevId: 844690553 --- BUILD.bazel | 11 +- CMakeLists.txt | 1 + compression/BUILD.bazel | 2 - evals/benchmark.cc | 15 +- evals/benchmark_helper.cc | 70 ++---- evals/benchmark_helper.h | 13 +- evals/benchmarks.cc | 6 +- evals/debug_prompt.cc | 12 +- evals/gemma_batch_bench.cc | 6 +- evals/gemma_test.cc | 7 +- evals/run_mmlu.cc | 16 +- examples/hello_world/run.cc | 16 +- examples/simplified_gemma/gemma.hpp | 17 +- examples/simplified_gemma/run.cc | 26 +-- gemma/api_client.cc | 309 ++++++++++++++++----------- gemma/api_server.cc | 318 ++++++++++++++-------------- gemma/bindings/context.cc | 43 ++-- gemma/bindings/context.h | 22 +- gemma/gemma.cc | 15 +- gemma/gemma.h | 9 +- gemma/gemma_args.h | 63 +++--- gemma/gemma_args_test.cc | 74 +++++++ gemma/run.cc | 44 ++-- io/migrate_weights.cc | 15 +- paligemma/paligemma_test.cc | 9 +- python/gemma_py.cc | 13 +- util/args.h | 70 +++++- util/threading_context.h | 4 +- 28 files changed, 713 insertions(+), 513 deletions(-) create mode 100644 gemma/gemma_args_test.cc diff --git a/BUILD.bazel b/BUILD.bazel index 8b96b16..1dfaa8f 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -556,12 +556,22 @@ cc_library( ":basics", ":configs", ":mat", + ":threading_context", "//io", "@highway//:hwy", "@highway//:profiler", ], ) +cc_test( + name = "gemma_args_test", + srcs = ["gemma/gemma_args_test.cc"], + deps = [ + ":gemma_args", + "@googletest//:gtest_main", # buildcleaner: keep + ], +) + cc_library( name = "gemma_lib", srcs = [ @@ -666,7 +676,6 @@ cc_library( ":gemma_args", ":gemma_lib", ":matmul_env", - ":ops", ":threading_context", ":tokenizer", "@google_benchmark//:benchmark", diff --git a/CMakeLists.txt b/CMakeLists.txt index bab0b90..47d7c4c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -222,6 +222,7 @@ set(GEMMA_TEST_FILES compression/nuq_test.cc compression/sfp_test.cc evals/gemma_test.cc + gemma/gemma_args_test.cc gemma/flash_attention_test.cc gemma/tensor_info_test.cc io/blob_store_test.cc diff --git a/compression/BUILD.bazel b/compression/BUILD.bazel index 7d042da..4221a8d 100644 --- a/compression/BUILD.bazel +++ b/compression/BUILD.bazel @@ -101,13 +101,11 @@ cc_test( # for test_suite. tags = ["hwy_ops_test"], deps = [ - ":distortion", ":int", "@googletest//:gtest_main", # buildcleaner: keep "//:test_util", "@highway//:hwy", "@highway//:hwy_test_util", - "@highway//:nanobenchmark", ], ) diff --git a/evals/benchmark.cc b/evals/benchmark.cc index 4dec9ee..69cd644 100644 --- a/evals/benchmark.cc +++ b/evals/benchmark.cc @@ -23,7 +23,9 @@ using json = nlohmann::json; class BenchmarkArgs : public ArgsBase { public: - BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + BenchmarkArgs(int argc, char* argv[], ConsumedArgs& consumed) { + InitAndParse(argc, argv, consumed); + } Path summarize_text; Path cross_entropy; @@ -127,9 +129,16 @@ int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file, } // namespace gcpp int main(int argc, char** argv) { - gcpp::GemmaEnv env(argc, argv); - gcpp::BenchmarkArgs benchmark_args(argc, argv); + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::GemmaArgs args(argc, argv, consumed); + gcpp::BenchmarkArgs benchmark_args(argc, argv, consumed); + if (gcpp::HasHelp(argc, argv)) { + args.Help(); + return 0; + } + consumed.AbortIfUnconsumed(); + gcpp::GemmaEnv env(args); if (!benchmark_args.summarize_text.Empty()) { return BenchmarkSummary(env, benchmark_args.summarize_text); } else if (!benchmark_args.cross_entropy.Empty()) { diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 6fbd3f3..30d364f 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -36,35 +36,29 @@ namespace gcpp { -GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, - const InferenceArgs& inference) +GemmaEnv::GemmaEnv(const GemmaArgs& args) : initializer_value_(gcpp::InternalInit()), - ctx_(threading), + ctx_(args.threading), env_(ctx_), - gemma_(loader, inference, ctx_) { + gemma_(args, ctx_) { const ModelConfig& config = gemma_.Config(); // Only allocate one for starters because GenerateBatch might not be called. - kv_caches_.push_back(KVCache(config, inference, ctx_.allocator)); + kv_caches_.push_back(KVCache(config, args.inference, ctx_.allocator)); - if (inference.verbosity >= 2) { - ShowConfig(loader, threading, inference, config, gemma_.WeightReadMode(), - ctx_); + if (args.inference.verbosity >= 2) { + ShowConfig(args, config, gemma_.WeightReadMode(), ctx_); } - if (inference.verbosity >= 3) env_.print_best = true; - if (inference.verbosity >= 4) env_.print_config = true; + if (args.inference.verbosity >= 3) env_.print_best = true; + if (args.inference.verbosity >= 4) env_.print_config = true; runtime_config_ = { - .max_generated_tokens = inference.max_generated_tokens, - .temperature = inference.temperature, - .verbosity = inference.verbosity, + .max_generated_tokens = args.inference.max_generated_tokens, + .temperature = args.inference.temperature, + .verbosity = args.inference.verbosity, }; - inference.CopyTo(runtime_config_); + args.inference.CopyTo(runtime_config_); } -GemmaEnv::GemmaEnv(int argc, char** argv) - : GemmaEnv(LoaderArgs(argc, argv), ThreadingArgs(argc, argv), - InferenceArgs(argc, argv)) {} - QueryResult GemmaEnv::QueryModel(const std::vector& tokens) { QueryResult result; @@ -234,19 +228,19 @@ static constexpr const char* CompiledConfig() { } } -void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, - const InferenceArgs& inference, const ModelConfig& config, +void ShowConfig(const GemmaArgs& args, const ModelConfig& config, const WeightsPtrs::Mode weight_read_mode, const ThreadingContext& ctx) { - threading.Print(inference.verbosity); - loader.Print(inference.verbosity); - inference.Print(inference.verbosity); - fprintf( - stderr, "Model : %s, to_bf16 %d, mmap %d => %s\n", - config.Specifier().c_str(), static_cast(loader.to_bf16), - static_cast(loader.map), WeightsPtrs::ToString(weight_read_mode)); + args.threading.Print(args.inference.verbosity); + args.loader.Print(args.inference.verbosity); + args.inference.Print(args.inference.verbosity); + fprintf(stderr, + "Model : %s, to_bf16 %d, mmap %d => %s\n", + config.Specifier().c_str(), static_cast(args.loader.to_bf16), + static_cast(args.loader.map), + WeightsPtrs::ToString(weight_read_mode)); - if (inference.verbosity >= 2) { + if (args.inference.verbosity >= 2) { time_t now = time(nullptr); char* dt = ctime(&now); // NOLINT char cpu100[100] = "unknown"; @@ -259,7 +253,7 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, "Instruction set : %s (%zu bits)\n" "Compiled config : %s, profiler %d\n" "Memory MiB : %4zu\n", - dt, cpu100, static_cast(threading.bind), + dt, cpu100, static_cast(args.threading.bind), ctx.topology.TopologyString(), ctx.pools.PinString(), CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()), ctx.cache_info.VectorBytes() * 8, CompiledConfig(), @@ -267,22 +261,4 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, } } -void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading, - const InferenceArgs& inference) { - std::cerr - << "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n" - "==========================================================\n\n" - "To run with pre-2025 weights, specify --tokenizer and --weights.\n" - "With the single-file weights format, specify just --weights.\n"; - std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm " - "--weights gemma2-2b-it-sfp.sbs\n"; - std::cerr << "\n*Model Loading Arguments*\n\n"; - loader.Help(); - std::cerr << "\n*Threading Arguments*\n\n"; - threading.Help(); - std::cerr << "\n*Inference Arguments*\n\n"; - inference.Help(); - std::cerr << "\n"; -} - } // namespace gcpp diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 3f97c21..203174c 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -23,7 +23,7 @@ #include "gemma/configs.h" #include "gemma/gemma.h" -#include "gemma/gemma_args.h" +#include "gemma/gemma_args.h" // IWYU pragma: export #include "gemma/tokenizer.h" // WrapAndTokenize #include "ops/matmul.h" #include "util/threading_context.h" @@ -50,10 +50,8 @@ struct QueryResultAndMetrics { // Convenience class to load a model and run inference. class GemmaEnv { public: - // Calls the other constructor with *Args arguments initialized from argv. - GemmaEnv(int argc, char** argv); - GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, - const InferenceArgs& inference); + explicit GemmaEnv(const GemmaArgs& args); + MatMulEnv& Env() { return env_; } size_t MaxGeneratedTokens() const { @@ -137,12 +135,9 @@ class GemmaEnv { // Logs the inference speed in tokens/sec. void LogSpeedStats(double time_start, size_t total_tokens); -void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, - const InferenceArgs& inference, const ModelConfig& config, +void ShowConfig(const GemmaArgs& args, const ModelConfig& config, WeightsPtrs::Mode weight_read_mode, const ThreadingContext& ctx); -void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading, - const InferenceArgs& inference); } // namespace gcpp diff --git a/evals/benchmarks.cc b/evals/benchmarks.cc index 3cb3d3f..f44c62b 100644 --- a/evals/benchmarks.cc +++ b/evals/benchmarks.cc @@ -98,7 +98,11 @@ BENCHMARK(BM_coding_prompt) ->UseRealTime(); int main(int argc, char** argv) { - gcpp::GemmaEnv env(argc, argv); + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::GemmaArgs args(argc, argv, consumed); + consumed.AbortIfUnconsumed(); + + gcpp::GemmaEnv env(args); env.SetMaxGeneratedTokens(256); gcpp::s_env = &env; diff --git a/evals/debug_prompt.cc b/evals/debug_prompt.cc index 66fa466..a6cf8c4 100644 --- a/evals/debug_prompt.cc +++ b/evals/debug_prompt.cc @@ -31,7 +31,9 @@ namespace gcpp { class PromptArgs : public ArgsBase { public: - PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + PromptArgs(int argc, char* argv[], ConsumedArgs& consumed) { + InitAndParse(argc, argv, consumed); + } Path layers_output; // optional std::string prompt; @@ -51,11 +53,15 @@ class PromptArgs : public ArgsBase { }; int Run(int argc, char** argv) { - PromptArgs prompt_args(argc, argv); + ConsumedArgs consumed(argc, argv); + const GemmaArgs args(argc, argv, consumed); + const PromptArgs prompt_args(argc, argv, consumed); AbortIfInvalidArgs(prompt_args); + consumed.AbortIfUnconsumed(); json json_output; - GemmaEnv env(argc, argv); + GemmaEnv env(args); + env.MutableConfig().layers_output = prompt_args.layers_output.Empty() ? LayersOutputFunc() diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index 45531ea..dd9cb45 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -146,7 +146,11 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) { int main(int argc, char** argv) { fprintf(stderr, "GemmaEnv setup..\n"); - gcpp::GemmaEnv env(argc, argv); + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::GemmaArgs args(argc, argv, consumed); + consumed.AbortIfUnconsumed(); + + gcpp::GemmaEnv env(args); gcpp::s_env = &env; testing::InitGoogleTest(&argc, argv); diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index abd8c90..a581561 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -22,7 +22,6 @@ #include "evals/benchmark_helper.h" #include "gemma/configs.h" -#include "io/io.h" #include "hwy/base.h" #include "hwy/tests/hwy_gtest.h" @@ -42,7 +41,11 @@ class GemmaTest : public ::testing::Test { // Requires argc/argv, hence do not use `SetUpTestSuite`. static void InitEnv(int argc, char** argv) { HWY_ASSERT(s_env == nullptr); // Should only be called once. - s_env = new GemmaEnv(argc, argv); + ConsumedArgs consumed(argc, argv); + GemmaArgs args(argc, argv, consumed); + consumed.AbortIfUnconsumed(); + + s_env = new GemmaEnv(args); const gcpp::ModelConfig& config = s_env->GetGemma()->Config(); fprintf(stderr, "Using %s\n", config.Specifier().c_str()); } diff --git a/evals/run_mmlu.cc b/evals/run_mmlu.cc index 04a6e00..c6ce972 100644 --- a/evals/run_mmlu.cc +++ b/evals/run_mmlu.cc @@ -31,7 +31,9 @@ namespace gcpp { struct JsonArgs : public ArgsBase { - JsonArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + JsonArgs(int argc, char* argv[], ConsumedArgs& consumed) { + InitAndParse(argc, argv, consumed); + } Path input; @@ -151,10 +153,14 @@ void Run(GemmaEnv& env, JsonArgs& json) { int main(int argc, char** argv) { { PROFILER_ZONE("Startup.all"); - gcpp::GemmaEnv env(argc, argv); - gcpp::JsonArgs json(argc, argv); - gcpp::AbortIfInvalidArgs(json); - gcpp::Run(env, json); + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::GemmaArgs args(argc, argv, consumed); + gcpp::JsonArgs json_args(argc, argv, consumed); + gcpp::AbortIfInvalidArgs(json_args); + consumed.AbortIfUnconsumed(); + + gcpp::GemmaEnv env(args); + gcpp::Run(env, json_args); } PROFILER_PRINT_RESULTS(); // Must call outside the zone above. return 0; diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index f67324d..9cd8b0e 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -24,20 +24,20 @@ #include #include "gemma/gemma.h" -#include "gemma/gemma_args.h" // LoaderArgs +#include "gemma/gemma_args.h" // GemmaArgs #include "gemma/tokenizer.h" #include "util/args.h" #include "util/threading_context.h" #include "hwy/base.h" int main(int argc, char** argv) { - gcpp::LoaderArgs loader(argc, argv); - gcpp::ThreadingArgs threading(argc, argv); - gcpp::InferenceArgs inference(argc, argv); + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::GemmaArgs args(argc, argv, consumed); if (gcpp::HasHelp(argc, argv)) { - loader.Help(); + args.Help(); return 0; } + consumed.AbortIfUnconsumed(); // Demonstrate constrained decoding by never outputting certain tokens. std::set reject_tokens; @@ -49,10 +49,10 @@ int main(int argc, char** argv) { } // Instantiate model and KV Cache - gcpp::ThreadingContext ctx(threading); + gcpp::ThreadingContext ctx(args.threading); gcpp::MatMulEnv env(ctx); - gcpp::Gemma gemma(loader, inference, ctx); - gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator); + gcpp::Gemma gemma(args, ctx); + gcpp::KVCache kv_cache(gemma.Config(), args.inference, ctx.allocator); size_t generated = 0; // Tokenize instructions. diff --git a/examples/simplified_gemma/gemma.hpp b/examples/simplified_gemma/gemma.hpp index e5bb1d8..4a69923 100644 --- a/examples/simplified_gemma/gemma.hpp +++ b/examples/simplified_gemma/gemma.hpp @@ -23,7 +23,7 @@ #include #include "third_party/gemma_cpp/gemma/gemma.h" -#include "third_party/gemma_cpp/gemma/gemma_args.h" // LoaderArgs +#include "third_party/gemma_cpp/gemma/gemma_args.h" // GemmaArgs #include "third_party/gemma_cpp/gemma/tokenizer.h" #include "third_party/gemma_cpp/ops/matmul.h" #include "third_party/gemma_cpp/util/threading_context.h" @@ -31,18 +31,11 @@ class SimplifiedGemma { public: - SimplifiedGemma(const gcpp::LoaderArgs& loader, - const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(), - const gcpp::InferenceArgs& inference = gcpp::InferenceArgs()) - : ctx_(threading), + SimplifiedGemma(const gcpp::GemmaArgs& args) + : ctx_(args.threading), env_(ctx_), - gemma_(loader, inference, ctx_), - kv_cache_(gemma_.Config(), inference, ctx_.allocator) {} - - SimplifiedGemma(int argc, char** argv) - : SimplifiedGemma(gcpp::LoaderArgs(argc, argv), - gcpp::ThreadingArgs(argc, argv), - gcpp::InferenceArgs(argc, argv)) {} + gemma_(args, ctx_), + kv_cache_(gemma_.Config(), args.inference, ctx_.allocator) {} void Generate(std::string& prompt, size_t max_generated_tokens = 1024, float temperature = 0.7, diff --git a/examples/simplified_gemma/run.cc b/examples/simplified_gemma/run.cc index b7af134..58356d2 100644 --- a/examples/simplified_gemma/run.cc +++ b/examples/simplified_gemma/run.cc @@ -18,28 +18,18 @@ #include #include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp" -#include "gemma/gemma_args.h" // LoaderArgs +#include "gemma/gemma_args.h" int main(int argc, char** argv) { - // Standard usage: LoaderArgs takes argc and argv as input, then parses - // necessary flags. - gcpp::LoaderArgs loader(argc, argv); + // Sets arguments from argc and argv. Note that you can instead pass in + // LoaderArgs, ThreadingArgs, and InferenceArgs directly. + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::GemmaArgs args(argc, argv, consumed); + consumed.AbortIfUnconsumed(); - // Optional: LoaderArgs can also take tokenizer and weights paths directly. - // - // gcpp::LoaderArgs loader("/path/to/tokenizer", "/path/to/weights", - // "model_identifier"); - - // Optional: ThreadingArgs and InferenceArgs can be passed in as well. If not - // specified, default values will be used. - // - // gcpp::InferenceArgs inference(argc, argv); - // gcpp::ThreadingArgs threading(argc, argv); - // SimplifiedGemma gemma(loader, threading, inference); - - SimplifiedGemma gemma(loader); + SimplifiedGemma gemma(args); std::string prompt = "Write a greeting to the world."; gemma.Generate(prompt, 256, 0.6); return 0; -} \ No newline at end of file +} diff --git a/gemma/api_client.cc b/gemma/api_client.cc index 1f64d96..e6ce191 100644 --- a/gemma/api_client.cc +++ b/gemma/api_client.cc @@ -15,18 +15,22 @@ // Test client for API server -#include -#include -#include +#include + #include #include +#include +#include +#include #include "httplib.h" -#include "nlohmann/json.hpp" #include "gemma/gemma_args.h" +#include "nlohmann/json.hpp" using json = nlohmann::json; +namespace gcpp { + // ANSI color codes const std::string RESET = "\033[0m"; const std::string BOLD = "\033[1m"; @@ -37,9 +41,15 @@ const std::string YELLOW = "\033[33m"; const std::string RED = "\033[31m"; class APIClient { -public: - APIClient(const std::string& host, int port, const std::string& api_key = "", const std::string& model = "gemma3-4b") - : host_(host), port_(port), api_key_(api_key), model_(model), use_https_(port == 443), interactive_mode_(false) { + public: + APIClient(const std::string& host, int port, const std::string& api_key = "", + const std::string& model = "gemma3-4b") + : host_(host), + port_(port), + api_key_(api_key), + model_(model), + use_https_(port == 443), + interactive_mode_(false) { if (use_https_) { ssl_client_ = std::make_unique(host, port); ssl_client_->set_read_timeout(60, 0); @@ -55,22 +65,25 @@ public: // Unified request processing for both public and local APIs json ProcessRequest(const json& request, bool stream = true) { bool is_public_api = !api_key_.empty(); - + std::string endpoint; if (is_public_api) { - endpoint = stream ? "/v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse" - : "/v1beta/models/gemini-2.0-flash:generateContent"; + endpoint = + stream + ? "/v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse" + : "/v1beta/models/gemini-2.0-flash:generateContent"; } else { - endpoint = stream ? "/v1beta/models/" + model_ + ":streamGenerateContent" + endpoint = stream ? "/v1beta/models/" + model_ + ":streamGenerateContent" : "/v1beta/models/" + model_ + ":generateContent"; } - + // Only show verbose output in non-interactive mode if (!interactive_mode_) { - std::cout << "\n" << BOLD << BLUE << "📤 POST " << endpoint << RESET << std::endl; + std::cout << "\n" + << BOLD << BLUE << "📤 POST " << endpoint << RESET << std::endl; std::cout << "Request: " << request.dump(2) << std::endl; } - + if (stream) { return ProcessStreamingRequest(request, endpoint); } else { @@ -81,21 +94,24 @@ public: void TestGenerateContent(const std::string& prompt, bool stream = true) { json request = CreateAPIRequest(prompt); json response = ProcessRequest(request, stream); - + if (response.contains("error")) { - std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl; + std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET + << std::endl; } } void TestListModels() { - std::cout << "\n" << BOLD << BLUE << "📤 GET /v1beta/models" << RESET << std::endl; - + std::cout << "\n" + << BOLD << BLUE << "📤 GET /v1beta/models" << RESET << std::endl; + httplib::Headers headers; if (!api_key_.empty()) { headers.emplace("X-goog-api-key", api_key_); } - auto res = use_https_ ? ssl_client_->Get("/v1beta/models", headers) : client_->Get("/v1beta/models", headers); - + auto res = use_https_ ? ssl_client_->Get("/v1beta/models", headers) + : client_->Get("/v1beta/models", headers); + if (res && res->status == 200) { json response = json::parse(res->body); std::cout << GREEN << "✅ Available models:" << RESET << std::endl; @@ -106,49 +122,53 @@ public: } void InteractiveChat() { - std::cout << "\n" << BOLD << CYAN << "💬 Interactive Chat Mode (with session)" << RESET << std::endl; + std::cout << "\n" + << BOLD << CYAN << "💬 Interactive Chat Mode (with session)" + << RESET << std::endl; std::cout << "Type ':gemma %q' to end.\n" << std::endl; - + interactive_mode_ = true; json messages; - + while (true) { std::cout << BOLD << BLUE << "You: " << RESET; std::string input; std::getline(std::cin, input); - + if (input == ":gemma %q") { std::cout << BOLD << YELLOW << "👋 Goodbye!" << RESET << std::endl; break; } - + if (input.empty()) continue; - + // Add user message with proper role json user_message = {{"parts", {{{"text", input}}}}}; if (!api_key_.empty()) { user_message["role"] = "user"; } messages.push_back(user_message); - + // Create request using unified logic json request = CreateAPIRequest("", messages); - + std::cout << BOLD << GREEN << "Assistant: " << RESET; - + // Use unified processing - streaming for real-time output json response = ProcessRequest(request, true); - + if (response.contains("candidates") && !response["candidates"].empty()) { auto& candidate = response["candidates"][0]; - if (candidate.contains("content") && candidate["content"].contains("parts")) { + if (candidate.contains("content") && + candidate["content"].contains("parts")) { for (const auto& part : candidate["content"]["parts"]) { if (part.contains("text")) { std::string assistant_response = part["text"].get(); - + // For streaming, the response is already displayed in real-time // Just add to message history for context - json assistant_message = {{"parts", {{{"text", assistant_response}}}}}; + json assistant_message = { + {"parts", {{{"text", assistant_response}}}}}; if (!api_key_.empty()) { assistant_message["role"] = "model"; } @@ -157,23 +177,21 @@ public: } } } else if (response.contains("error")) { - std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl; + std::cerr << RED << "❌ Error: " << response["error"]["message"] + << RESET << std::endl; } - + std::cout << std::endl; } } -private: - json CreateAPIRequest(const std::string& prompt, const json& messages = json::array()) { + private: + json CreateAPIRequest(const std::string& prompt, + const json& messages = json::array()) { json request = { - {"generationConfig", { - {"temperature", 0.9}, - {"topK", 1}, - {"maxOutputTokens", 1024} - }} - }; - + {"generationConfig", + {{"temperature", 0.9}, {"topK", 1}, {"maxOutputTokens", 1024}}}}; + if (messages.empty()) { // Single prompt json user_message = {{"parts", {{{"text", prompt}}}}}; @@ -185,44 +203,48 @@ private: // Use provided message history request["contents"] = messages; } - + return request; } - json ProcessNonStreamingRequest(const json& request, const std::string& endpoint) { + json ProcessNonStreamingRequest(const json& request, + const std::string& endpoint) { httplib::Headers headers = {{"Content-Type", "application/json"}}; if (!api_key_.empty()) { headers.emplace("X-goog-api-key", api_key_); } - - auto res = use_https_ ? ssl_client_->Post(endpoint, headers, request.dump(), "application/json") - : client_->Post(endpoint, headers, request.dump(), "application/json"); - + + auto res = use_https_ ? ssl_client_->Post(endpoint, headers, request.dump(), + "application/json") + : client_->Post(endpoint, headers, request.dump(), + "application/json"); + if (res && res->status == 200) { json response = json::parse(res->body); if (!interactive_mode_) { - std::cout << "\n" << BOLD << GREEN << "📥 Response:" << RESET << std::endl; + std::cout << "\n" + << BOLD << GREEN << "📥 Response:" << RESET << std::endl; std::cout << response.dump(2) << std::endl; } return response; } else { - json error_response = { - {"error", { - {"message", "Request failed"}, - {"status", res ? res->status : -1} - }} - }; + json error_response = {{"error", + {{"message", "Request failed"}, + {"status", res ? res->status : -1}}}}; if (res && !res->body.empty()) { error_response["error"]["details"] = res->body; } - std::cerr << RED << "❌ Request failed. Status: " << (res ? res->status : -1) << RESET << std::endl; + std::cerr << RED + << "❌ Request failed. Status: " << (res ? res->status : -1) + << RESET << std::endl; return error_response; } } - json ProcessStreamingRequest(const json& request, const std::string& endpoint) { + json ProcessStreamingRequest(const json& request, + const std::string& endpoint) { std::string accumulated_response; - + // Use same SSE logic for both public and local APIs httplib::Request req; req.method = "POST"; @@ -232,72 +254,73 @@ private: req.set_header("X-goog-api-key", api_key_); } req.body = request.dump(); - - req.content_receiver = [&accumulated_response, this](const char* data, size_t data_length, uint64_t offset, uint64_t total_length) -> bool { - std::string chunk(data, data_length); - std::istringstream stream(chunk); - std::string line; - - while (std::getline(stream, line)) { - if (line.substr(0, 6) == "data: ") { - std::string event_data = line.substr(6); - - if (event_data == "[DONE]") { - if (!interactive_mode_) { - std::cout << "\n\n" << GREEN << "✅ Generation complete!" << RESET << std::endl; - } - } else { - try { - json event = json::parse(event_data); - if (event.contains("candidates") && !event["candidates"].empty()) { - auto& candidate = event["candidates"][0]; - if (candidate.contains("content") && candidate["content"].contains("parts")) { - for (const auto& part : candidate["content"]["parts"]) { - if (part.contains("text")) { - std::string text = part["text"].get(); - std::cout << text << std::flush; - accumulated_response += text; - } + + req.content_receiver = [&accumulated_response, this]( + const char* data, size_t data_length, + uint64_t offset, uint64_t total_length) -> bool { + std::string chunk(data, data_length); + std::istringstream stream(chunk); + std::string line; + + while (std::getline(stream, line)) { + if (line.substr(0, 6) == "data: ") { + std::string event_data = line.substr(6); + + if (event_data == "[DONE]") { + if (!interactive_mode_) { + std::cout << "\n\n" + << GREEN << "✅ Generation complete!" << RESET + << std::endl; + } + } else { + try { + json event = json::parse(event_data); + if (event.contains("candidates") && + !event["candidates"].empty()) { + auto& candidate = event["candidates"][0]; + if (candidate.contains("content") && + candidate["content"].contains("parts")) { + for (const auto& part : candidate["content"]["parts"]) { + if (part.contains("text")) { + std::string text = part["text"].get(); + std::cout << text << std::flush; + accumulated_response += text; } } } - } catch (const json::exception& e) { - // Skip parse errors } + } catch (const json::exception& e) { + // Skip parse errors } } } - return true; - }; - + } + return true; + }; + httplib::Response res; httplib::Error error; - bool success = use_https_ ? ssl_client_->send(req, res, error) : client_->send(req, res, error); - + bool success = use_https_ ? ssl_client_->send(req, res, error) + : client_->send(req, res, error); + if (res.status == 200 && !accumulated_response.empty()) { return json{ - {"candidates", {{ - {"content", { - {"parts", {{{"text", accumulated_response}}}} - }} - }}} - }; + {"candidates", + {{{"content", {{"parts", {{{"text", accumulated_response}}}}}}}}}}; } else { json error_response = { - {"error", { - {"message", "Streaming request failed"}, - {"status", res.status} - }} - }; + {"error", + {{"message", "Streaming request failed"}, {"status", res.status}}}}; if (!res.body.empty()) { error_response["error"]["details"] = res.body; } - std::cerr << RED << "❌ Streaming request failed. Status: " << res.status << RESET << std::endl; + std::cerr << RED << "❌ Streaming request failed. Status: " << res.status + << RESET << std::endl; return error_response; } } -private: + private: std::unique_ptr client_; std::unique_ptr ssl_client_; std::string host_; @@ -308,19 +331,55 @@ private: bool interactive_mode_; }; +struct ClientArgs : public ArgsBase { + ClientArgs(int argc, char* argv[], ConsumedArgs& consumed) { + InitAndParse(argc, argv, consumed); + } + ClientArgs() { Init(); }; + + std::string host; + int port; + std::string api_key; + std::string model; + std::string prompt; + bool interactive; + + template + void ForEach(const Visitor& visitor) { + visitor(host, "host", std::string("localhost"), + "Server host (default: localhost)"); + visitor(port, "port", 8080, "Server port (default: 8080)"); + visitor(api_key, "api_key", std::string(""), + "Use public API with key (changes host to " + "generativelanguage.googleapis.com:443)"); + visitor(model, "model", std::string("gemma3-4b"), + "Model name to use (default: gemma3-4b)"); + visitor(prompt, "prompt", std::string("Hello! How are you?"), + "Prompt for generation (default: 'Hello! How are you?')"); + visitor(interactive, "interactive", false, + "Start interactive chat mode (0 = no, 1 = yes)"); + } +}; + +} // namespace gcpp + int main(int argc, char* argv[]) { - gcpp::ClientArgs client_args(argc, argv); - + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::ClientArgs client_args(argc, argv, consumed); + if (gcpp::HasHelp(argc, argv)) { - std::cout << "\nAPI Client for gemma.cpp\n"; - std::cout << "========================\n\n"; + fprintf(stderr, + "\nAPI Client for gemma.cpp\n" + "========================\n\n"); client_args.Help(); - std::cout << std::endl; - std::cout << "Environment Variables:" << std::endl; - std::cout << " GOOGLE_API_KEY : Automatically use public Google API if set" << std::endl; + fprintf(stderr, + "\n*Environment Variables:\n" + " GOOGLE_API_KEY : Automatically use public Google API if set\n"); return 0; } - + + consumed.AbortIfUnconsumed(); + // Check for GOOGLE_API_KEY environment variable const char* env_api_key = std::getenv("GOOGLE_API_KEY"); if (env_api_key != nullptr && strlen(env_api_key) > 0) { @@ -328,32 +387,34 @@ int main(int argc, char* argv[]) { client_args.host = "generativelanguage.googleapis.com"; client_args.port = 443; } - + // Handle API key override if (!client_args.api_key.empty()) { client_args.host = "generativelanguage.googleapis.com"; client_args.port = 443; } - - std::cout << BOLD << YELLOW << "🚀 Testing API Server at " - << client_args.host << ":" << client_args.port << RESET << std::endl; - + + std::cout << BOLD << YELLOW << "🚀 Testing API Server at " << client_args.host + << ":" << client_args.port << RESET << std::endl; + try { - APIClient client(client_args.host, client_args.port, client_args.api_key, client_args.model); - + APIClient client(client_args.host, client_args.port, client_args.api_key, + client_args.model); + if (client_args.interactive) { client.InteractiveChat(); } else { client.TestListModels(); client.TestGenerateContent(client_args.prompt, true); } - } catch (const std::exception& e) { std::cerr << RED << "❌ Error: " << e.what() << RESET << std::endl; std::cerr << "Make sure the API server is running:" << std::endl; - std::cerr << " ./build/gemma_api_server --tokenizer --weights " << std::endl; + std::cerr + << " ./build/gemma_api_server --tokenizer --weights " + << std::endl; return 1; } - + return 0; } diff --git a/gemma/api_server.cc b/gemma/api_server.cc index f05447b..8f71043 100644 --- a/gemma/api_server.cc +++ b/gemma/api_server.cc @@ -15,22 +15,19 @@ // HTTP API server for gemma.cpp with SSE support -#include #include +#include -#include -#include -#include -#include -#include -#include -#include #include #include -#include -#include +#include +#include #include +#include +#include +#include // NOLINT #include +#include // HTTP server library #undef CPPHTTPLIB_OPENSSL_SUPPORT @@ -38,16 +35,12 @@ #include "httplib.h" // JSON library -#include "nlohmann/json.hpp" - -#include "compression/types.h" #include "gemma/gemma.h" #include "gemma/gemma_args.h" -#include "gemma/tokenizer.h" #include "ops/matmul.h" #include "util/args.h" #include "hwy/base.h" -#include "hwy/profiler.h" +#include "nlohmann/json.hpp" using json = nlohmann::json; @@ -90,7 +83,8 @@ struct ServerState { std::lock_guard lock(sessions_mutex); auto& session = sessions[session_id]; if (!session.kv_cache) { - session.kv_cache = std::make_unique(gemma->Config(), InferenceArgs(), env->ctx.allocator); + session.kv_cache = std::make_unique( + gemma->Config(), InferenceArgs(), env->ctx.allocator); } session.last_access = std::chrono::steady_clock::now(); return session; @@ -107,7 +101,8 @@ std::string GenerateSessionId() { return ss.str(); } -// Wraps messages with start_of_turn markers - handles both with and without roles +// Wraps messages with start_of_turn markers - handles both with and without +// roles std::string WrapMessagesWithTurnMarkers(const json& contents) { std::string prompt; @@ -121,12 +116,14 @@ std::string WrapMessagesWithTurnMarkers(const json& contents) { std::string text = part["text"]; if (role == "user") { - prompt += "user\n" + text + "\nmodel\n"; + prompt += + "user\n" + text + "\nmodel\n"; } else if (role == "model") { prompt += text + "\n"; } else if (role.empty()) { // Local format without roles - for now, treat as user input - prompt += "user\n" + text + "\nmodel\n"; + prompt += + "user\n" + text + "\nmodel\n"; } } } @@ -163,18 +160,15 @@ RuntimeConfig ParseGenerationConfig(const json& request) { return config; } -// Unified response formatter - creates consistent format regardless of request type -json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false) { +// Unified response formatter - creates consistent format regardless of request +// type +json CreateAPIResponse(const std::string& text, + bool is_streaming_chunk = false) { json response = { - {"candidates", {{ - {"content", { - {"parts", {{{"text", text}}}}, - {"role", "model"} - }}, - {"index", 0} - }}}, - {"promptFeedback", {{"safetyRatings", json::array()}}} - }; + {"candidates", + {{{"content", {{"parts", {{{"text", text}}}}, {"role", "model"}}}, + {"index", 0}}}}, + {"promptFeedback", {{"safetyRatings", json::array()}}}}; // Only add finishReason for non-streaming chunks if (!is_streaming_chunk) { @@ -185,7 +179,9 @@ json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false) } // Handle generateContent endpoint (non-streaming) -void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) { +void HandleGenerateContentNonStreaming(ServerState& state, + const httplib::Request& req, + httplib::Response& res) { try { json request = json::parse(req.body); @@ -199,7 +195,9 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques prompt = WrapMessagesWithTurnMarkers(request["contents"]); } else { res.status = 400; - res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json"); + res.set_content( + json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), + "application/json"); return; } @@ -209,12 +207,7 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques // Set up runtime config RuntimeConfig runtime_config = ParseGenerationConfig(request); - // Collect full response - std::string full_response; - runtime_config.stream_token = [&full_response](int token, float) { - // Skip EOS token - return true; - }; + runtime_config.stream_token = [](int token, float) { return true; }; // Tokenize prompt std::vector tokens = WrapAndTokenize( @@ -227,7 +220,8 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques // Temporarily redirect output to capture response std::stringstream output; - runtime_config.stream_token = [&output, &state, &session, &tokens](int token, float) { + runtime_config.stream_token = [&output, &state, &session, &tokens]( + int token, float) { // Skip prompt tokens if (session.abs_pos < tokens.size()) { session.abs_pos++; @@ -279,7 +273,9 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques } // Handle streamGenerateContent endpoint with SSE) -void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) { +void HandleGenerateContentStreaming(ServerState& state, + const httplib::Request& req, + httplib::Response& res) { try { json request = json::parse(req.body); @@ -293,7 +289,9 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& prompt = WrapMessagesWithTurnMarkers(request["contents"]); } else { res.status = 400; - res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json"); + res.set_content( + json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), + "application/json"); return; } @@ -305,88 +303,85 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& // Set up chunked content provider for SSE res.set_chunked_content_provider( - "text/event-stream", - [&state, request, prompt, session_id](size_t offset, httplib::DataSink& sink) { - try { - // Lock for inference - std::lock_guard lock(state.inference_mutex); - auto& session = state.GetOrCreateSession(session_id); + "text/event-stream", [&state, request, prompt, session_id]( + size_t offset, httplib::DataSink& sink) { + try { + // Lock for inference + std::lock_guard lock(state.inference_mutex); + auto& session = state.GetOrCreateSession(session_id); - // Set up runtime config - RuntimeConfig runtime_config = ParseGenerationConfig(request); + // Set up runtime config + RuntimeConfig runtime_config = ParseGenerationConfig(request); - // Tokenize prompt - std::vector tokens = WrapAndTokenize( - state.gemma->Tokenizer(), state.gemma->ChatTemplate(), - state.gemma->Config().wrapping, session.abs_pos, prompt); + // Tokenize prompt + std::vector tokens = WrapAndTokenize( + state.gemma->Tokenizer(), state.gemma->ChatTemplate(), + state.gemma->Config().wrapping, session.abs_pos, prompt); + + // Stream token callback + std::string accumulated_text; + auto stream_token = [&](int token, float) { + // Skip prompt tokens + if (session.abs_pos < tokens.size()) { + session.abs_pos++; + return true; + } - // Stream token callback - std::string accumulated_text; - auto stream_token = [&](int token, float) { - // Skip prompt tokens - if (session.abs_pos < tokens.size()) { session.abs_pos++; + + // Check for EOS + if (state.gemma->Config().IsEOS(token)) { + return true; + } + + // Decode token + std::string token_text; + state.gemma->Tokenizer().Decode(std::vector{token}, + &token_text); + accumulated_text += token_text; + + // Send SSE event using unified formatter + json event = CreateAPIResponse(token_text, true); + + std::string sse_data = "data: " + event.dump() + "\n\n"; + sink.write(sse_data.data(), sse_data.size()); + return true; - } + }; - session.abs_pos++; + runtime_config.stream_token = stream_token; - // Check for EOS - if (state.gemma->Config().IsEOS(token)) { - return true; - } + // Run inference with KV cache + TimingInfo timing_info = {.verbosity = 0}; + size_t prefix_end = 0; - // Decode token - std::string token_text; - state.gemma->Tokenizer().Decode(std::vector{token}, &token_text); - accumulated_text += token_text; + state.gemma->Generate(runtime_config, tokens, session.abs_pos, + prefix_end, *session.kv_cache, *state.env, + timing_info); - // Send SSE event using unified formatter - json event = CreateAPIResponse(token_text, true); + // Send final event using unified formatter + json final_event = CreateAPIResponse("", false); + final_event["usageMetadata"] = { + {"promptTokenCount", tokens.size()}, + {"candidatesTokenCount", session.abs_pos - tokens.size()}, + {"totalTokenCount", session.abs_pos}}; - std::string sse_data = "data: " + event.dump() + "\n\n"; - sink.write(sse_data.data(), sse_data.size()); + std::string final_sse = "data: " + final_event.dump() + "\n\n"; + sink.write(final_sse.data(), final_sse.size()); - return true; - }; - - runtime_config.stream_token = stream_token; - - // Run inference with KV cache - TimingInfo timing_info = {.verbosity = 0}; - size_t prefix_end = 0; - - state.gemma->Generate(runtime_config, tokens, session.abs_pos, - prefix_end, *session.kv_cache, *state.env, - timing_info); - - // Send final event using unified formatter - json final_event = CreateAPIResponse("", false); - final_event["usageMetadata"] = { - {"promptTokenCount", tokens.size()}, - {"candidatesTokenCount", session.abs_pos - tokens.size()}, - {"totalTokenCount", session.abs_pos} - }; - - std::string final_sse = "data: " + final_event.dump() + "\n\n"; - sink.write(final_sse.data(), final_sse.size()); - - // Send done event - sink.write("data: [DONE]\n\n", 15); - - // Ensure all data is sent - sink.done(); - return false; // End streaming - - } catch (const std::exception& e) { - json error_event = {{"error", {{"message", e.what()}}}}; - std::string error_sse = "data: " + error_event.dump() + "\n\n"; - sink.write(error_sse.data(), error_sse.size()); - return false; - } - } - ); + // Send done event + sink.write("data: [DONE]\n\n", 15); + // Ensure all data is sent + sink.done(); + return false; // End streaming + } catch (const std::exception& e) { + json error_event = {{"error", {{"message", e.what()}}}}; + std::string error_sse = "data: " + error_event.dump() + "\n\n"; + sink.write(error_sse.data(), error_sse.size()); + return false; + } + }); } catch (const json::exception& e) { res.status = 400; res.set_content( @@ -398,20 +393,20 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& } // Handle models list endpoint -void HandleListModels(ServerState& state, const InferenceArgs& inference, const httplib::Request& req, httplib::Response& res) { +void HandleListModels(ServerState& state, const InferenceArgs& inference, + const httplib::Request& req, httplib::Response& res) { json response = { - {"models", {{ - {"name", "models/" + inference.model}, - {"version", "001"}, - {"displayName", inference.model}, - {"description", inference.model + " model running locally"}, - {"inputTokenLimit", 8192}, - {"outputTokenLimit", 8192}, - {"supportedGenerationMethods", json::array({"generateContent", "streamGenerateContent"})}, - {"temperature", 1.0}, - {"topK", 1} - }}} - }; + {"models", + {{{"name", "models/" + inference.model}, + {"version", "001"}, + {"displayName", inference.model}, + {"description", inference.model + " model running locally"}, + {"inputTokenLimit", 8192}, + {"outputTokenLimit", 8192}, + {"supportedGenerationMethods", + json::array({"generateContent", "streamGenerateContent"})}, + {"temperature", 1.0}, + {"topK", 1}}}}}; res.set_content(response.dump(), "application/json"); } @@ -421,39 +416,45 @@ void HandleListModels(ServerState& state, const InferenceArgs& inference, const // server_running = false; // } -void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading, - const InferenceArgs& inference) { +void RunServer(const GemmaArgs& args) { std::cerr << "Loading model..." << std::endl; // Initialize model - ThreadingContext ctx(threading); + ThreadingContext ctx(args.threading); MatMulEnv env(ctx); - ServerState state; - state.gemma = std::make_unique(loader, inference, ctx); + state.gemma = std::make_unique(args, ctx); state.env = &env; state.ctx = &ctx; + const InferenceArgs& inference = args.inference; + httplib::Server server; // Set up routes - server.Get("/", [&inference](const httplib::Request&, httplib::Response& res) { - res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" + inference.model + ":generateContent", "text/plain"); - }); + server.Get( + "/", [&inference](const httplib::Request&, httplib::Response& res) { + res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" + + inference.model + ":generateContent", + "text/plain"); + }); // API endpoints - server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req, httplib::Response& res) { + server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req, + httplib::Response& res) { HandleListModels(state, inference, req, res); }); std::string model_endpoint = "/v1beta/models/" + inference.model; - server.Post(model_endpoint + ":generateContent", [&state](const httplib::Request& req, httplib::Response& res) { - HandleGenerateContentNonStreaming(state, req, res); - }); + server.Post(model_endpoint + ":generateContent", + [&state](const httplib::Request& req, httplib::Response& res) { + HandleGenerateContentNonStreaming(state, req, res); + }); - server.Post(model_endpoint + ":streamGenerateContent", [&state](const httplib::Request& req, httplib::Response& res) { - HandleGenerateContentStreaming(state, req, res); - }); + server.Post(model_endpoint + ":streamGenerateContent", + [&state](const httplib::Request& req, httplib::Response& res) { + HandleGenerateContentStreaming(state, req, res); + }); // Periodic cleanup of old sessions std::thread cleanup_thread([&state]() { @@ -466,12 +467,15 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading, std::cerr << "Starting API server on port " << inference.port << std::endl; std::cerr << "Model loaded successfully" << std::endl; std::cerr << "Endpoints:" << std::endl; - std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent" << std::endl; - std::cerr << " POST /v1beta/models/" << inference.model << ":streamGenerateContent (SSE)" << std::endl; + std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent" + << std::endl; + std::cerr << " POST /v1beta/models/" << inference.model + << ":streamGenerateContent (SSE)" << std::endl; std::cerr << " GET /v1beta/models" << std::endl; if (!server.listen("0.0.0.0", inference.port)) { - std::cerr << "Failed to start server on port " << inference.port << std::endl; + std::cerr << "Failed to start server on port " << inference.port + << std::endl; } cleanup_thread.join(); @@ -482,35 +486,27 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading, int main(int argc, char** argv) { gcpp::InternalInit(); - gcpp::LoaderArgs loader(argc, argv); - gcpp::ThreadingArgs threading(argc, argv); - gcpp::InferenceArgs inference(argc, argv); + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::GemmaArgs args(argc, argv, consumed); if (gcpp::HasHelp(argc, argv)) { - std::cerr << "\n\nAPI server for gemma.cpp\n"; - std::cout << "========================\n\n"; - std::cerr << "Usage: " << argv[0] << " --weights --tokenizer [options]\n"; - std::cerr << "\nOptions:\n"; - std::cerr << " --port PORT Server port (default: 8080)\n"; - std::cerr << " --model MODEL Model name for endpoints (default: gemma3-4b)\n"; - std::cerr << "\n"; - std::cerr << "\n*Model Loading Arguments*\n\n"; - loader.Help(); - std::cerr << "\n*Threading Arguments*\n\n"; - threading.Help(); - std::cerr << "\n*Inference Arguments*\n\n"; - inference.Help(); - std::cerr << "\n"; + fprintf( + stderr, + "\n\nAPI server for gemma.cpp\n" + "========================\n\n" + " --port PORT Server port (default: 8080)\n" + " --model MODEL Model name for endpoints (default: gemma3-4b)\n"); + args.Help(); return 0; } - // Arguments are now handled by InferenceArgs + consumed.AbortIfUnconsumed(); // // Set up signal handler // signal(SIGINT, gcpp::HandleShutdown); // signal(SIGTERM, gcpp::HandleShutdown); - gcpp::RunServer(loader, threading, inference); + gcpp::RunServer(args); return 0; } diff --git a/gemma/bindings/context.cc b/gemma/bindings/context.cc index 5741d70..5db3adc 100644 --- a/gemma/bindings/context.cc +++ b/gemma/bindings/context.cc @@ -73,45 +73,38 @@ GemmaContext* GemmaContext::Create(const char* tokenizer_path, ThreadingArgs threading_args; threading_args.spin = gcpp::Tristate::kFalse; - LoaderArgs loader(tokenizer_path, weights_path); - LogDebug("LoaderArgs created"); + threading_args.spin = gcpp::Tristate::kFalse; + GemmaArgs args(LoaderArgs(tokenizer_path, weights_path), threading_args); // Initialize cached args LogDebug("Initializing inference args"); - InferenceArgs inference_args; - inference_args.Init(); - inference_args.max_generated_tokens = max_generated_tokens; - inference_args.temperature = 0.7f; - inference_args.top_k = 1; - inference_args.deterministic = false; + args.inference.max_generated_tokens = max_generated_tokens; + args.inference.temperature = 0.7f; + args.inference.top_k = 1; + args.inference.deterministic = false; ss.str(""); ss << "Inference args initialized with max_tokens: " << max_generated_tokens - << ", temperature: " << inference_args.temperature - << ", top_k: " << inference_args.top_k << ", deterministic: " - << (inference_args.deterministic ? "true" : "false"); + << ", temperature: " << args.inference.temperature + << ", top_k: " << args.inference.top_k << ", deterministic: " + << (args.inference.deterministic ? "true" : "false"); LogDebug(ss.str().c_str()); - return new GemmaContext(loader, inference_args, threading_args, - max_generated_tokens); + return new GemmaContext(args, max_generated_tokens); } -GemmaContext::GemmaContext(const LoaderArgs& loader, - const InferenceArgs& inference_args, - const ThreadingArgs& threading_args, - int max_generated_tokens) - : inference_args(inference_args), - threading_args(threading_args), - ctx(threading_args), +GemmaContext::GemmaContext(const GemmaArgs& args, int max_generated_tokens) + : args(args), + ctx(args.threading), matmul_env(ctx), active_conversation_name("default"), - model(loader, inference_args, matmul_env.ctx) { + model(args, matmul_env.ctx) { std::stringstream ss; LogDebug("Creating initial ConversationData"); // Create the initial ConversationData object using make_shared active_conversation = std::make_shared( - model.Config(), inference_args, ctx.allocator); + model.Config(), args.inference, ctx.allocator); LogDebug( "Storing initial ConversationData in conversation_cache[\"default\"]"); @@ -172,8 +165,8 @@ int GemmaContext::GenerateInternal(const char* prompt_string, // set up runtime config TimingInfo timing_info = {}; RuntimeConfig runtime_config = {.stream_token = stream_token, - .use_spinning = threading_args.spin}; - inference_args.CopyTo(runtime_config); + .use_spinning = args.threading.spin}; + args.inference.CopyTo(runtime_config); size_t prefix_end = 0; const ModelConfig& model_config = model.Config(); @@ -247,7 +240,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string, timing_info); // prepare for next turn - if (!inference_args.multiturn || + if (!args.inference.multiturn || model_config.wrapping == PromptWrapping::PALIGEMMA) { // If not multiturn, or Paligemma (which handles turns differently), // reset the *active* conversation's position. diff --git a/gemma/bindings/context.h b/gemma/bindings/context.h index 00648fc..5aa3412 100644 --- a/gemma/bindings/context.h +++ b/gemma/bindings/context.h @@ -53,8 +53,7 @@ typedef void (*GemmaLogCallback)(const char* message, void* user_data); class GemmaContext { private: - GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args, - const ThreadingArgs& threading_args, int max_generated_tokens); + GemmaContext(const GemmaArgs& args, int max_generated_tokens); public: static GemmaContext* Create(const char* tokenizer_path, @@ -81,37 +80,37 @@ class GemmaContext { // Set max generated tokens void SetMaxGeneratedTokens(size_t value) { - inference_args.max_generated_tokens = value; + args.inference.max_generated_tokens = value; LogDebug("Setting max_generated_tokens to configured value"); } // Set multiturn flag (0 = disabled, 1 = enabled) void SetMultiturn(int value) { - inference_args.multiturn = value; + args.inference.multiturn = value; LogDebug("Setting multiturn to configured value"); } // Set temperature for token generation void SetTemperature(float value) { - inference_args.temperature = value; + args.inference.temperature = value; LogDebug("Setting temperature to configured value"); } // Set top_k parameter for sampling void SetTopK(int value) { - inference_args.top_k = value; + args.inference.top_k = value; LogDebug("Setting top_k to configured value"); } // Set deterministic flag void SetDeterministic(bool value) { - inference_args.deterministic = value; + args.inference.deterministic = value; LogDebug("Setting deterministic flag to configured value"); } // Set prefill_tbatch_size void SetPrefillTbatchSize(size_t value) { - inference_args.prefill_tbatch_size = value; + args.inference.prefill_tbatch_size = value; LogDebug("Setting prefill_tbatch_size to configured value"); } @@ -175,7 +174,7 @@ class GemmaContext { active_conversation->abs_pos = 0; // Replace the cache within the current ConversationData object active_conversation->kv_cache = std::make_unique( - model.Config(), inference_args, ctx.allocator); + model.Config(), args.inference, ctx.allocator); LogDebug((log_prefix + "Successfully rewound to initial state.").c_str()); } else { @@ -193,7 +192,7 @@ class GemmaContext { LogDebug("Creating new conversation"); // Create a new ConversationData object using make_shared conversation_cache[name] = std::make_shared( - model.Config(), inference_args, ctx.allocator); + model.Config(), args.inference, ctx.allocator); return true; } @@ -274,8 +273,7 @@ class GemmaContext { std::vector token_buffer; // Cached args (remain global for the context) - InferenceArgs inference_args; - ThreadingArgs threading_args; + GemmaArgs args; ThreadingContext ctx; MatMulEnv matmul_env; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index c58a5a8..0ce6ab3 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -738,17 +738,16 @@ HWY_EXPORT(GenerateSingleT); HWY_EXPORT(GenerateBatchT); HWY_EXPORT(GenerateImageTokensT); -Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference, - ThreadingContext& ctx) - : reader_(loader.weights), - model_(reader_, loader.tokenizer, loader.wrapping), +Gemma::Gemma(const GemmaArgs& args, ThreadingContext& ctx) + : reader_(args.loader.weights), + model_(reader_, args.loader.tokenizer, args.loader.wrapping), weights_(model_.Config()), chat_template_(model_.Tokenizer(), model_.Config().model), - inference_(inference), - aes_ctr_engine_(inference.deterministic) { + inference_(args.inference), + aes_ctr_engine_(args.inference.deterministic) { // Negligible CPU time in the ctor body (except ReadFromBlobs). - weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference, - mat_owners_, ctx); + weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, args.loader, + args.inference, mat_owners_, ctx); // Read everything into memory, or `weights_.mapped_` keeps the mapping alive. reader_.CloseFile(); } diff --git a/gemma/gemma.h b/gemma/gemma.h index f503a1a..b630a8c 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -130,11 +130,16 @@ struct TimingInfo { // separate `ThreadingContext` and `MatMulEnv` for each concurrent `Generate`. class Gemma { public: - // Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`. + // Reads weights/config/tokenizer from `BlobStore` at `args.loader.weights`. // `ctx` is only used to read tensors and not stored. Calls to `Generate*` // may reference the same, or other `ThreadingContext` via `MatMulEnv`. + Gemma(const GemmaArgs& args, ThreadingContext& ctx); + + // Deprecated prior interface for backwards compatibility. Gemma(const LoaderArgs& loader, const InferenceArgs& inference, - ThreadingContext& ctx); + ThreadingContext& ctx) + : Gemma(GemmaArgs(loader, ThreadingArgs(), inference), ctx) {} + ~Gemma(); const ModelConfig& Config() const { return model_.Config(); } diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 0db32d3..ba72db6 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -25,10 +25,11 @@ #include #include "gemma/configs.h" -#include "io/io.h" // Path -#include "util/args.h" +#include "io/io.h" // Path +#include "util/args.h" // IWYU pragma: export #include "util/basics.h" // Tristate #include "util/mat.h" +#include "util/threading_context.h" #include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" // HWY_ABORT #include "hwy/profiler.h" @@ -36,7 +37,9 @@ namespace gcpp { struct LoaderArgs : public ArgsBase { - LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + LoaderArgs(int argc, char* argv[], ConsumedArgs& consumed) { + InitAndParse(argc, argv, consumed); + } LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path) { Init(); // Init sets to defaults, so assignments must come after Init(). @@ -169,7 +172,9 @@ struct RuntimeConfig { }; struct InferenceArgs : public ArgsBase { - InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + InferenceArgs(int argc, char* argv[], ConsumedArgs& consumed) { + InitAndParse(argc, argv, consumed); + } InferenceArgs() { Init(); }; bool IsInteractive() const { return prompt.empty() && prompt_file.Empty(); } @@ -275,33 +280,35 @@ struct InferenceArgs : public ArgsBase { } }; -struct ClientArgs : public ArgsBase { - ClientArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } - ClientArgs() { Init(); }; +// Bundles all args required to construct a `GemmaEnv` or the equivalent. +struct GemmaArgs { + // For callers that do not parse command line args. + GemmaArgs(const LoaderArgs& loader, + const ThreadingArgs& threading = ThreadingArgs(), + const InferenceArgs& inference = InferenceArgs()) + : loader(loader), threading(threading), inference(inference) {} - std::string host; - int port; - std::string api_key; - std::string model; - std::string prompt; - bool interactive; + GemmaArgs(int argc, char** argv, ConsumedArgs& consumed) + : loader(argc, argv, consumed), + threading(argc, argv, consumed), + inference(argc, argv, consumed) {} - template - void ForEach(const Visitor& visitor) { - visitor(host, "host", std::string("localhost"), - "Server host (default: localhost)"); - visitor(port, "port", 8080, - "Server port (default: 8080)"); - visitor(api_key, "api_key", std::string(""), - "Use public API with key (changes host to " - "generativelanguage.googleapis.com:443)"); - visitor(model, "model", std::string("gemma3-4b"), - "Model name to use (default: gemma3-4b)"); - visitor(prompt, "prompt", std::string("Hello! How are you?"), - "Prompt for generation (default: 'Hello! How are you?')"); - visitor(interactive, "interactive", false, - "Start interactive chat mode (0 = no, 1 = yes)"); + void Help() { + fprintf(stderr, + "To run with pre-2025 weights, specify --tokenizer and --weights.\n" + "With the single-file weights format, specify just --weights.\n" + "\n*Model Loading Arguments*\n"); + loader.Help(); + fprintf(stderr, "\n*Threading Arguments*\n"); + threading.Help(); + fprintf(stderr, "\n*Inference Arguments*\n"); + inference.Help(); + fprintf(stderr, "\n"); } + + LoaderArgs loader; + ThreadingArgs threading; + InferenceArgs inference; }; } // namespace gcpp diff --git a/gemma/gemma_args_test.cc b/gemma/gemma_args_test.cc new file mode 100644 index 0000000..d9ee8b7 --- /dev/null +++ b/gemma/gemma_args_test.cc @@ -0,0 +1,74 @@ +#include "gemma/gemma_args.h" + +#include + +#include +#include + +#include "gtest/gtest.h" + +namespace gcpp { + +void FillPtrs(const std::vector& args, std::vector& ptrs) { + ptrs.reserve(args.size()); + for (const std::string& arg : args) { + ptrs.push_back(const_cast(arg.data())); + } +} + +static void CheckAllConsumed(const std::vector& args) { + std::vector ptrs; + FillPtrs(args, ptrs); + const int argc = static_cast(args.size()); + char** argv = const_cast(ptrs.data()); + + ConsumedArgs consumed(argc, argv); + GemmaArgs gemma_args(argc, argv, consumed); + consumed.AbortIfUnconsumed(); +} + +static void CheckUnconsumed(const std::vector& args, + size_t expected) { + std::vector ptrs; + FillPtrs(args, ptrs); + const int argc = static_cast(args.size()); + char** argv = const_cast(ptrs.data()); + + ConsumedArgs consumed(argc, argv); + GemmaArgs gemma_args(argc, argv, consumed); + ASSERT_EQ(expected, consumed.FirstUnconsumed()); +} + +// Note: do not use --help because that is not actually consumed; it is actually +// special-cased in `HasHelp`. +TEST(GemmaArgsTest, AllConsumedArgs) { + // Single arg + CheckAllConsumed({"gemma", "--weights=x"}); + // Two args, one with = + CheckAllConsumed({"gemma", "--weights=x", "--verbosity=1"}); + // Two args, one with extra value + CheckAllConsumed({"gemma", "--weights=x", "--verbosity", "2"}); + // Two args with values + CheckAllConsumed({"gemma", "--verbosity", "2", "--deterministic=true"}); +} + +TEST(GemmaArgsTest, UnconsumedArgs) { + // Single unconsumed arg + CheckUnconsumed({"gemma", "--UNDEFINED"}, 1); + // Single unconsumed arg, no -- + CheckUnconsumed({"gemma", "UNDEFINED"}, 1); + // Single unconsumed arg after valid arg + CheckUnconsumed({"gemma", "--weights=x", "--UNDEFINED"}, 2); + // Single unconsumed arg before valid arg + CheckUnconsumed({"gemma", "--UNDEFINED", "--weights=x"}, 1); + // Single unconsumed arg with = after valid arg + CheckUnconsumed({"gemma", "--weights=x", "--UNDEFINED=1"}, 2); + // Single unconsumed arg with = before valid arg + CheckUnconsumed({"gemma", "--UNDEFINED=false", "--weights=x"}, 1); + // Multiple unconsumed args + CheckUnconsumed({"gemma", "--UNDEFINED", "--XXX"}, 1); + // Multiple unconsumed args with valid arg between + CheckUnconsumed({"gemma", "--UNDEFINED", "--weights=x", "--XXX"}, 1); +} + +} // namespace gcpp diff --git a/gemma/run.cc b/gemma/run.cc index 90da090..6c6f4d0 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -89,9 +89,11 @@ std::string GetPrompt(const InferenceArgs& inference) { } // The main Read-Eval-Print Loop. -void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, - const Gemma& gemma, KVCache& kv_cache, MatMulEnv& env) { +void ReplGemma(const GemmaArgs& args, const Gemma& gemma, KVCache& kv_cache, + MatMulEnv& env) { PROFILER_ZONE("Gen.misc"); + const InferenceArgs& inference = args.inference; + const int verbosity = inference.verbosity; size_t abs_pos = 0; // across turns size_t tokens_generated_this_turn = 0; // differentiates prefill from reply size_t prompt_size = 0; @@ -113,12 +115,12 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, HWY_ASSERT(image.ReadPPM(inference.image_file.path)); const size_t image_size = config.vit_config.image_size; image.Resize(image_size, image_size); - RuntimeConfig runtime_config = {.verbosity = inference.verbosity, - .use_spinning = threading.spin}; + RuntimeConfig runtime_config = {.verbosity = verbosity, + .use_spinning = args.threading.spin}; double image_tokens_start = hwy::platform::Now(); gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image, image_tokens, env); - if (inference.verbosity >= 1) { + if (verbosity >= 1) { double image_tokens_duration = hwy::platform::Now() - image_tokens_start; fprintf(stderr, "\n\n[ Timing info ] Image token generation took: %d ms\n", @@ -189,7 +191,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, TimingInfo timing_info = {.verbosity = inference.verbosity}; RuntimeConfig runtime_config = {.verbosity = inference.verbosity, .batch_stream_token = batch_stream_token, - .use_spinning = threading.spin}; + .use_spinning = args.threading.spin}; inference.CopyTo(runtime_config); std::vector prompt; size_t prefix_end = 0; @@ -252,14 +254,14 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, } } -void Run(const LoaderArgs& loader, const ThreadingArgs& threading, - const InferenceArgs& inference) { +void Run(const GemmaArgs& args) { PROFILER_ZONE("Run.misc"); - ThreadingContext ctx(threading); + ThreadingContext ctx(args.threading); MatMulEnv env(ctx); + const InferenceArgs& inference = args.inference; if (inference.verbosity >= 3) env.print_best = true; - const Gemma gemma(loader, inference, ctx); + const Gemma gemma(args, ctx); KVCache kv_cache(gemma.Config(), inference, ctx.allocator); if (inference.verbosity >= 1) { @@ -287,13 +289,12 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading, if (inference.IsInteractive()) { std::cout << "\033[2J\033[1;1H" // clear screen << kAsciiArtBanner << "\n\n"; - ShowConfig(loader, threading, inference, gemma.Config(), - gemma.WeightReadMode(), ctx); + ShowConfig(args, gemma.Config(), gemma.WeightReadMode(), ctx); std::cout << "\n" << instructions << "\n"; } } - ReplGemma(threading, inference, gemma, kv_cache, env); + ReplGemma(args, gemma, kv_cache, env); } } // namespace gcpp @@ -302,17 +303,24 @@ int main(int argc, char** argv) { gcpp::InternalInit(); { // Negligible CPU time. - gcpp::LoaderArgs loader(argc, argv); - gcpp::ThreadingArgs threading(argc, argv); - gcpp::InferenceArgs inference(argc, argv); + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::GemmaArgs args(argc, argv, consumed); if (gcpp::HasHelp(argc, argv)) { std::cerr << gcpp::kAsciiArtBanner; - gcpp::ShowHelp(loader, threading, inference); + fprintf(stderr, + "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n" + "==========================================================\n\n" + "*Example Usage*\n\n./gemma --tokenizer tokenizer.spm " + "--weights gemma2-2b-it-sfp.sbs\n\n"); + args.Help(); return 0; } - gcpp::Run(loader, threading, inference); + // After `HasHelp` so that we print --help even if unconsumed args remain. + consumed.AbortIfUnconsumed(); + + gcpp::Run(args); } PROFILER_PRINT_RESULTS(); // Must call outside the zone above. return 0; diff --git a/io/migrate_weights.cc b/io/migrate_weights.cc index d20835f..beb268e 100644 --- a/io/migrate_weights.cc +++ b/io/migrate_weights.cc @@ -23,7 +23,9 @@ namespace gcpp { namespace { struct WriterArgs : public ArgsBase { - WriterArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + WriterArgs(int argc, char* argv[], ConsumedArgs& consumed) { + InitAndParse(argc, argv, consumed); + } Path output_weights; @@ -38,12 +40,15 @@ struct WriterArgs : public ArgsBase { } // namespace gcpp int main(int argc, char** argv) { - gcpp::WriterArgs args(argc, argv); - if (args.output_weights.Empty()) { + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::GemmaArgs args(argc, argv, consumed); + gcpp::WriterArgs writer_args(argc, argv, consumed); + if (writer_args.output_weights.Empty()) { HWY_ABORT("Missing --output_weights flag, a file for the model weights."); } + consumed.AbortIfUnconsumed(); - gcpp::GemmaEnv env(argc, argv); - env.GetGemma()->Save(args.output_weights, env.Env().ctx); + gcpp::GemmaEnv env(args); + env.GetGemma()->Save(writer_args.output_weights, env.Env().ctx); return 0; } diff --git a/paligemma/paligemma_test.cc b/paligemma/paligemma_test.cc index 1075a0a..7bfd78c 100644 --- a/paligemma/paligemma_test.cc +++ b/paligemma/paligemma_test.cc @@ -21,10 +21,9 @@ #include "evals/benchmark_helper.h" #include "gemma/configs.h" #include "gemma/gemma.h" -#include "io/io.h" +#include "paligemma/paligemma_helper.h" #include "util/allocator.h" #include "hwy/tests/hwy_gtest.h" -#include "paligemma/paligemma_helper.h" // This test can be run manually with the downloaded PaliGemma weights. // It should pass for `paligemma-3b-mix-224` and `paligemma2-3b-pt-448`. @@ -73,7 +72,11 @@ TEST_F(PaliGemmaTest, QueryObjects) { int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); - gcpp::GemmaEnv env(argc, argv); + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::GemmaArgs args(argc, argv, consumed); + consumed.AbortIfUnconsumed(); + + gcpp::GemmaEnv env(args); gcpp::s_env = &env; return RUN_ALL_TESTS(); diff --git a/python/gemma_py.cc b/python/gemma_py.cc index 1bab194..0d056d9 100644 --- a/python/gemma_py.cc +++ b/python/gemma_py.cc @@ -45,10 +45,7 @@ static void RemoveTrailingZeros(std::vector &vec) { // Wrapper around GemmaEnv to expose to Python. class GemmaModel { public: - GemmaModel(const gcpp::LoaderArgs& loader, - const gcpp::ThreadingArgs& threading, - const gcpp::InferenceArgs& inference) - : env_(loader, threading, inference), last_prob_(0.0f) {} + GemmaModel(const gcpp::GemmaArgs& args) : env_(args), last_prob_(0.0f) {} // Generates a single example, given a prompt and a callback to stream the // generated tokens. @@ -254,13 +251,15 @@ PYBIND11_MODULE(gemma, mod) { py::class_(mod, "GemmaModel") .def(py::init([](const std::string& tokenizer, const std::string& weights, size_t max_threads) { - const gcpp::LoaderArgs loader(tokenizer, weights); gcpp::ThreadingArgs threading; threading.max_lps = max_threads; + gcpp::InferenceArgs inference; inference.max_generated_tokens = 512; - auto gemma = - std::make_unique(loader, threading, inference); + + const gcpp::GemmaArgs args(gcpp::LoaderArgs(tokenizer, weights), + threading, inference); + auto gemma = std::make_unique(args); if (!gemma->ModelIsLoaded()) { throw std::invalid_argument("Could not load model."); } diff --git a/util/args.h b/util/args.h index 317be20..8c6423b 100644 --- a/util/args.h +++ b/util/args.h @@ -22,6 +22,7 @@ #include // std::transform #include +#include #include "io/io.h" // Path #include "util/basics.h" // Tristate @@ -29,6 +30,56 @@ namespace gcpp { +// For checking which args were not matched/consumed. Passed to each `*Args` +// ctor that parses argc/argv to ensure that their args are tracked, without +// requiring global state. +class ConsumedArgs { + public: + ConsumedArgs(int argc, char** argv) : argv_(argv), consumed_(argc) { + // We assume argc >= 1, because argv[0] is the binary name. That allows us + // to signal "called AbortIfUnconsumed" with an empty vector. + HWY_ASSERT(!consumed_.empty()); + } + + ~ConsumedArgs() { + if (HWY_UNLIKELY(!consumed_.empty())) { + HWY_ABORT("AbortIfUnconsumed was not called."); + } + } + + void NotifyConsumed(size_t idx) { + HWY_ASSERT(idx < consumed_.size()); + HWY_ASSERT(consumed_[idx] == 0); + consumed_[idx] = 1; + } + + // Returns index of first unconsumed arg, or 0 if none. Also disarms the + // warning in the dtor checking whether this/`AbortIfUnconsumed` were called. + size_t FirstUnconsumed() { + // Ignore argv[0], which is the binary name. + for (size_t i = 1; i < consumed_.size(); ++i) { + if (HWY_UNLIKELY(consumed_[i] == 0)) { + consumed_.clear(); + return i; + } + } + + consumed_.clear(); + return 0; + } + + void AbortIfUnconsumed() { + const size_t i = FirstUnconsumed(); + if (HWY_UNLIKELY(i != 0)) { + HWY_ABORT("Unrecognized arg %zu: %s\n", i, argv_[i]); + } + } + + private: + char** argv_; + std::vector consumed_; +}; + // Args is a class that provides a ForEach member function which visits each of // its member variables. ArgsBase provides functions called by Args to // initialize values to their defaults (passed as an argument to the visitor), @@ -93,7 +144,8 @@ class ArgsBase { // consider adding a hash-map to speed this up. class ParseVisitor { public: - ParseVisitor(int argc, char* argv[]) : argc_(argc), argv_(argv) {} + ParseVisitor(int argc, char* argv[], ConsumedArgs& consumed) + : argc_(argc), argv_(argv), consumed_(consumed) {} template void operator()(T& t, const char* name, const T& /*init*/, @@ -108,6 +160,8 @@ class ArgsBase { if (!SetValue(argv_[i + 1], t)) { HWY_ABORT("Invalid value for %s, got %s\n", name, argv_[i + 1]); } + consumed_.NotifyConsumed(i); + consumed_.NotifyConsumed(i + 1); return; } if (std::string(argv_[i]).find(prefixed_eq) == 0) { @@ -115,6 +169,7 @@ class ArgsBase { if (!SetValue(value, t)) { HWY_ABORT("Invalid value for %s, got %s\n", name, value); } + consumed_.NotifyConsumed(i); return; } } @@ -181,8 +236,9 @@ class ArgsBase { } } - int argc_; - char** argv_; + const int argc_; + char** const argv_; + ConsumedArgs& consumed_; }; // ParseVisitor template @@ -211,15 +267,15 @@ class ArgsBase { ForEach(visitor); } - void Parse(int argc, char* argv[]) { - ParseVisitor visitor(argc, argv); + void Parse(int argc, char* argv[], ConsumedArgs& consumed) { + ParseVisitor visitor(argc, argv, consumed); ForEach(visitor); } // For convenience, enables single-line constructor. - void InitAndParse(int argc, char* argv[]) { + void InitAndParse(int argc, char* argv[], ConsumedArgs& consumed) { Init(); - Parse(argc, argv); + Parse(argc, argv, consumed); } }; diff --git a/util/threading_context.h b/util/threading_context.h index 07c8089..7e595ba 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -38,7 +38,9 @@ namespace gcpp { // Optional arguments for `ThreadingContext` from the command line. class ThreadingArgs : public ArgsBase { public: - ThreadingArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + ThreadingArgs(int argc, char* argv[], ConsumedArgs& consumed) { + InitAndParse(argc, argv, consumed); + } ThreadingArgs() { Init(); }; // For BoundedTopology: