Abort if args are unrecognized, refactor argument passing

This catches typos/incorrect usage.
Refactor: group Loader/Threading/Inference into GemmaArgs.
All *Args ctors now have an extra ConsumedArgs& argument.
PiperOrigin-RevId: 844690553
This commit is contained in:
Jan Wassenberg 2025-12-15 03:18:11 -08:00 committed by Copybara-Service
parent f50550f4ce
commit 0c64987a96
28 changed files with 713 additions and 513 deletions

View File

@ -556,12 +556,22 @@ cc_library(
":basics", ":basics",
":configs", ":configs",
":mat", ":mat",
":threading_context",
"//io", "//io",
"@highway//:hwy", "@highway//:hwy",
"@highway//:profiler", "@highway//:profiler",
], ],
) )
cc_test(
name = "gemma_args_test",
srcs = ["gemma/gemma_args_test.cc"],
deps = [
":gemma_args",
"@googletest//:gtest_main", # buildcleaner: keep
],
)
cc_library( cc_library(
name = "gemma_lib", name = "gemma_lib",
srcs = [ srcs = [
@ -666,7 +676,6 @@ cc_library(
":gemma_args", ":gemma_args",
":gemma_lib", ":gemma_lib",
":matmul_env", ":matmul_env",
":ops",
":threading_context", ":threading_context",
":tokenizer", ":tokenizer",
"@google_benchmark//:benchmark", "@google_benchmark//:benchmark",

View File

@ -222,6 +222,7 @@ set(GEMMA_TEST_FILES
compression/nuq_test.cc compression/nuq_test.cc
compression/sfp_test.cc compression/sfp_test.cc
evals/gemma_test.cc evals/gemma_test.cc
gemma/gemma_args_test.cc
gemma/flash_attention_test.cc gemma/flash_attention_test.cc
gemma/tensor_info_test.cc gemma/tensor_info_test.cc
io/blob_store_test.cc io/blob_store_test.cc

View File

@ -101,13 +101,11 @@ cc_test(
# for test_suite. # for test_suite.
tags = ["hwy_ops_test"], tags = ["hwy_ops_test"],
deps = [ deps = [
":distortion",
":int", ":int",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"//:test_util", "//:test_util",
"@highway//:hwy", "@highway//:hwy",
"@highway//:hwy_test_util", "@highway//:hwy_test_util",
"@highway//:nanobenchmark",
], ],
) )

View File

@ -23,7 +23,9 @@ using json = nlohmann::json;
class BenchmarkArgs : public ArgsBase<BenchmarkArgs> { class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
public: public:
BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } BenchmarkArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
Path summarize_text; Path summarize_text;
Path cross_entropy; Path cross_entropy;
@ -127,9 +129,16 @@ int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file,
} // namespace gcpp } // namespace gcpp
int main(int argc, char** argv) { int main(int argc, char** argv) {
gcpp::GemmaEnv env(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
gcpp::BenchmarkArgs benchmark_args(argc, argv); gcpp::GemmaArgs args(argc, argv, consumed);
gcpp::BenchmarkArgs benchmark_args(argc, argv, consumed);
if (gcpp::HasHelp(argc, argv)) {
args.Help();
return 0;
}
consumed.AbortIfUnconsumed();
gcpp::GemmaEnv env(args);
if (!benchmark_args.summarize_text.Empty()) { if (!benchmark_args.summarize_text.Empty()) {
return BenchmarkSummary(env, benchmark_args.summarize_text); return BenchmarkSummary(env, benchmark_args.summarize_text);
} else if (!benchmark_args.cross_entropy.Empty()) { } else if (!benchmark_args.cross_entropy.Empty()) {

View File

@ -36,35 +36,29 @@
namespace gcpp { namespace gcpp {
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, GemmaEnv::GemmaEnv(const GemmaArgs& args)
const InferenceArgs& inference)
: initializer_value_(gcpp::InternalInit()), : initializer_value_(gcpp::InternalInit()),
ctx_(threading), ctx_(args.threading),
env_(ctx_), env_(ctx_),
gemma_(loader, inference, ctx_) { gemma_(args, ctx_) {
const ModelConfig& config = gemma_.Config(); const ModelConfig& config = gemma_.Config();
// Only allocate one for starters because GenerateBatch might not be called. // Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator)); kv_caches_.push_back(KVCache(config, args.inference, ctx_.allocator));
if (inference.verbosity >= 2) { if (args.inference.verbosity >= 2) {
ShowConfig(loader, threading, inference, config, gemma_.WeightReadMode(), ShowConfig(args, config, gemma_.WeightReadMode(), ctx_);
ctx_);
} }
if (inference.verbosity >= 3) env_.print_best = true; if (args.inference.verbosity >= 3) env_.print_best = true;
if (inference.verbosity >= 4) env_.print_config = true; if (args.inference.verbosity >= 4) env_.print_config = true;
runtime_config_ = { runtime_config_ = {
.max_generated_tokens = inference.max_generated_tokens, .max_generated_tokens = args.inference.max_generated_tokens,
.temperature = inference.temperature, .temperature = args.inference.temperature,
.verbosity = inference.verbosity, .verbosity = args.inference.verbosity,
}; };
inference.CopyTo(runtime_config_); args.inference.CopyTo(runtime_config_);
} }
GemmaEnv::GemmaEnv(int argc, char** argv)
: GemmaEnv(LoaderArgs(argc, argv), ThreadingArgs(argc, argv),
InferenceArgs(argc, argv)) {}
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) { QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
QueryResult result; QueryResult result;
@ -234,19 +228,19 @@ static constexpr const char* CompiledConfig() {
} }
} }
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, void ShowConfig(const GemmaArgs& args, const ModelConfig& config,
const InferenceArgs& inference, const ModelConfig& config,
const WeightsPtrs::Mode weight_read_mode, const WeightsPtrs::Mode weight_read_mode,
const ThreadingContext& ctx) { const ThreadingContext& ctx) {
threading.Print(inference.verbosity); args.threading.Print(args.inference.verbosity);
loader.Print(inference.verbosity); args.loader.Print(args.inference.verbosity);
inference.Print(inference.verbosity); args.inference.Print(args.inference.verbosity);
fprintf( fprintf(stderr,
stderr, "Model : %s, to_bf16 %d, mmap %d => %s\n", "Model : %s, to_bf16 %d, mmap %d => %s\n",
config.Specifier().c_str(), static_cast<int>(loader.to_bf16), config.Specifier().c_str(), static_cast<int>(args.loader.to_bf16),
static_cast<int>(loader.map), WeightsPtrs::ToString(weight_read_mode)); static_cast<int>(args.loader.map),
WeightsPtrs::ToString(weight_read_mode));
if (inference.verbosity >= 2) { if (args.inference.verbosity >= 2) {
time_t now = time(nullptr); time_t now = time(nullptr);
char* dt = ctime(&now); // NOLINT char* dt = ctime(&now); // NOLINT
char cpu100[100] = "unknown"; char cpu100[100] = "unknown";
@ -259,7 +253,7 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
"Instruction set : %s (%zu bits)\n" "Instruction set : %s (%zu bits)\n"
"Compiled config : %s, profiler %d\n" "Compiled config : %s, profiler %d\n"
"Memory MiB : %4zu\n", "Memory MiB : %4zu\n",
dt, cpu100, static_cast<int>(threading.bind), dt, cpu100, static_cast<int>(args.threading.bind),
ctx.topology.TopologyString(), ctx.pools.PinString(), ctx.topology.TopologyString(), ctx.pools.PinString(),
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()), CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
ctx.cache_info.VectorBytes() * 8, CompiledConfig(), ctx.cache_info.VectorBytes() * 8, CompiledConfig(),
@ -267,22 +261,4 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
} }
} }
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference) {
std::cerr
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
"==========================================================\n\n"
"To run with pre-2025 weights, specify --tokenizer and --weights.\n"
"With the single-file weights format, specify just --weights.\n";
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
"--weights gemma2-2b-it-sfp.sbs\n";
std::cerr << "\n*Model Loading Arguments*\n\n";
loader.Help();
std::cerr << "\n*Threading Arguments*\n\n";
threading.Help();
std::cerr << "\n*Inference Arguments*\n\n";
inference.Help();
std::cerr << "\n";
}
} // namespace gcpp } // namespace gcpp

View File

@ -23,7 +23,7 @@
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/gemma_args.h" #include "gemma/gemma_args.h" // IWYU pragma: export
#include "gemma/tokenizer.h" // WrapAndTokenize #include "gemma/tokenizer.h" // WrapAndTokenize
#include "ops/matmul.h" #include "ops/matmul.h"
#include "util/threading_context.h" #include "util/threading_context.h"
@ -50,10 +50,8 @@ struct QueryResultAndMetrics {
// Convenience class to load a model and run inference. // Convenience class to load a model and run inference.
class GemmaEnv { class GemmaEnv {
public: public:
// Calls the other constructor with *Args arguments initialized from argv. explicit GemmaEnv(const GemmaArgs& args);
GemmaEnv(int argc, char** argv);
GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference);
MatMulEnv& Env() { return env_; } MatMulEnv& Env() { return env_; }
size_t MaxGeneratedTokens() const { size_t MaxGeneratedTokens() const {
@ -137,12 +135,9 @@ class GemmaEnv {
// Logs the inference speed in tokens/sec. // Logs the inference speed in tokens/sec.
void LogSpeedStats(double time_start, size_t total_tokens); void LogSpeedStats(double time_start, size_t total_tokens);
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, void ShowConfig(const GemmaArgs& args, const ModelConfig& config,
const InferenceArgs& inference, const ModelConfig& config,
WeightsPtrs::Mode weight_read_mode, WeightsPtrs::Mode weight_read_mode,
const ThreadingContext& ctx); const ThreadingContext& ctx);
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference);
} // namespace gcpp } // namespace gcpp

View File

@ -98,7 +98,11 @@ BENCHMARK(BM_coding_prompt)
->UseRealTime(); ->UseRealTime();
int main(int argc, char** argv) { int main(int argc, char** argv) {
gcpp::GemmaEnv env(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
consumed.AbortIfUnconsumed();
gcpp::GemmaEnv env(args);
env.SetMaxGeneratedTokens(256); env.SetMaxGeneratedTokens(256);
gcpp::s_env = &env; gcpp::s_env = &env;

View File

@ -31,7 +31,9 @@ namespace gcpp {
class PromptArgs : public ArgsBase<PromptArgs> { class PromptArgs : public ArgsBase<PromptArgs> {
public: public:
PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } PromptArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
Path layers_output; // optional Path layers_output; // optional
std::string prompt; std::string prompt;
@ -51,11 +53,15 @@ class PromptArgs : public ArgsBase<PromptArgs> {
}; };
int Run(int argc, char** argv) { int Run(int argc, char** argv) {
PromptArgs prompt_args(argc, argv); ConsumedArgs consumed(argc, argv);
const GemmaArgs args(argc, argv, consumed);
const PromptArgs prompt_args(argc, argv, consumed);
AbortIfInvalidArgs(prompt_args); AbortIfInvalidArgs(prompt_args);
consumed.AbortIfUnconsumed();
json json_output; json json_output;
GemmaEnv env(argc, argv); GemmaEnv env(args);
env.MutableConfig().layers_output = env.MutableConfig().layers_output =
prompt_args.layers_output.Empty() prompt_args.layers_output.Empty()
? LayersOutputFunc() ? LayersOutputFunc()

View File

@ -146,7 +146,11 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
int main(int argc, char** argv) { int main(int argc, char** argv) {
fprintf(stderr, "GemmaEnv setup..\n"); fprintf(stderr, "GemmaEnv setup..\n");
gcpp::GemmaEnv env(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
consumed.AbortIfUnconsumed();
gcpp::GemmaEnv env(args);
gcpp::s_env = &env; gcpp::s_env = &env;
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);

View File

@ -22,7 +22,6 @@
#include "evals/benchmark_helper.h" #include "evals/benchmark_helper.h"
#include "gemma/configs.h" #include "gemma/configs.h"
#include "io/io.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
@ -42,7 +41,11 @@ class GemmaTest : public ::testing::Test {
// Requires argc/argv, hence do not use `SetUpTestSuite`. // Requires argc/argv, hence do not use `SetUpTestSuite`.
static void InitEnv(int argc, char** argv) { static void InitEnv(int argc, char** argv) {
HWY_ASSERT(s_env == nullptr); // Should only be called once. HWY_ASSERT(s_env == nullptr); // Should only be called once.
s_env = new GemmaEnv(argc, argv); ConsumedArgs consumed(argc, argv);
GemmaArgs args(argc, argv, consumed);
consumed.AbortIfUnconsumed();
s_env = new GemmaEnv(args);
const gcpp::ModelConfig& config = s_env->GetGemma()->Config(); const gcpp::ModelConfig& config = s_env->GetGemma()->Config();
fprintf(stderr, "Using %s\n", config.Specifier().c_str()); fprintf(stderr, "Using %s\n", config.Specifier().c_str());
} }

View File

@ -31,7 +31,9 @@
namespace gcpp { namespace gcpp {
struct JsonArgs : public ArgsBase<JsonArgs> { struct JsonArgs : public ArgsBase<JsonArgs> {
JsonArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } JsonArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
Path input; Path input;
@ -151,10 +153,14 @@ void Run(GemmaEnv& env, JsonArgs& json) {
int main(int argc, char** argv) { int main(int argc, char** argv) {
{ {
PROFILER_ZONE("Startup.all"); PROFILER_ZONE("Startup.all");
gcpp::GemmaEnv env(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
gcpp::JsonArgs json(argc, argv); gcpp::GemmaArgs args(argc, argv, consumed);
gcpp::AbortIfInvalidArgs(json); gcpp::JsonArgs json_args(argc, argv, consumed);
gcpp::Run(env, json); gcpp::AbortIfInvalidArgs(json_args);
consumed.AbortIfUnconsumed();
gcpp::GemmaEnv env(args);
gcpp::Run(env, json_args);
} }
PROFILER_PRINT_RESULTS(); // Must call outside the zone above. PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
return 0; return 0;

View File

@ -24,20 +24,20 @@
#include <vector> #include <vector>
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/gemma_args.h" // LoaderArgs #include "gemma/gemma_args.h" // GemmaArgs
#include "gemma/tokenizer.h" #include "gemma/tokenizer.h"
#include "util/args.h" #include "util/args.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "hwy/base.h" #include "hwy/base.h"
int main(int argc, char** argv) { int main(int argc, char** argv) {
gcpp::LoaderArgs loader(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
gcpp::ThreadingArgs threading(argc, argv); gcpp::GemmaArgs args(argc, argv, consumed);
gcpp::InferenceArgs inference(argc, argv);
if (gcpp::HasHelp(argc, argv)) { if (gcpp::HasHelp(argc, argv)) {
loader.Help(); args.Help();
return 0; return 0;
} }
consumed.AbortIfUnconsumed();
// Demonstrate constrained decoding by never outputting certain tokens. // Demonstrate constrained decoding by never outputting certain tokens.
std::set<int> reject_tokens; std::set<int> reject_tokens;
@ -49,10 +49,10 @@ int main(int argc, char** argv) {
} }
// Instantiate model and KV Cache // Instantiate model and KV Cache
gcpp::ThreadingContext ctx(threading); gcpp::ThreadingContext ctx(args.threading);
gcpp::MatMulEnv env(ctx); gcpp::MatMulEnv env(ctx);
gcpp::Gemma gemma(loader, inference, ctx); gcpp::Gemma gemma(args, ctx);
gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator); gcpp::KVCache kv_cache(gemma.Config(), args.inference, ctx.allocator);
size_t generated = 0; size_t generated = 0;
// Tokenize instructions. // Tokenize instructions.

View File

@ -23,7 +23,7 @@
#include <vector> #include <vector>
#include "third_party/gemma_cpp/gemma/gemma.h" #include "third_party/gemma_cpp/gemma/gemma.h"
#include "third_party/gemma_cpp/gemma/gemma_args.h" // LoaderArgs #include "third_party/gemma_cpp/gemma/gemma_args.h" // GemmaArgs
#include "third_party/gemma_cpp/gemma/tokenizer.h" #include "third_party/gemma_cpp/gemma/tokenizer.h"
#include "third_party/gemma_cpp/ops/matmul.h" #include "third_party/gemma_cpp/ops/matmul.h"
#include "third_party/gemma_cpp/util/threading_context.h" #include "third_party/gemma_cpp/util/threading_context.h"
@ -31,18 +31,11 @@
class SimplifiedGemma { class SimplifiedGemma {
public: public:
SimplifiedGemma(const gcpp::LoaderArgs& loader, SimplifiedGemma(const gcpp::GemmaArgs& args)
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(), : ctx_(args.threading),
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
: ctx_(threading),
env_(ctx_), env_(ctx_),
gemma_(loader, inference, ctx_), gemma_(args, ctx_),
kv_cache_(gemma_.Config(), inference, ctx_.allocator) {} kv_cache_(gemma_.Config(), args.inference, ctx_.allocator) {}
SimplifiedGemma(int argc, char** argv)
: SimplifiedGemma(gcpp::LoaderArgs(argc, argv),
gcpp::ThreadingArgs(argc, argv),
gcpp::InferenceArgs(argc, argv)) {}
void Generate(std::string& prompt, size_t max_generated_tokens = 1024, void Generate(std::string& prompt, size_t max_generated_tokens = 1024,
float temperature = 0.7, float temperature = 0.7,

View File

@ -18,26 +18,16 @@
#include <string> #include <string>
#include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp" #include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp"
#include "gemma/gemma_args.h" // LoaderArgs #include "gemma/gemma_args.h"
int main(int argc, char** argv) { int main(int argc, char** argv) {
// Standard usage: LoaderArgs takes argc and argv as input, then parses // Sets arguments from argc and argv. Note that you can instead pass in
// necessary flags. // LoaderArgs, ThreadingArgs, and InferenceArgs directly.
gcpp::LoaderArgs loader(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
consumed.AbortIfUnconsumed();
// Optional: LoaderArgs can also take tokenizer and weights paths directly. SimplifiedGemma gemma(args);
//
// gcpp::LoaderArgs loader("/path/to/tokenizer", "/path/to/weights",
// "model_identifier");
// Optional: ThreadingArgs and InferenceArgs can be passed in as well. If not
// specified, default values will be used.
//
// gcpp::InferenceArgs inference(argc, argv);
// gcpp::ThreadingArgs threading(argc, argv);
// SimplifiedGemma gemma(loader, threading, inference);
SimplifiedGemma gemma(loader);
std::string prompt = "Write a greeting to the world."; std::string prompt = "Write a greeting to the world.";
gemma.Generate(prompt, 256, 0.6); gemma.Generate(prompt, 256, 0.6);

View File

@ -15,18 +15,22 @@
// Test client for API server // Test client for API server
#include <iostream> #include <stdio.h>
#include <string>
#include <sstream>
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
#include <iostream>
#include <sstream>
#include <string>
#include "httplib.h" #include "httplib.h"
#include "nlohmann/json.hpp"
#include "gemma/gemma_args.h" #include "gemma/gemma_args.h"
#include "nlohmann/json.hpp"
using json = nlohmann::json; using json = nlohmann::json;
namespace gcpp {
// ANSI color codes // ANSI color codes
const std::string RESET = "\033[0m"; const std::string RESET = "\033[0m";
const std::string BOLD = "\033[1m"; const std::string BOLD = "\033[1m";
@ -37,9 +41,15 @@ const std::string YELLOW = "\033[33m";
const std::string RED = "\033[31m"; const std::string RED = "\033[31m";
class APIClient { class APIClient {
public: public:
APIClient(const std::string& host, int port, const std::string& api_key = "", const std::string& model = "gemma3-4b") APIClient(const std::string& host, int port, const std::string& api_key = "",
: host_(host), port_(port), api_key_(api_key), model_(model), use_https_(port == 443), interactive_mode_(false) { const std::string& model = "gemma3-4b")
: host_(host),
port_(port),
api_key_(api_key),
model_(model),
use_https_(port == 443),
interactive_mode_(false) {
if (use_https_) { if (use_https_) {
ssl_client_ = std::make_unique<httplib::SSLClient>(host, port); ssl_client_ = std::make_unique<httplib::SSLClient>(host, port);
ssl_client_->set_read_timeout(60, 0); ssl_client_->set_read_timeout(60, 0);
@ -58,8 +68,10 @@ public:
std::string endpoint; std::string endpoint;
if (is_public_api) { if (is_public_api) {
endpoint = stream ? "/v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse" endpoint =
: "/v1beta/models/gemini-2.0-flash:generateContent"; stream
? "/v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse"
: "/v1beta/models/gemini-2.0-flash:generateContent";
} else { } else {
endpoint = stream ? "/v1beta/models/" + model_ + ":streamGenerateContent" endpoint = stream ? "/v1beta/models/" + model_ + ":streamGenerateContent"
: "/v1beta/models/" + model_ + ":generateContent"; : "/v1beta/models/" + model_ + ":generateContent";
@ -67,7 +79,8 @@ public:
// Only show verbose output in non-interactive mode // Only show verbose output in non-interactive mode
if (!interactive_mode_) { if (!interactive_mode_) {
std::cout << "\n" << BOLD << BLUE << "📤 POST " << endpoint << RESET << std::endl; std::cout << "\n"
<< BOLD << BLUE << "📤 POST " << endpoint << RESET << std::endl;
std::cout << "Request: " << request.dump(2) << std::endl; std::cout << "Request: " << request.dump(2) << std::endl;
} }
@ -83,18 +96,21 @@ public:
json response = ProcessRequest(request, stream); json response = ProcessRequest(request, stream);
if (response.contains("error")) { if (response.contains("error")) {
std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl; std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET
<< std::endl;
} }
} }
void TestListModels() { void TestListModels() {
std::cout << "\n" << BOLD << BLUE << "📤 GET /v1beta/models" << RESET << std::endl; std::cout << "\n"
<< BOLD << BLUE << "📤 GET /v1beta/models" << RESET << std::endl;
httplib::Headers headers; httplib::Headers headers;
if (!api_key_.empty()) { if (!api_key_.empty()) {
headers.emplace("X-goog-api-key", api_key_); headers.emplace("X-goog-api-key", api_key_);
} }
auto res = use_https_ ? ssl_client_->Get("/v1beta/models", headers) : client_->Get("/v1beta/models", headers); auto res = use_https_ ? ssl_client_->Get("/v1beta/models", headers)
: client_->Get("/v1beta/models", headers);
if (res && res->status == 200) { if (res && res->status == 200) {
json response = json::parse(res->body); json response = json::parse(res->body);
@ -106,7 +122,9 @@ public:
} }
void InteractiveChat() { void InteractiveChat() {
std::cout << "\n" << BOLD << CYAN << "💬 Interactive Chat Mode (with session)" << RESET << std::endl; std::cout << "\n"
<< BOLD << CYAN << "💬 Interactive Chat Mode (with session)"
<< RESET << std::endl;
std::cout << "Type ':gemma %q' to end.\n" << std::endl; std::cout << "Type ':gemma %q' to end.\n" << std::endl;
interactive_mode_ = true; interactive_mode_ = true;
@ -141,14 +159,16 @@ public:
if (response.contains("candidates") && !response["candidates"].empty()) { if (response.contains("candidates") && !response["candidates"].empty()) {
auto& candidate = response["candidates"][0]; auto& candidate = response["candidates"][0];
if (candidate.contains("content") && candidate["content"].contains("parts")) { if (candidate.contains("content") &&
candidate["content"].contains("parts")) {
for (const auto& part : candidate["content"]["parts"]) { for (const auto& part : candidate["content"]["parts"]) {
if (part.contains("text")) { if (part.contains("text")) {
std::string assistant_response = part["text"].get<std::string>(); std::string assistant_response = part["text"].get<std::string>();
// For streaming, the response is already displayed in real-time // For streaming, the response is already displayed in real-time
// Just add to message history for context // Just add to message history for context
json assistant_message = {{"parts", {{{"text", assistant_response}}}}}; json assistant_message = {
{"parts", {{{"text", assistant_response}}}}};
if (!api_key_.empty()) { if (!api_key_.empty()) {
assistant_message["role"] = "model"; assistant_message["role"] = "model";
} }
@ -157,22 +177,20 @@ public:
} }
} }
} else if (response.contains("error")) { } else if (response.contains("error")) {
std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl; std::cerr << RED << "❌ Error: " << response["error"]["message"]
<< RESET << std::endl;
} }
std::cout << std::endl; std::cout << std::endl;
} }
} }
private: private:
json CreateAPIRequest(const std::string& prompt, const json& messages = json::array()) { json CreateAPIRequest(const std::string& prompt,
const json& messages = json::array()) {
json request = { json request = {
{"generationConfig", { {"generationConfig",
{"temperature", 0.9}, {{"temperature", 0.9}, {"topK", 1}, {"maxOutputTokens", 1024}}}};
{"topK", 1},
{"maxOutputTokens", 1024}
}}
};
if (messages.empty()) { if (messages.empty()) {
// Single prompt // Single prompt
@ -189,38 +207,42 @@ private:
return request; return request;
} }
json ProcessNonStreamingRequest(const json& request, const std::string& endpoint) { json ProcessNonStreamingRequest(const json& request,
const std::string& endpoint) {
httplib::Headers headers = {{"Content-Type", "application/json"}}; httplib::Headers headers = {{"Content-Type", "application/json"}};
if (!api_key_.empty()) { if (!api_key_.empty()) {
headers.emplace("X-goog-api-key", api_key_); headers.emplace("X-goog-api-key", api_key_);
} }
auto res = use_https_ ? ssl_client_->Post(endpoint, headers, request.dump(), "application/json") auto res = use_https_ ? ssl_client_->Post(endpoint, headers, request.dump(),
: client_->Post(endpoint, headers, request.dump(), "application/json"); "application/json")
: client_->Post(endpoint, headers, request.dump(),
"application/json");
if (res && res->status == 200) { if (res && res->status == 200) {
json response = json::parse(res->body); json response = json::parse(res->body);
if (!interactive_mode_) { if (!interactive_mode_) {
std::cout << "\n" << BOLD << GREEN << "📥 Response:" << RESET << std::endl; std::cout << "\n"
<< BOLD << GREEN << "📥 Response:" << RESET << std::endl;
std::cout << response.dump(2) << std::endl; std::cout << response.dump(2) << std::endl;
} }
return response; return response;
} else { } else {
json error_response = { json error_response = {{"error",
{"error", { {{"message", "Request failed"},
{"message", "Request failed"}, {"status", res ? res->status : -1}}}};
{"status", res ? res->status : -1}
}}
};
if (res && !res->body.empty()) { if (res && !res->body.empty()) {
error_response["error"]["details"] = res->body; error_response["error"]["details"] = res->body;
} }
std::cerr << RED << "❌ Request failed. Status: " << (res ? res->status : -1) << RESET << std::endl; std::cerr << RED
<< "❌ Request failed. Status: " << (res ? res->status : -1)
<< RESET << std::endl;
return error_response; return error_response;
} }
} }
json ProcessStreamingRequest(const json& request, const std::string& endpoint) { json ProcessStreamingRequest(const json& request,
const std::string& endpoint) {
std::string accumulated_response; std::string accumulated_response;
// Use same SSE logic for both public and local APIs // Use same SSE logic for both public and local APIs
@ -233,71 +255,72 @@ private:
} }
req.body = request.dump(); req.body = request.dump();
req.content_receiver = [&accumulated_response, this](const char* data, size_t data_length, uint64_t offset, uint64_t total_length) -> bool { req.content_receiver = [&accumulated_response, this](
std::string chunk(data, data_length); const char* data, size_t data_length,
std::istringstream stream(chunk); uint64_t offset, uint64_t total_length) -> bool {
std::string line; std::string chunk(data, data_length);
std::istringstream stream(chunk);
std::string line;
while (std::getline(stream, line)) { while (std::getline(stream, line)) {
if (line.substr(0, 6) == "data: ") { if (line.substr(0, 6) == "data: ") {
std::string event_data = line.substr(6); std::string event_data = line.substr(6);
if (event_data == "[DONE]") { if (event_data == "[DONE]") {
if (!interactive_mode_) { if (!interactive_mode_) {
std::cout << "\n\n" << GREEN << "✅ Generation complete!" << RESET << std::endl; std::cout << "\n\n"
} << GREEN << "✅ Generation complete!" << RESET
} else { << std::endl;
try { }
json event = json::parse(event_data); } else {
if (event.contains("candidates") && !event["candidates"].empty()) { try {
auto& candidate = event["candidates"][0]; json event = json::parse(event_data);
if (candidate.contains("content") && candidate["content"].contains("parts")) { if (event.contains("candidates") &&
for (const auto& part : candidate["content"]["parts"]) { !event["candidates"].empty()) {
if (part.contains("text")) { auto& candidate = event["candidates"][0];
std::string text = part["text"].get<std::string>(); if (candidate.contains("content") &&
std::cout << text << std::flush; candidate["content"].contains("parts")) {
accumulated_response += text; for (const auto& part : candidate["content"]["parts"]) {
} if (part.contains("text")) {
std::string text = part["text"].get<std::string>();
std::cout << text << std::flush;
accumulated_response += text;
} }
} }
} }
} catch (const json::exception& e) {
// Skip parse errors
} }
} catch (const json::exception& e) {
// Skip parse errors
} }
} }
} }
return true; }
}; return true;
};
httplib::Response res; httplib::Response res;
httplib::Error error; httplib::Error error;
bool success = use_https_ ? ssl_client_->send(req, res, error) : client_->send(req, res, error); bool success = use_https_ ? ssl_client_->send(req, res, error)
: client_->send(req, res, error);
if (res.status == 200 && !accumulated_response.empty()) { if (res.status == 200 && !accumulated_response.empty()) {
return json{ return json{
{"candidates", {{ {"candidates",
{"content", { {{{"content", {{"parts", {{{"text", accumulated_response}}}}}}}}}};
{"parts", {{{"text", accumulated_response}}}}
}}
}}}
};
} else { } else {
json error_response = { json error_response = {
{"error", { {"error",
{"message", "Streaming request failed"}, {{"message", "Streaming request failed"}, {"status", res.status}}}};
{"status", res.status}
}}
};
if (!res.body.empty()) { if (!res.body.empty()) {
error_response["error"]["details"] = res.body; error_response["error"]["details"] = res.body;
} }
std::cerr << RED << "❌ Streaming request failed. Status: " << res.status << RESET << std::endl; std::cerr << RED << "❌ Streaming request failed. Status: " << res.status
<< RESET << std::endl;
return error_response; return error_response;
} }
} }
private: private:
std::unique_ptr<httplib::Client> client_; std::unique_ptr<httplib::Client> client_;
std::unique_ptr<httplib::SSLClient> ssl_client_; std::unique_ptr<httplib::SSLClient> ssl_client_;
std::string host_; std::string host_;
@ -308,19 +331,55 @@ private:
bool interactive_mode_; bool interactive_mode_;
}; };
struct ClientArgs : public ArgsBase<ClientArgs> {
ClientArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
ClientArgs() { Init(); };
std::string host;
int port;
std::string api_key;
std::string model;
std::string prompt;
bool interactive;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(host, "host", std::string("localhost"),
"Server host (default: localhost)");
visitor(port, "port", 8080, "Server port (default: 8080)");
visitor(api_key, "api_key", std::string(""),
"Use public API with key (changes host to "
"generativelanguage.googleapis.com:443)");
visitor(model, "model", std::string("gemma3-4b"),
"Model name to use (default: gemma3-4b)");
visitor(prompt, "prompt", std::string("Hello! How are you?"),
"Prompt for generation (default: 'Hello! How are you?')");
visitor(interactive, "interactive", false,
"Start interactive chat mode (0 = no, 1 = yes)");
}
};
} // namespace gcpp
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gcpp::ClientArgs client_args(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
gcpp::ClientArgs client_args(argc, argv, consumed);
if (gcpp::HasHelp(argc, argv)) { if (gcpp::HasHelp(argc, argv)) {
std::cout << "\nAPI Client for gemma.cpp\n"; fprintf(stderr,
std::cout << "========================\n\n"; "\nAPI Client for gemma.cpp\n"
"========================\n\n");
client_args.Help(); client_args.Help();
std::cout << std::endl; fprintf(stderr,
std::cout << "Environment Variables:" << std::endl; "\n*Environment Variables:\n"
std::cout << " GOOGLE_API_KEY : Automatically use public Google API if set" << std::endl; " GOOGLE_API_KEY : Automatically use public Google API if set\n");
return 0; return 0;
} }
consumed.AbortIfUnconsumed();
// Check for GOOGLE_API_KEY environment variable // Check for GOOGLE_API_KEY environment variable
const char* env_api_key = std::getenv("GOOGLE_API_KEY"); const char* env_api_key = std::getenv("GOOGLE_API_KEY");
if (env_api_key != nullptr && strlen(env_api_key) > 0) { if (env_api_key != nullptr && strlen(env_api_key) > 0) {
@ -335,11 +394,12 @@ int main(int argc, char* argv[]) {
client_args.port = 443; client_args.port = 443;
} }
std::cout << BOLD << YELLOW << "🚀 Testing API Server at " std::cout << BOLD << YELLOW << "🚀 Testing API Server at " << client_args.host
<< client_args.host << ":" << client_args.port << RESET << std::endl; << ":" << client_args.port << RESET << std::endl;
try { try {
APIClient client(client_args.host, client_args.port, client_args.api_key, client_args.model); APIClient client(client_args.host, client_args.port, client_args.api_key,
client_args.model);
if (client_args.interactive) { if (client_args.interactive) {
client.InteractiveChat(); client.InteractiveChat();
@ -347,11 +407,12 @@ int main(int argc, char* argv[]) {
client.TestListModels(); client.TestListModels();
client.TestGenerateContent(client_args.prompt, true); client.TestGenerateContent(client_args.prompt, true);
} }
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::cerr << RED << "❌ Error: " << e.what() << RESET << std::endl; std::cerr << RED << "❌ Error: " << e.what() << RESET << std::endl;
std::cerr << "Make sure the API server is running:" << std::endl; std::cerr << "Make sure the API server is running:" << std::endl;
std::cerr << " ./build/gemma_api_server --tokenizer <path> --weights <path>" << std::endl; std::cerr
<< " ./build/gemma_api_server --tokenizer <path> --weights <path>"
<< std::endl;
return 1; return 1;
} }

View File

@ -15,22 +15,19 @@
// HTTP API server for gemma.cpp with SSE support // HTTP API server for gemma.cpp with SSE support
#include <stdio.h>
#include <signal.h> #include <signal.h>
#include <stdio.h>
#include <iostream>
#include <memory>
#include <random>
#include <string>
#include <string_view>
#include <vector>
#include <thread>
#include <atomic> #include <atomic>
#include <chrono> #include <chrono>
#include <sstream> #include <iostream>
#include <iomanip> #include <memory>
#include <mutex> #include <mutex>
#include <sstream>
#include <string>
#include <thread> // NOLINT
#include <unordered_map> #include <unordered_map>
#include <vector>
// HTTP server library // HTTP server library
#undef CPPHTTPLIB_OPENSSL_SUPPORT #undef CPPHTTPLIB_OPENSSL_SUPPORT
@ -38,16 +35,12 @@
#include "httplib.h" #include "httplib.h"
// JSON library // JSON library
#include "nlohmann/json.hpp"
#include "compression/types.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/gemma_args.h" #include "gemma/gemma_args.h"
#include "gemma/tokenizer.h"
#include "ops/matmul.h" #include "ops/matmul.h"
#include "util/args.h" #include "util/args.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/profiler.h" #include "nlohmann/json.hpp"
using json = nlohmann::json; using json = nlohmann::json;
@ -90,7 +83,8 @@ struct ServerState {
std::lock_guard<std::mutex> lock(sessions_mutex); std::lock_guard<std::mutex> lock(sessions_mutex);
auto& session = sessions[session_id]; auto& session = sessions[session_id];
if (!session.kv_cache) { if (!session.kv_cache) {
session.kv_cache = std::make_unique<KVCache>(gemma->Config(), InferenceArgs(), env->ctx.allocator); session.kv_cache = std::make_unique<KVCache>(
gemma->Config(), InferenceArgs(), env->ctx.allocator);
} }
session.last_access = std::chrono::steady_clock::now(); session.last_access = std::chrono::steady_clock::now();
return session; return session;
@ -107,7 +101,8 @@ std::string GenerateSessionId() {
return ss.str(); return ss.str();
} }
// Wraps messages with start_of_turn markers - handles both with and without roles // Wraps messages with start_of_turn markers - handles both with and without
// roles
std::string WrapMessagesWithTurnMarkers(const json& contents) { std::string WrapMessagesWithTurnMarkers(const json& contents) {
std::string prompt; std::string prompt;
@ -121,12 +116,14 @@ std::string WrapMessagesWithTurnMarkers(const json& contents) {
std::string text = part["text"]; std::string text = part["text"];
if (role == "user") { if (role == "user") {
prompt += "<start_of_turn>user\n" + text + "\n<start_of_turn>model\n"; prompt +=
"<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
} else if (role == "model") { } else if (role == "model") {
prompt += text + "\n"; prompt += text + "\n";
} else if (role.empty()) { } else if (role.empty()) {
// Local format without roles - for now, treat as user input // Local format without roles - for now, treat as user input
prompt += "<start_of_turn>user\n" + text + "\n<start_of_turn>model\n"; prompt +=
"<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
} }
} }
} }
@ -163,18 +160,15 @@ RuntimeConfig ParseGenerationConfig(const json& request) {
return config; return config;
} }
// Unified response formatter - creates consistent format regardless of request type // Unified response formatter - creates consistent format regardless of request
json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false) { // type
json CreateAPIResponse(const std::string& text,
bool is_streaming_chunk = false) {
json response = { json response = {
{"candidates", {{ {"candidates",
{"content", { {{{"content", {{"parts", {{{"text", text}}}}, {"role", "model"}}},
{"parts", {{{"text", text}}}}, {"index", 0}}}},
{"role", "model"} {"promptFeedback", {{"safetyRatings", json::array()}}}};
}},
{"index", 0}
}}},
{"promptFeedback", {{"safetyRatings", json::array()}}}
};
// Only add finishReason for non-streaming chunks // Only add finishReason for non-streaming chunks
if (!is_streaming_chunk) { if (!is_streaming_chunk) {
@ -185,7 +179,9 @@ json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false)
} }
// Handle generateContent endpoint (non-streaming) // Handle generateContent endpoint (non-streaming)
void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) { void HandleGenerateContentNonStreaming(ServerState& state,
const httplib::Request& req,
httplib::Response& res) {
try { try {
json request = json::parse(req.body); json request = json::parse(req.body);
@ -199,7 +195,9 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
prompt = WrapMessagesWithTurnMarkers(request["contents"]); prompt = WrapMessagesWithTurnMarkers(request["contents"]);
} else { } else {
res.status = 400; res.status = 400;
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json"); res.set_content(
json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(),
"application/json");
return; return;
} }
@ -209,12 +207,7 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
// Set up runtime config // Set up runtime config
RuntimeConfig runtime_config = ParseGenerationConfig(request); RuntimeConfig runtime_config = ParseGenerationConfig(request);
// Collect full response runtime_config.stream_token = [](int token, float) { return true; };
std::string full_response;
runtime_config.stream_token = [&full_response](int token, float) {
// Skip EOS token
return true;
};
// Tokenize prompt // Tokenize prompt
std::vector<int> tokens = WrapAndTokenize( std::vector<int> tokens = WrapAndTokenize(
@ -227,7 +220,8 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
// Temporarily redirect output to capture response // Temporarily redirect output to capture response
std::stringstream output; std::stringstream output;
runtime_config.stream_token = [&output, &state, &session, &tokens](int token, float) { runtime_config.stream_token = [&output, &state, &session, &tokens](
int token, float) {
// Skip prompt tokens // Skip prompt tokens
if (session.abs_pos < tokens.size()) { if (session.abs_pos < tokens.size()) {
session.abs_pos++; session.abs_pos++;
@ -279,7 +273,9 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
} }
// Handle streamGenerateContent endpoint with SSE) // Handle streamGenerateContent endpoint with SSE)
void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) { void HandleGenerateContentStreaming(ServerState& state,
const httplib::Request& req,
httplib::Response& res) {
try { try {
json request = json::parse(req.body); json request = json::parse(req.body);
@ -293,7 +289,9 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
prompt = WrapMessagesWithTurnMarkers(request["contents"]); prompt = WrapMessagesWithTurnMarkers(request["contents"]);
} else { } else {
res.status = 400; res.status = 400;
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json"); res.set_content(
json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(),
"application/json");
return; return;
} }
@ -305,88 +303,85 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
// Set up chunked content provider for SSE // Set up chunked content provider for SSE
res.set_chunked_content_provider( res.set_chunked_content_provider(
"text/event-stream", "text/event-stream", [&state, request, prompt, session_id](
[&state, request, prompt, session_id](size_t offset, httplib::DataSink& sink) { size_t offset, httplib::DataSink& sink) {
try { try {
// Lock for inference // Lock for inference
std::lock_guard<std::mutex> lock(state.inference_mutex); std::lock_guard<std::mutex> lock(state.inference_mutex);
auto& session = state.GetOrCreateSession(session_id); auto& session = state.GetOrCreateSession(session_id);
// Set up runtime config // Set up runtime config
RuntimeConfig runtime_config = ParseGenerationConfig(request); RuntimeConfig runtime_config = ParseGenerationConfig(request);
// Tokenize prompt // Tokenize prompt
std::vector<int> tokens = WrapAndTokenize( std::vector<int> tokens = WrapAndTokenize(
state.gemma->Tokenizer(), state.gemma->ChatTemplate(), state.gemma->Tokenizer(), state.gemma->ChatTemplate(),
state.gemma->Config().wrapping, session.abs_pos, prompt); state.gemma->Config().wrapping, session.abs_pos, prompt);
// Stream token callback
std::string accumulated_text;
auto stream_token = [&](int token, float) {
// Skip prompt tokens
if (session.abs_pos < tokens.size()) {
session.abs_pos++;
return true;
}
// Stream token callback
std::string accumulated_text;
auto stream_token = [&](int token, float) {
// Skip prompt tokens
if (session.abs_pos < tokens.size()) {
session.abs_pos++; session.abs_pos++;
// Check for EOS
if (state.gemma->Config().IsEOS(token)) {
return true;
}
// Decode token
std::string token_text;
state.gemma->Tokenizer().Decode(std::vector<int>{token},
&token_text);
accumulated_text += token_text;
// Send SSE event using unified formatter
json event = CreateAPIResponse(token_text, true);
std::string sse_data = "data: " + event.dump() + "\n\n";
sink.write(sse_data.data(), sse_data.size());
return true; return true;
} };
session.abs_pos++; runtime_config.stream_token = stream_token;
// Check for EOS // Run inference with KV cache
if (state.gemma->Config().IsEOS(token)) { TimingInfo timing_info = {.verbosity = 0};
return true; size_t prefix_end = 0;
}
// Decode token state.gemma->Generate(runtime_config, tokens, session.abs_pos,
std::string token_text; prefix_end, *session.kv_cache, *state.env,
state.gemma->Tokenizer().Decode(std::vector<int>{token}, &token_text); timing_info);
accumulated_text += token_text;
// Send SSE event using unified formatter // Send final event using unified formatter
json event = CreateAPIResponse(token_text, true); json final_event = CreateAPIResponse("", false);
final_event["usageMetadata"] = {
{"promptTokenCount", tokens.size()},
{"candidatesTokenCount", session.abs_pos - tokens.size()},
{"totalTokenCount", session.abs_pos}};
std::string sse_data = "data: " + event.dump() + "\n\n"; std::string final_sse = "data: " + final_event.dump() + "\n\n";
sink.write(sse_data.data(), sse_data.size()); sink.write(final_sse.data(), final_sse.size());
return true; // Send done event
}; sink.write("data: [DONE]\n\n", 15);
runtime_config.stream_token = stream_token;
// Run inference with KV cache
TimingInfo timing_info = {.verbosity = 0};
size_t prefix_end = 0;
state.gemma->Generate(runtime_config, tokens, session.abs_pos,
prefix_end, *session.kv_cache, *state.env,
timing_info);
// Send final event using unified formatter
json final_event = CreateAPIResponse("", false);
final_event["usageMetadata"] = {
{"promptTokenCount", tokens.size()},
{"candidatesTokenCount", session.abs_pos - tokens.size()},
{"totalTokenCount", session.abs_pos}
};
std::string final_sse = "data: " + final_event.dump() + "\n\n";
sink.write(final_sse.data(), final_sse.size());
// Send done event
sink.write("data: [DONE]\n\n", 15);
// Ensure all data is sent
sink.done();
return false; // End streaming
} catch (const std::exception& e) {
json error_event = {{"error", {{"message", e.what()}}}};
std::string error_sse = "data: " + error_event.dump() + "\n\n";
sink.write(error_sse.data(), error_sse.size());
return false;
}
}
);
// Ensure all data is sent
sink.done();
return false; // End streaming
} catch (const std::exception& e) {
json error_event = {{"error", {{"message", e.what()}}}};
std::string error_sse = "data: " + error_event.dump() + "\n\n";
sink.write(error_sse.data(), error_sse.size());
return false;
}
});
} catch (const json::exception& e) { } catch (const json::exception& e) {
res.status = 400; res.status = 400;
res.set_content( res.set_content(
@ -398,20 +393,20 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
} }
// Handle models list endpoint // Handle models list endpoint
void HandleListModels(ServerState& state, const InferenceArgs& inference, const httplib::Request& req, httplib::Response& res) { void HandleListModels(ServerState& state, const InferenceArgs& inference,
const httplib::Request& req, httplib::Response& res) {
json response = { json response = {
{"models", {{ {"models",
{"name", "models/" + inference.model}, {{{"name", "models/" + inference.model},
{"version", "001"}, {"version", "001"},
{"displayName", inference.model}, {"displayName", inference.model},
{"description", inference.model + " model running locally"}, {"description", inference.model + " model running locally"},
{"inputTokenLimit", 8192}, {"inputTokenLimit", 8192},
{"outputTokenLimit", 8192}, {"outputTokenLimit", 8192},
{"supportedGenerationMethods", json::array({"generateContent", "streamGenerateContent"})}, {"supportedGenerationMethods",
{"temperature", 1.0}, json::array({"generateContent", "streamGenerateContent"})},
{"topK", 1} {"temperature", 1.0},
}}} {"topK", 1}}}}};
};
res.set_content(response.dump(), "application/json"); res.set_content(response.dump(), "application/json");
} }
@ -421,39 +416,45 @@ void HandleListModels(ServerState& state, const InferenceArgs& inference, const
// server_running = false; // server_running = false;
// } // }
void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading, void RunServer(const GemmaArgs& args) {
const InferenceArgs& inference) {
std::cerr << "Loading model..." << std::endl; std::cerr << "Loading model..." << std::endl;
// Initialize model // Initialize model
ThreadingContext ctx(threading); ThreadingContext ctx(args.threading);
MatMulEnv env(ctx); MatMulEnv env(ctx);
ServerState state; ServerState state;
state.gemma = std::make_unique<Gemma>(loader, inference, ctx); state.gemma = std::make_unique<Gemma>(args, ctx);
state.env = &env; state.env = &env;
state.ctx = &ctx; state.ctx = &ctx;
const InferenceArgs& inference = args.inference;
httplib::Server server; httplib::Server server;
// Set up routes // Set up routes
server.Get("/", [&inference](const httplib::Request&, httplib::Response& res) { server.Get(
res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" + inference.model + ":generateContent", "text/plain"); "/", [&inference](const httplib::Request&, httplib::Response& res) {
}); res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" +
inference.model + ":generateContent",
"text/plain");
});
// API endpoints // API endpoints
server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req, httplib::Response& res) { server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req,
httplib::Response& res) {
HandleListModels(state, inference, req, res); HandleListModels(state, inference, req, res);
}); });
std::string model_endpoint = "/v1beta/models/" + inference.model; std::string model_endpoint = "/v1beta/models/" + inference.model;
server.Post(model_endpoint + ":generateContent", [&state](const httplib::Request& req, httplib::Response& res) { server.Post(model_endpoint + ":generateContent",
HandleGenerateContentNonStreaming(state, req, res); [&state](const httplib::Request& req, httplib::Response& res) {
}); HandleGenerateContentNonStreaming(state, req, res);
});
server.Post(model_endpoint + ":streamGenerateContent", [&state](const httplib::Request& req, httplib::Response& res) { server.Post(model_endpoint + ":streamGenerateContent",
HandleGenerateContentStreaming(state, req, res); [&state](const httplib::Request& req, httplib::Response& res) {
}); HandleGenerateContentStreaming(state, req, res);
});
// Periodic cleanup of old sessions // Periodic cleanup of old sessions
std::thread cleanup_thread([&state]() { std::thread cleanup_thread([&state]() {
@ -466,12 +467,15 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
std::cerr << "Starting API server on port " << inference.port << std::endl; std::cerr << "Starting API server on port " << inference.port << std::endl;
std::cerr << "Model loaded successfully" << std::endl; std::cerr << "Model loaded successfully" << std::endl;
std::cerr << "Endpoints:" << std::endl; std::cerr << "Endpoints:" << std::endl;
std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent" << std::endl; std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent"
std::cerr << " POST /v1beta/models/" << inference.model << ":streamGenerateContent (SSE)" << std::endl; << std::endl;
std::cerr << " POST /v1beta/models/" << inference.model
<< ":streamGenerateContent (SSE)" << std::endl;
std::cerr << " GET /v1beta/models" << std::endl; std::cerr << " GET /v1beta/models" << std::endl;
if (!server.listen("0.0.0.0", inference.port)) { if (!server.listen("0.0.0.0", inference.port)) {
std::cerr << "Failed to start server on port " << inference.port << std::endl; std::cerr << "Failed to start server on port " << inference.port
<< std::endl;
} }
cleanup_thread.join(); cleanup_thread.join();
@ -482,35 +486,27 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
int main(int argc, char** argv) { int main(int argc, char** argv) {
gcpp::InternalInit(); gcpp::InternalInit();
gcpp::LoaderArgs loader(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
gcpp::ThreadingArgs threading(argc, argv); gcpp::GemmaArgs args(argc, argv, consumed);
gcpp::InferenceArgs inference(argc, argv);
if (gcpp::HasHelp(argc, argv)) { if (gcpp::HasHelp(argc, argv)) {
std::cerr << "\n\nAPI server for gemma.cpp\n"; fprintf(
std::cout << "========================\n\n"; stderr,
std::cerr << "Usage: " << argv[0] << " --weights <path> --tokenizer <path> [options]\n"; "\n\nAPI server for gemma.cpp\n"
std::cerr << "\nOptions:\n"; "========================\n\n"
std::cerr << " --port PORT Server port (default: 8080)\n"; " --port PORT Server port (default: 8080)\n"
std::cerr << " --model MODEL Model name for endpoints (default: gemma3-4b)\n"; " --model MODEL Model name for endpoints (default: gemma3-4b)\n");
std::cerr << "\n"; args.Help();
std::cerr << "\n*Model Loading Arguments*\n\n";
loader.Help();
std::cerr << "\n*Threading Arguments*\n\n";
threading.Help();
std::cerr << "\n*Inference Arguments*\n\n";
inference.Help();
std::cerr << "\n";
return 0; return 0;
} }
// Arguments are now handled by InferenceArgs consumed.AbortIfUnconsumed();
// // Set up signal handler // // Set up signal handler
// signal(SIGINT, gcpp::HandleShutdown); // signal(SIGINT, gcpp::HandleShutdown);
// signal(SIGTERM, gcpp::HandleShutdown); // signal(SIGTERM, gcpp::HandleShutdown);
gcpp::RunServer(loader, threading, inference); gcpp::RunServer(args);
return 0; return 0;
} }

View File

@ -73,45 +73,38 @@ GemmaContext* GemmaContext::Create(const char* tokenizer_path,
ThreadingArgs threading_args; ThreadingArgs threading_args;
threading_args.spin = gcpp::Tristate::kFalse; threading_args.spin = gcpp::Tristate::kFalse;
LoaderArgs loader(tokenizer_path, weights_path); threading_args.spin = gcpp::Tristate::kFalse;
LogDebug("LoaderArgs created"); GemmaArgs args(LoaderArgs(tokenizer_path, weights_path), threading_args);
// Initialize cached args // Initialize cached args
LogDebug("Initializing inference args"); LogDebug("Initializing inference args");
InferenceArgs inference_args; args.inference.max_generated_tokens = max_generated_tokens;
inference_args.Init(); args.inference.temperature = 0.7f;
inference_args.max_generated_tokens = max_generated_tokens; args.inference.top_k = 1;
inference_args.temperature = 0.7f; args.inference.deterministic = false;
inference_args.top_k = 1;
inference_args.deterministic = false;
ss.str(""); ss.str("");
ss << "Inference args initialized with max_tokens: " << max_generated_tokens ss << "Inference args initialized with max_tokens: " << max_generated_tokens
<< ", temperature: " << inference_args.temperature << ", temperature: " << args.inference.temperature
<< ", top_k: " << inference_args.top_k << ", deterministic: " << ", top_k: " << args.inference.top_k << ", deterministic: "
<< (inference_args.deterministic ? "true" : "false"); << (args.inference.deterministic ? "true" : "false");
LogDebug(ss.str().c_str()); LogDebug(ss.str().c_str());
return new GemmaContext(loader, inference_args, threading_args, return new GemmaContext(args, max_generated_tokens);
max_generated_tokens);
} }
GemmaContext::GemmaContext(const LoaderArgs& loader, GemmaContext::GemmaContext(const GemmaArgs& args, int max_generated_tokens)
const InferenceArgs& inference_args, : args(args),
const ThreadingArgs& threading_args, ctx(args.threading),
int max_generated_tokens)
: inference_args(inference_args),
threading_args(threading_args),
ctx(threading_args),
matmul_env(ctx), matmul_env(ctx),
active_conversation_name("default"), active_conversation_name("default"),
model(loader, inference_args, matmul_env.ctx) { model(args, matmul_env.ctx) {
std::stringstream ss; std::stringstream ss;
LogDebug("Creating initial ConversationData"); LogDebug("Creating initial ConversationData");
// Create the initial ConversationData object using make_shared // Create the initial ConversationData object using make_shared
active_conversation = std::make_shared<ConversationData>( active_conversation = std::make_shared<ConversationData>(
model.Config(), inference_args, ctx.allocator); model.Config(), args.inference, ctx.allocator);
LogDebug( LogDebug(
"Storing initial ConversationData in conversation_cache[\"default\"]"); "Storing initial ConversationData in conversation_cache[\"default\"]");
@ -172,8 +165,8 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
// set up runtime config // set up runtime config
TimingInfo timing_info = {}; TimingInfo timing_info = {};
RuntimeConfig runtime_config = {.stream_token = stream_token, RuntimeConfig runtime_config = {.stream_token = stream_token,
.use_spinning = threading_args.spin}; .use_spinning = args.threading.spin};
inference_args.CopyTo(runtime_config); args.inference.CopyTo(runtime_config);
size_t prefix_end = 0; size_t prefix_end = 0;
const ModelConfig& model_config = model.Config(); const ModelConfig& model_config = model.Config();
@ -247,7 +240,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
timing_info); timing_info);
// prepare for next turn // prepare for next turn
if (!inference_args.multiturn || if (!args.inference.multiturn ||
model_config.wrapping == PromptWrapping::PALIGEMMA) { model_config.wrapping == PromptWrapping::PALIGEMMA) {
// If not multiturn, or Paligemma (which handles turns differently), // If not multiturn, or Paligemma (which handles turns differently),
// reset the *active* conversation's position. // reset the *active* conversation's position.

View File

@ -53,8 +53,7 @@ typedef void (*GemmaLogCallback)(const char* message, void* user_data);
class GemmaContext { class GemmaContext {
private: private:
GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args, GemmaContext(const GemmaArgs& args, int max_generated_tokens);
const ThreadingArgs& threading_args, int max_generated_tokens);
public: public:
static GemmaContext* Create(const char* tokenizer_path, static GemmaContext* Create(const char* tokenizer_path,
@ -81,37 +80,37 @@ class GemmaContext {
// Set max generated tokens // Set max generated tokens
void SetMaxGeneratedTokens(size_t value) { void SetMaxGeneratedTokens(size_t value) {
inference_args.max_generated_tokens = value; args.inference.max_generated_tokens = value;
LogDebug("Setting max_generated_tokens to configured value"); LogDebug("Setting max_generated_tokens to configured value");
} }
// Set multiturn flag (0 = disabled, 1 = enabled) // Set multiturn flag (0 = disabled, 1 = enabled)
void SetMultiturn(int value) { void SetMultiturn(int value) {
inference_args.multiturn = value; args.inference.multiturn = value;
LogDebug("Setting multiturn to configured value"); LogDebug("Setting multiturn to configured value");
} }
// Set temperature for token generation // Set temperature for token generation
void SetTemperature(float value) { void SetTemperature(float value) {
inference_args.temperature = value; args.inference.temperature = value;
LogDebug("Setting temperature to configured value"); LogDebug("Setting temperature to configured value");
} }
// Set top_k parameter for sampling // Set top_k parameter for sampling
void SetTopK(int value) { void SetTopK(int value) {
inference_args.top_k = value; args.inference.top_k = value;
LogDebug("Setting top_k to configured value"); LogDebug("Setting top_k to configured value");
} }
// Set deterministic flag // Set deterministic flag
void SetDeterministic(bool value) { void SetDeterministic(bool value) {
inference_args.deterministic = value; args.inference.deterministic = value;
LogDebug("Setting deterministic flag to configured value"); LogDebug("Setting deterministic flag to configured value");
} }
// Set prefill_tbatch_size // Set prefill_tbatch_size
void SetPrefillTbatchSize(size_t value) { void SetPrefillTbatchSize(size_t value) {
inference_args.prefill_tbatch_size = value; args.inference.prefill_tbatch_size = value;
LogDebug("Setting prefill_tbatch_size to configured value"); LogDebug("Setting prefill_tbatch_size to configured value");
} }
@ -175,7 +174,7 @@ class GemmaContext {
active_conversation->abs_pos = 0; active_conversation->abs_pos = 0;
// Replace the cache within the current ConversationData object // Replace the cache within the current ConversationData object
active_conversation->kv_cache = std::make_unique<KVCache>( active_conversation->kv_cache = std::make_unique<KVCache>(
model.Config(), inference_args, ctx.allocator); model.Config(), args.inference, ctx.allocator);
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str()); LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
} else { } else {
@ -193,7 +192,7 @@ class GemmaContext {
LogDebug("Creating new conversation"); LogDebug("Creating new conversation");
// Create a new ConversationData object using make_shared // Create a new ConversationData object using make_shared
conversation_cache[name] = std::make_shared<ConversationData>( conversation_cache[name] = std::make_shared<ConversationData>(
model.Config(), inference_args, ctx.allocator); model.Config(), args.inference, ctx.allocator);
return true; return true;
} }
@ -274,8 +273,7 @@ class GemmaContext {
std::vector<int> token_buffer; std::vector<int> token_buffer;
// Cached args (remain global for the context) // Cached args (remain global for the context)
InferenceArgs inference_args; GemmaArgs args;
ThreadingArgs threading_args;
ThreadingContext ctx; ThreadingContext ctx;
MatMulEnv matmul_env; MatMulEnv matmul_env;

View File

@ -738,17 +738,16 @@ HWY_EXPORT(GenerateSingleT);
HWY_EXPORT(GenerateBatchT); HWY_EXPORT(GenerateBatchT);
HWY_EXPORT(GenerateImageTokensT); HWY_EXPORT(GenerateImageTokensT);
Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference, Gemma::Gemma(const GemmaArgs& args, ThreadingContext& ctx)
ThreadingContext& ctx) : reader_(args.loader.weights),
: reader_(loader.weights), model_(reader_, args.loader.tokenizer, args.loader.wrapping),
model_(reader_, loader.tokenizer, loader.wrapping),
weights_(model_.Config()), weights_(model_.Config()),
chat_template_(model_.Tokenizer(), model_.Config().model), chat_template_(model_.Tokenizer(), model_.Config().model),
inference_(inference), inference_(args.inference),
aes_ctr_engine_(inference.deterministic) { aes_ctr_engine_(args.inference.deterministic) {
// Negligible CPU time in the ctor body (except ReadFromBlobs). // Negligible CPU time in the ctor body (except ReadFromBlobs).
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference, weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, args.loader,
mat_owners_, ctx); args.inference, mat_owners_, ctx);
// Read everything into memory, or `weights_.mapped_` keeps the mapping alive. // Read everything into memory, or `weights_.mapped_` keeps the mapping alive.
reader_.CloseFile(); reader_.CloseFile();
} }

View File

@ -130,11 +130,16 @@ struct TimingInfo {
// separate `ThreadingContext` and `MatMulEnv` for each concurrent `Generate`. // separate `ThreadingContext` and `MatMulEnv` for each concurrent `Generate`.
class Gemma { class Gemma {
public: public:
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`. // Reads weights/config/tokenizer from `BlobStore` at `args.loader.weights`.
// `ctx` is only used to read tensors and not stored. Calls to `Generate*` // `ctx` is only used to read tensors and not stored. Calls to `Generate*`
// may reference the same, or other `ThreadingContext` via `MatMulEnv`. // may reference the same, or other `ThreadingContext` via `MatMulEnv`.
Gemma(const GemmaArgs& args, ThreadingContext& ctx);
// Deprecated prior interface for backwards compatibility.
Gemma(const LoaderArgs& loader, const InferenceArgs& inference, Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
ThreadingContext& ctx); ThreadingContext& ctx)
: Gemma(GemmaArgs(loader, ThreadingArgs(), inference), ctx) {}
~Gemma(); ~Gemma();
const ModelConfig& Config() const { return model_.Config(); } const ModelConfig& Config() const { return model_.Config(); }

View File

@ -25,10 +25,11 @@
#include <string> #include <string>
#include "gemma/configs.h" #include "gemma/configs.h"
#include "io/io.h" // Path #include "io/io.h" // Path
#include "util/args.h" #include "util/args.h" // IWYU pragma: export
#include "util/basics.h" // Tristate #include "util/basics.h" // Tristate
#include "util/mat.h" #include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h" // Span #include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" // HWY_ABORT #include "hwy/base.h" // HWY_ABORT
#include "hwy/profiler.h" #include "hwy/profiler.h"
@ -36,7 +37,9 @@
namespace gcpp { namespace gcpp {
struct LoaderArgs : public ArgsBase<LoaderArgs> { struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } LoaderArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
LoaderArgs(const std::string& tokenizer_path, LoaderArgs(const std::string& tokenizer_path,
const std::string& weights_path) { const std::string& weights_path) {
Init(); // Init sets to defaults, so assignments must come after Init(). Init(); // Init sets to defaults, so assignments must come after Init().
@ -169,7 +172,9 @@ struct RuntimeConfig {
}; };
struct InferenceArgs : public ArgsBase<InferenceArgs> { struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } InferenceArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
InferenceArgs() { Init(); }; InferenceArgs() { Init(); };
bool IsInteractive() const { return prompt.empty() && prompt_file.Empty(); } bool IsInteractive() const { return prompt.empty() && prompt_file.Empty(); }
@ -275,33 +280,35 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
} }
}; };
struct ClientArgs : public ArgsBase<ClientArgs> { // Bundles all args required to construct a `GemmaEnv` or the equivalent.
ClientArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } struct GemmaArgs {
ClientArgs() { Init(); }; // For callers that do not parse command line args.
GemmaArgs(const LoaderArgs& loader,
const ThreadingArgs& threading = ThreadingArgs(),
const InferenceArgs& inference = InferenceArgs())
: loader(loader), threading(threading), inference(inference) {}
std::string host; GemmaArgs(int argc, char** argv, ConsumedArgs& consumed)
int port; : loader(argc, argv, consumed),
std::string api_key; threading(argc, argv, consumed),
std::string model; inference(argc, argv, consumed) {}
std::string prompt;
bool interactive;
template <class Visitor> void Help() {
void ForEach(const Visitor& visitor) { fprintf(stderr,
visitor(host, "host", std::string("localhost"), "To run with pre-2025 weights, specify --tokenizer and --weights.\n"
"Server host (default: localhost)"); "With the single-file weights format, specify just --weights.\n"
visitor(port, "port", 8080, "\n*Model Loading Arguments*\n");
"Server port (default: 8080)"); loader.Help();
visitor(api_key, "api_key", std::string(""), fprintf(stderr, "\n*Threading Arguments*\n");
"Use public API with key (changes host to " threading.Help();
"generativelanguage.googleapis.com:443)"); fprintf(stderr, "\n*Inference Arguments*\n");
visitor(model, "model", std::string("gemma3-4b"), inference.Help();
"Model name to use (default: gemma3-4b)"); fprintf(stderr, "\n");
visitor(prompt, "prompt", std::string("Hello! How are you?"),
"Prompt for generation (default: 'Hello! How are you?')");
visitor(interactive, "interactive", false,
"Start interactive chat mode (0 = no, 1 = yes)");
} }
LoaderArgs loader;
ThreadingArgs threading;
InferenceArgs inference;
}; };
} // namespace gcpp } // namespace gcpp

74
gemma/gemma_args_test.cc Normal file
View File

@ -0,0 +1,74 @@
#include "gemma/gemma_args.h"
#include <stddef.h>
#include <string>
#include <vector>
#include "gtest/gtest.h"
namespace gcpp {
void FillPtrs(const std::vector<std::string>& args, std::vector<char*>& ptrs) {
ptrs.reserve(args.size());
for (const std::string& arg : args) {
ptrs.push_back(const_cast<char*>(arg.data()));
}
}
static void CheckAllConsumed(const std::vector<std::string>& args) {
std::vector<char*> ptrs;
FillPtrs(args, ptrs);
const int argc = static_cast<int>(args.size());
char** argv = const_cast<char**>(ptrs.data());
ConsumedArgs consumed(argc, argv);
GemmaArgs gemma_args(argc, argv, consumed);
consumed.AbortIfUnconsumed();
}
static void CheckUnconsumed(const std::vector<std::string>& args,
size_t expected) {
std::vector<char*> ptrs;
FillPtrs(args, ptrs);
const int argc = static_cast<int>(args.size());
char** argv = const_cast<char**>(ptrs.data());
ConsumedArgs consumed(argc, argv);
GemmaArgs gemma_args(argc, argv, consumed);
ASSERT_EQ(expected, consumed.FirstUnconsumed());
}
// Note: do not use --help because that is not actually consumed; it is actually
// special-cased in `HasHelp`.
TEST(GemmaArgsTest, AllConsumedArgs) {
// Single arg
CheckAllConsumed({"gemma", "--weights=x"});
// Two args, one with =
CheckAllConsumed({"gemma", "--weights=x", "--verbosity=1"});
// Two args, one with extra value
CheckAllConsumed({"gemma", "--weights=x", "--verbosity", "2"});
// Two args with values
CheckAllConsumed({"gemma", "--verbosity", "2", "--deterministic=true"});
}
TEST(GemmaArgsTest, UnconsumedArgs) {
// Single unconsumed arg
CheckUnconsumed({"gemma", "--UNDEFINED"}, 1);
// Single unconsumed arg, no --
CheckUnconsumed({"gemma", "UNDEFINED"}, 1);
// Single unconsumed arg after valid arg
CheckUnconsumed({"gemma", "--weights=x", "--UNDEFINED"}, 2);
// Single unconsumed arg before valid arg
CheckUnconsumed({"gemma", "--UNDEFINED", "--weights=x"}, 1);
// Single unconsumed arg with = after valid arg
CheckUnconsumed({"gemma", "--weights=x", "--UNDEFINED=1"}, 2);
// Single unconsumed arg with = before valid arg
CheckUnconsumed({"gemma", "--UNDEFINED=false", "--weights=x"}, 1);
// Multiple unconsumed args
CheckUnconsumed({"gemma", "--UNDEFINED", "--XXX"}, 1);
// Multiple unconsumed args with valid arg between
CheckUnconsumed({"gemma", "--UNDEFINED", "--weights=x", "--XXX"}, 1);
}
} // namespace gcpp

View File

@ -89,9 +89,11 @@ std::string GetPrompt(const InferenceArgs& inference) {
} }
// The main Read-Eval-Print Loop. // The main Read-Eval-Print Loop.
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, void ReplGemma(const GemmaArgs& args, const Gemma& gemma, KVCache& kv_cache,
const Gemma& gemma, KVCache& kv_cache, MatMulEnv& env) { MatMulEnv& env) {
PROFILER_ZONE("Gen.misc"); PROFILER_ZONE("Gen.misc");
const InferenceArgs& inference = args.inference;
const int verbosity = inference.verbosity;
size_t abs_pos = 0; // across turns size_t abs_pos = 0; // across turns
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
size_t prompt_size = 0; size_t prompt_size = 0;
@ -113,12 +115,12 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
HWY_ASSERT(image.ReadPPM(inference.image_file.path)); HWY_ASSERT(image.ReadPPM(inference.image_file.path));
const size_t image_size = config.vit_config.image_size; const size_t image_size = config.vit_config.image_size;
image.Resize(image_size, image_size); image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.verbosity = inference.verbosity, RuntimeConfig runtime_config = {.verbosity = verbosity,
.use_spinning = threading.spin}; .use_spinning = args.threading.spin};
double image_tokens_start = hwy::platform::Now(); double image_tokens_start = hwy::platform::Now();
gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image, gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
image_tokens, env); image_tokens, env);
if (inference.verbosity >= 1) { if (verbosity >= 1) {
double image_tokens_duration = hwy::platform::Now() - image_tokens_start; double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
fprintf(stderr, fprintf(stderr,
"\n\n[ Timing info ] Image token generation took: %d ms\n", "\n\n[ Timing info ] Image token generation took: %d ms\n",
@ -189,7 +191,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
TimingInfo timing_info = {.verbosity = inference.verbosity}; TimingInfo timing_info = {.verbosity = inference.verbosity};
RuntimeConfig runtime_config = {.verbosity = inference.verbosity, RuntimeConfig runtime_config = {.verbosity = inference.verbosity,
.batch_stream_token = batch_stream_token, .batch_stream_token = batch_stream_token,
.use_spinning = threading.spin}; .use_spinning = args.threading.spin};
inference.CopyTo(runtime_config); inference.CopyTo(runtime_config);
std::vector<int> prompt; std::vector<int> prompt;
size_t prefix_end = 0; size_t prefix_end = 0;
@ -252,14 +254,14 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
} }
} }
void Run(const LoaderArgs& loader, const ThreadingArgs& threading, void Run(const GemmaArgs& args) {
const InferenceArgs& inference) {
PROFILER_ZONE("Run.misc"); PROFILER_ZONE("Run.misc");
ThreadingContext ctx(threading); ThreadingContext ctx(args.threading);
MatMulEnv env(ctx); MatMulEnv env(ctx);
const InferenceArgs& inference = args.inference;
if (inference.verbosity >= 3) env.print_best = true; if (inference.verbosity >= 3) env.print_best = true;
const Gemma gemma(loader, inference, ctx); const Gemma gemma(args, ctx);
KVCache kv_cache(gemma.Config(), inference, ctx.allocator); KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
if (inference.verbosity >= 1) { if (inference.verbosity >= 1) {
@ -287,13 +289,12 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
if (inference.IsInteractive()) { if (inference.IsInteractive()) {
std::cout << "\033[2J\033[1;1H" // clear screen std::cout << "\033[2J\033[1;1H" // clear screen
<< kAsciiArtBanner << "\n\n"; << kAsciiArtBanner << "\n\n";
ShowConfig(loader, threading, inference, gemma.Config(), ShowConfig(args, gemma.Config(), gemma.WeightReadMode(), ctx);
gemma.WeightReadMode(), ctx);
std::cout << "\n" << instructions << "\n"; std::cout << "\n" << instructions << "\n";
} }
} }
ReplGemma(threading, inference, gemma, kv_cache, env); ReplGemma(args, gemma, kv_cache, env);
} }
} // namespace gcpp } // namespace gcpp
@ -302,17 +303,24 @@ int main(int argc, char** argv) {
gcpp::InternalInit(); gcpp::InternalInit();
{ {
// Negligible CPU time. // Negligible CPU time.
gcpp::LoaderArgs loader(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
gcpp::ThreadingArgs threading(argc, argv); gcpp::GemmaArgs args(argc, argv, consumed);
gcpp::InferenceArgs inference(argc, argv);
if (gcpp::HasHelp(argc, argv)) { if (gcpp::HasHelp(argc, argv)) {
std::cerr << gcpp::kAsciiArtBanner; std::cerr << gcpp::kAsciiArtBanner;
gcpp::ShowHelp(loader, threading, inference); fprintf(stderr,
"\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
"==========================================================\n\n"
"*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
"--weights gemma2-2b-it-sfp.sbs\n\n");
args.Help();
return 0; return 0;
} }
gcpp::Run(loader, threading, inference); // After `HasHelp` so that we print --help even if unconsumed args remain.
consumed.AbortIfUnconsumed();
gcpp::Run(args);
} }
PROFILER_PRINT_RESULTS(); // Must call outside the zone above. PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
return 0; return 0;

View File

@ -23,7 +23,9 @@ namespace gcpp {
namespace { namespace {
struct WriterArgs : public ArgsBase<WriterArgs> { struct WriterArgs : public ArgsBase<WriterArgs> {
WriterArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } WriterArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
Path output_weights; Path output_weights;
@ -38,12 +40,15 @@ struct WriterArgs : public ArgsBase<WriterArgs> {
} // namespace gcpp } // namespace gcpp
int main(int argc, char** argv) { int main(int argc, char** argv) {
gcpp::WriterArgs args(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
if (args.output_weights.Empty()) { gcpp::GemmaArgs args(argc, argv, consumed);
gcpp::WriterArgs writer_args(argc, argv, consumed);
if (writer_args.output_weights.Empty()) {
HWY_ABORT("Missing --output_weights flag, a file for the model weights."); HWY_ABORT("Missing --output_weights flag, a file for the model weights.");
} }
consumed.AbortIfUnconsumed();
gcpp::GemmaEnv env(argc, argv); gcpp::GemmaEnv env(args);
env.GetGemma()->Save(args.output_weights, env.Env().ctx); env.GetGemma()->Save(writer_args.output_weights, env.Env().ctx);
return 0; return 0;
} }

View File

@ -21,10 +21,9 @@
#include "evals/benchmark_helper.h" #include "evals/benchmark_helper.h"
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "io/io.h" #include "paligemma/paligemma_helper.h"
#include "util/allocator.h" #include "util/allocator.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
#include "paligemma/paligemma_helper.h"
// This test can be run manually with the downloaded PaliGemma weights. // This test can be run manually with the downloaded PaliGemma weights.
// It should pass for `paligemma-3b-mix-224` and `paligemma2-3b-pt-448`. // It should pass for `paligemma-3b-mix-224` and `paligemma2-3b-pt-448`.
@ -73,7 +72,11 @@ TEST_F(PaliGemmaTest, QueryObjects) {
int main(int argc, char** argv) { int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
gcpp::GemmaEnv env(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
consumed.AbortIfUnconsumed();
gcpp::GemmaEnv env(args);
gcpp::s_env = &env; gcpp::s_env = &env;
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();

View File

@ -45,10 +45,7 @@ static void RemoveTrailingZeros(std::vector<int> &vec) {
// Wrapper around GemmaEnv to expose to Python. // Wrapper around GemmaEnv to expose to Python.
class GemmaModel { class GemmaModel {
public: public:
GemmaModel(const gcpp::LoaderArgs& loader, GemmaModel(const gcpp::GemmaArgs& args) : env_(args), last_prob_(0.0f) {}
const gcpp::ThreadingArgs& threading,
const gcpp::InferenceArgs& inference)
: env_(loader, threading, inference), last_prob_(0.0f) {}
// Generates a single example, given a prompt and a callback to stream the // Generates a single example, given a prompt and a callback to stream the
// generated tokens. // generated tokens.
@ -254,13 +251,15 @@ PYBIND11_MODULE(gemma, mod) {
py::class_<GemmaModel>(mod, "GemmaModel") py::class_<GemmaModel>(mod, "GemmaModel")
.def(py::init([](const std::string& tokenizer, const std::string& weights, .def(py::init([](const std::string& tokenizer, const std::string& weights,
size_t max_threads) { size_t max_threads) {
const gcpp::LoaderArgs loader(tokenizer, weights);
gcpp::ThreadingArgs threading; gcpp::ThreadingArgs threading;
threading.max_lps = max_threads; threading.max_lps = max_threads;
gcpp::InferenceArgs inference; gcpp::InferenceArgs inference;
inference.max_generated_tokens = 512; inference.max_generated_tokens = 512;
auto gemma =
std::make_unique<GemmaModel>(loader, threading, inference); const gcpp::GemmaArgs args(gcpp::LoaderArgs(tokenizer, weights),
threading, inference);
auto gemma = std::make_unique<GemmaModel>(args);
if (!gemma->ModelIsLoaded()) { if (!gemma->ModelIsLoaded()) {
throw std::invalid_argument("Could not load model."); throw std::invalid_argument("Could not load model.");
} }

View File

@ -22,6 +22,7 @@
#include <algorithm> // std::transform #include <algorithm> // std::transform
#include <string> #include <string>
#include <vector>
#include "io/io.h" // Path #include "io/io.h" // Path
#include "util/basics.h" // Tristate #include "util/basics.h" // Tristate
@ -29,6 +30,56 @@
namespace gcpp { namespace gcpp {
// For checking which args were not matched/consumed. Passed to each `*Args`
// ctor that parses argc/argv to ensure that their args are tracked, without
// requiring global state.
class ConsumedArgs {
public:
ConsumedArgs(int argc, char** argv) : argv_(argv), consumed_(argc) {
// We assume argc >= 1, because argv[0] is the binary name. That allows us
// to signal "called AbortIfUnconsumed" with an empty vector.
HWY_ASSERT(!consumed_.empty());
}
~ConsumedArgs() {
if (HWY_UNLIKELY(!consumed_.empty())) {
HWY_ABORT("AbortIfUnconsumed was not called.");
}
}
void NotifyConsumed(size_t idx) {
HWY_ASSERT(idx < consumed_.size());
HWY_ASSERT(consumed_[idx] == 0);
consumed_[idx] = 1;
}
// Returns index of first unconsumed arg, or 0 if none. Also disarms the
// warning in the dtor checking whether this/`AbortIfUnconsumed` were called.
size_t FirstUnconsumed() {
// Ignore argv[0], which is the binary name.
for (size_t i = 1; i < consumed_.size(); ++i) {
if (HWY_UNLIKELY(consumed_[i] == 0)) {
consumed_.clear();
return i;
}
}
consumed_.clear();
return 0;
}
void AbortIfUnconsumed() {
const size_t i = FirstUnconsumed();
if (HWY_UNLIKELY(i != 0)) {
HWY_ABORT("Unrecognized arg %zu: %s\n", i, argv_[i]);
}
}
private:
char** argv_;
std::vector<uint8_t> consumed_;
};
// Args is a class that provides a ForEach member function which visits each of // Args is a class that provides a ForEach member function which visits each of
// its member variables. ArgsBase provides functions called by Args to // its member variables. ArgsBase provides functions called by Args to
// initialize values to their defaults (passed as an argument to the visitor), // initialize values to their defaults (passed as an argument to the visitor),
@ -93,7 +144,8 @@ class ArgsBase {
// consider adding a hash-map to speed this up. // consider adding a hash-map to speed this up.
class ParseVisitor { class ParseVisitor {
public: public:
ParseVisitor(int argc, char* argv[]) : argc_(argc), argv_(argv) {} ParseVisitor(int argc, char* argv[], ConsumedArgs& consumed)
: argc_(argc), argv_(argv), consumed_(consumed) {}
template <typename T> template <typename T>
void operator()(T& t, const char* name, const T& /*init*/, void operator()(T& t, const char* name, const T& /*init*/,
@ -108,6 +160,8 @@ class ArgsBase {
if (!SetValue(argv_[i + 1], t)) { if (!SetValue(argv_[i + 1], t)) {
HWY_ABORT("Invalid value for %s, got %s\n", name, argv_[i + 1]); HWY_ABORT("Invalid value for %s, got %s\n", name, argv_[i + 1]);
} }
consumed_.NotifyConsumed(i);
consumed_.NotifyConsumed(i + 1);
return; return;
} }
if (std::string(argv_[i]).find(prefixed_eq) == 0) { if (std::string(argv_[i]).find(prefixed_eq) == 0) {
@ -115,6 +169,7 @@ class ArgsBase {
if (!SetValue(value, t)) { if (!SetValue(value, t)) {
HWY_ABORT("Invalid value for %s, got %s\n", name, value); HWY_ABORT("Invalid value for %s, got %s\n", name, value);
} }
consumed_.NotifyConsumed(i);
return; return;
} }
} }
@ -181,8 +236,9 @@ class ArgsBase {
} }
} }
int argc_; const int argc_;
char** argv_; char** const argv_;
ConsumedArgs& consumed_;
}; // ParseVisitor }; // ParseVisitor
template <class Visitor> template <class Visitor>
@ -211,15 +267,15 @@ class ArgsBase {
ForEach(visitor); ForEach(visitor);
} }
void Parse(int argc, char* argv[]) { void Parse(int argc, char* argv[], ConsumedArgs& consumed) {
ParseVisitor visitor(argc, argv); ParseVisitor visitor(argc, argv, consumed);
ForEach(visitor); ForEach(visitor);
} }
// For convenience, enables single-line constructor. // For convenience, enables single-line constructor.
void InitAndParse(int argc, char* argv[]) { void InitAndParse(int argc, char* argv[], ConsumedArgs& consumed) {
Init(); Init();
Parse(argc, argv); Parse(argc, argv, consumed);
} }
}; };

View File

@ -38,7 +38,9 @@ namespace gcpp {
// Optional arguments for `ThreadingContext` from the command line. // Optional arguments for `ThreadingContext` from the command line.
class ThreadingArgs : public ArgsBase<ThreadingArgs> { class ThreadingArgs : public ArgsBase<ThreadingArgs> {
public: public:
ThreadingArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } ThreadingArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
ThreadingArgs() { Init(); }; ThreadingArgs() { Init(); };
// For BoundedTopology: // For BoundedTopology: