Abort if args are unrecognized, refactor argument passing

This catches typos/incorrect usage.
Refactor: group Loader/Threading/Inference into GemmaArgs.
All *Args ctors now have an extra ConsumedArgs& argument.
PiperOrigin-RevId: 844690553
This commit is contained in:
Jan Wassenberg 2025-12-15 03:18:11 -08:00 committed by Copybara-Service
parent f50550f4ce
commit 0c64987a96
28 changed files with 713 additions and 513 deletions

View File

@ -556,12 +556,22 @@ cc_library(
":basics",
":configs",
":mat",
":threading_context",
"//io",
"@highway//:hwy",
"@highway//:profiler",
],
)
cc_test(
name = "gemma_args_test",
srcs = ["gemma/gemma_args_test.cc"],
deps = [
":gemma_args",
"@googletest//:gtest_main", # buildcleaner: keep
],
)
cc_library(
name = "gemma_lib",
srcs = [
@ -666,7 +676,6 @@ cc_library(
":gemma_args",
":gemma_lib",
":matmul_env",
":ops",
":threading_context",
":tokenizer",
"@google_benchmark//:benchmark",

View File

@ -222,6 +222,7 @@ set(GEMMA_TEST_FILES
compression/nuq_test.cc
compression/sfp_test.cc
evals/gemma_test.cc
gemma/gemma_args_test.cc
gemma/flash_attention_test.cc
gemma/tensor_info_test.cc
io/blob_store_test.cc

View File

@ -101,13 +101,11 @@ cc_test(
# for test_suite.
tags = ["hwy_ops_test"],
deps = [
":distortion",
":int",
"@googletest//:gtest_main", # buildcleaner: keep
"//:test_util",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
],
)

View File

@ -23,7 +23,9 @@ using json = nlohmann::json;
class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
public:
BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
BenchmarkArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
Path summarize_text;
Path cross_entropy;
@ -127,9 +129,16 @@ int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file,
} // namespace gcpp
int main(int argc, char** argv) {
gcpp::GemmaEnv env(argc, argv);
gcpp::BenchmarkArgs benchmark_args(argc, argv);
gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
gcpp::BenchmarkArgs benchmark_args(argc, argv, consumed);
if (gcpp::HasHelp(argc, argv)) {
args.Help();
return 0;
}
consumed.AbortIfUnconsumed();
gcpp::GemmaEnv env(args);
if (!benchmark_args.summarize_text.Empty()) {
return BenchmarkSummary(env, benchmark_args.summarize_text);
} else if (!benchmark_args.cross_entropy.Empty()) {

View File

@ -36,35 +36,29 @@
namespace gcpp {
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference)
GemmaEnv::GemmaEnv(const GemmaArgs& args)
: initializer_value_(gcpp::InternalInit()),
ctx_(threading),
ctx_(args.threading),
env_(ctx_),
gemma_(loader, inference, ctx_) {
gemma_(args, ctx_) {
const ModelConfig& config = gemma_.Config();
// Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator));
kv_caches_.push_back(KVCache(config, args.inference, ctx_.allocator));
if (inference.verbosity >= 2) {
ShowConfig(loader, threading, inference, config, gemma_.WeightReadMode(),
ctx_);
if (args.inference.verbosity >= 2) {
ShowConfig(args, config, gemma_.WeightReadMode(), ctx_);
}
if (inference.verbosity >= 3) env_.print_best = true;
if (inference.verbosity >= 4) env_.print_config = true;
if (args.inference.verbosity >= 3) env_.print_best = true;
if (args.inference.verbosity >= 4) env_.print_config = true;
runtime_config_ = {
.max_generated_tokens = inference.max_generated_tokens,
.temperature = inference.temperature,
.verbosity = inference.verbosity,
.max_generated_tokens = args.inference.max_generated_tokens,
.temperature = args.inference.temperature,
.verbosity = args.inference.verbosity,
};
inference.CopyTo(runtime_config_);
args.inference.CopyTo(runtime_config_);
}
GemmaEnv::GemmaEnv(int argc, char** argv)
: GemmaEnv(LoaderArgs(argc, argv), ThreadingArgs(argc, argv),
InferenceArgs(argc, argv)) {}
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
QueryResult result;
@ -234,19 +228,19 @@ static constexpr const char* CompiledConfig() {
}
}
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference, const ModelConfig& config,
void ShowConfig(const GemmaArgs& args, const ModelConfig& config,
const WeightsPtrs::Mode weight_read_mode,
const ThreadingContext& ctx) {
threading.Print(inference.verbosity);
loader.Print(inference.verbosity);
inference.Print(inference.verbosity);
fprintf(
stderr, "Model : %s, to_bf16 %d, mmap %d => %s\n",
config.Specifier().c_str(), static_cast<int>(loader.to_bf16),
static_cast<int>(loader.map), WeightsPtrs::ToString(weight_read_mode));
args.threading.Print(args.inference.verbosity);
args.loader.Print(args.inference.verbosity);
args.inference.Print(args.inference.verbosity);
fprintf(stderr,
"Model : %s, to_bf16 %d, mmap %d => %s\n",
config.Specifier().c_str(), static_cast<int>(args.loader.to_bf16),
static_cast<int>(args.loader.map),
WeightsPtrs::ToString(weight_read_mode));
if (inference.verbosity >= 2) {
if (args.inference.verbosity >= 2) {
time_t now = time(nullptr);
char* dt = ctime(&now); // NOLINT
char cpu100[100] = "unknown";
@ -259,7 +253,7 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
"Instruction set : %s (%zu bits)\n"
"Compiled config : %s, profiler %d\n"
"Memory MiB : %4zu\n",
dt, cpu100, static_cast<int>(threading.bind),
dt, cpu100, static_cast<int>(args.threading.bind),
ctx.topology.TopologyString(), ctx.pools.PinString(),
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
ctx.cache_info.VectorBytes() * 8, CompiledConfig(),
@ -267,22 +261,4 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
}
}
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference) {
std::cerr
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
"==========================================================\n\n"
"To run with pre-2025 weights, specify --tokenizer and --weights.\n"
"With the single-file weights format, specify just --weights.\n";
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
"--weights gemma2-2b-it-sfp.sbs\n";
std::cerr << "\n*Model Loading Arguments*\n\n";
loader.Help();
std::cerr << "\n*Threading Arguments*\n\n";
threading.Help();
std::cerr << "\n*Inference Arguments*\n\n";
inference.Help();
std::cerr << "\n";
}
} // namespace gcpp

View File

@ -23,7 +23,7 @@
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/gemma_args.h" // IWYU pragma: export
#include "gemma/tokenizer.h" // WrapAndTokenize
#include "ops/matmul.h"
#include "util/threading_context.h"
@ -50,10 +50,8 @@ struct QueryResultAndMetrics {
// Convenience class to load a model and run inference.
class GemmaEnv {
public:
// Calls the other constructor with *Args arguments initialized from argv.
GemmaEnv(int argc, char** argv);
GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference);
explicit GemmaEnv(const GemmaArgs& args);
MatMulEnv& Env() { return env_; }
size_t MaxGeneratedTokens() const {
@ -137,12 +135,9 @@ class GemmaEnv {
// Logs the inference speed in tokens/sec.
void LogSpeedStats(double time_start, size_t total_tokens);
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference, const ModelConfig& config,
void ShowConfig(const GemmaArgs& args, const ModelConfig& config,
WeightsPtrs::Mode weight_read_mode,
const ThreadingContext& ctx);
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference);
} // namespace gcpp

View File

@ -98,7 +98,11 @@ BENCHMARK(BM_coding_prompt)
->UseRealTime();
int main(int argc, char** argv) {
gcpp::GemmaEnv env(argc, argv);
gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
consumed.AbortIfUnconsumed();
gcpp::GemmaEnv env(args);
env.SetMaxGeneratedTokens(256);
gcpp::s_env = &env;

View File

@ -31,7 +31,9 @@ namespace gcpp {
class PromptArgs : public ArgsBase<PromptArgs> {
public:
PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
PromptArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
Path layers_output; // optional
std::string prompt;
@ -51,11 +53,15 @@ class PromptArgs : public ArgsBase<PromptArgs> {
};
int Run(int argc, char** argv) {
PromptArgs prompt_args(argc, argv);
ConsumedArgs consumed(argc, argv);
const GemmaArgs args(argc, argv, consumed);
const PromptArgs prompt_args(argc, argv, consumed);
AbortIfInvalidArgs(prompt_args);
consumed.AbortIfUnconsumed();
json json_output;
GemmaEnv env(argc, argv);
GemmaEnv env(args);
env.MutableConfig().layers_output =
prompt_args.layers_output.Empty()
? LayersOutputFunc()

View File

@ -146,7 +146,11 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
int main(int argc, char** argv) {
fprintf(stderr, "GemmaEnv setup..\n");
gcpp::GemmaEnv env(argc, argv);
gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
consumed.AbortIfUnconsumed();
gcpp::GemmaEnv env(args);
gcpp::s_env = &env;
testing::InitGoogleTest(&argc, argv);

View File

@ -22,7 +22,6 @@
#include "evals/benchmark_helper.h"
#include "gemma/configs.h"
#include "io/io.h"
#include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h"
@ -42,7 +41,11 @@ class GemmaTest : public ::testing::Test {
// Requires argc/argv, hence do not use `SetUpTestSuite`.
static void InitEnv(int argc, char** argv) {
HWY_ASSERT(s_env == nullptr); // Should only be called once.
s_env = new GemmaEnv(argc, argv);
ConsumedArgs consumed(argc, argv);
GemmaArgs args(argc, argv, consumed);
consumed.AbortIfUnconsumed();
s_env = new GemmaEnv(args);
const gcpp::ModelConfig& config = s_env->GetGemma()->Config();
fprintf(stderr, "Using %s\n", config.Specifier().c_str());
}

View File

@ -31,7 +31,9 @@
namespace gcpp {
struct JsonArgs : public ArgsBase<JsonArgs> {
JsonArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
JsonArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
Path input;
@ -151,10 +153,14 @@ void Run(GemmaEnv& env, JsonArgs& json) {
int main(int argc, char** argv) {
{
PROFILER_ZONE("Startup.all");
gcpp::GemmaEnv env(argc, argv);
gcpp::JsonArgs json(argc, argv);
gcpp::AbortIfInvalidArgs(json);
gcpp::Run(env, json);
gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
gcpp::JsonArgs json_args(argc, argv, consumed);
gcpp::AbortIfInvalidArgs(json_args);
consumed.AbortIfUnconsumed();
gcpp::GemmaEnv env(args);
gcpp::Run(env, json_args);
}
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
return 0;

View File

@ -24,20 +24,20 @@
#include <vector>
#include "gemma/gemma.h"
#include "gemma/gemma_args.h" // LoaderArgs
#include "gemma/gemma_args.h" // GemmaArgs
#include "gemma/tokenizer.h"
#include "util/args.h"
#include "util/threading_context.h"
#include "hwy/base.h"
int main(int argc, char** argv) {
gcpp::LoaderArgs loader(argc, argv);
gcpp::ThreadingArgs threading(argc, argv);
gcpp::InferenceArgs inference(argc, argv);
gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
if (gcpp::HasHelp(argc, argv)) {
loader.Help();
args.Help();
return 0;
}
consumed.AbortIfUnconsumed();
// Demonstrate constrained decoding by never outputting certain tokens.
std::set<int> reject_tokens;
@ -49,10 +49,10 @@ int main(int argc, char** argv) {
}
// Instantiate model and KV Cache
gcpp::ThreadingContext ctx(threading);
gcpp::ThreadingContext ctx(args.threading);
gcpp::MatMulEnv env(ctx);
gcpp::Gemma gemma(loader, inference, ctx);
gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
gcpp::Gemma gemma(args, ctx);
gcpp::KVCache kv_cache(gemma.Config(), args.inference, ctx.allocator);
size_t generated = 0;
// Tokenize instructions.

View File

@ -23,7 +23,7 @@
#include <vector>
#include "third_party/gemma_cpp/gemma/gemma.h"
#include "third_party/gemma_cpp/gemma/gemma_args.h" // LoaderArgs
#include "third_party/gemma_cpp/gemma/gemma_args.h" // GemmaArgs
#include "third_party/gemma_cpp/gemma/tokenizer.h"
#include "third_party/gemma_cpp/ops/matmul.h"
#include "third_party/gemma_cpp/util/threading_context.h"
@ -31,18 +31,11 @@
class SimplifiedGemma {
public:
SimplifiedGemma(const gcpp::LoaderArgs& loader,
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
: ctx_(threading),
SimplifiedGemma(const gcpp::GemmaArgs& args)
: ctx_(args.threading),
env_(ctx_),
gemma_(loader, inference, ctx_),
kv_cache_(gemma_.Config(), inference, ctx_.allocator) {}
SimplifiedGemma(int argc, char** argv)
: SimplifiedGemma(gcpp::LoaderArgs(argc, argv),
gcpp::ThreadingArgs(argc, argv),
gcpp::InferenceArgs(argc, argv)) {}
gemma_(args, ctx_),
kv_cache_(gemma_.Config(), args.inference, ctx_.allocator) {}
void Generate(std::string& prompt, size_t max_generated_tokens = 1024,
float temperature = 0.7,

View File

@ -18,26 +18,16 @@
#include <string>
#include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp"
#include "gemma/gemma_args.h" // LoaderArgs
#include "gemma/gemma_args.h"
int main(int argc, char** argv) {
// Standard usage: LoaderArgs takes argc and argv as input, then parses
// necessary flags.
gcpp::LoaderArgs loader(argc, argv);
// Sets arguments from argc and argv. Note that you can instead pass in
// LoaderArgs, ThreadingArgs, and InferenceArgs directly.
gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
consumed.AbortIfUnconsumed();
// Optional: LoaderArgs can also take tokenizer and weights paths directly.
//
// gcpp::LoaderArgs loader("/path/to/tokenizer", "/path/to/weights",
// "model_identifier");
// Optional: ThreadingArgs and InferenceArgs can be passed in as well. If not
// specified, default values will be used.
//
// gcpp::InferenceArgs inference(argc, argv);
// gcpp::ThreadingArgs threading(argc, argv);
// SimplifiedGemma gemma(loader, threading, inference);
SimplifiedGemma gemma(loader);
SimplifiedGemma gemma(args);
std::string prompt = "Write a greeting to the world.";
gemma.Generate(prompt, 256, 0.6);

View File

@ -15,18 +15,22 @@
// Test client for API server
#include <iostream>
#include <string>
#include <sstream>
#include <stdio.h>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <sstream>
#include <string>
#include "httplib.h"
#include "nlohmann/json.hpp"
#include "gemma/gemma_args.h"
#include "nlohmann/json.hpp"
using json = nlohmann::json;
namespace gcpp {
// ANSI color codes
const std::string RESET = "\033[0m";
const std::string BOLD = "\033[1m";
@ -38,8 +42,14 @@ const std::string RED = "\033[31m";
class APIClient {
public:
APIClient(const std::string& host, int port, const std::string& api_key = "", const std::string& model = "gemma3-4b")
: host_(host), port_(port), api_key_(api_key), model_(model), use_https_(port == 443), interactive_mode_(false) {
APIClient(const std::string& host, int port, const std::string& api_key = "",
const std::string& model = "gemma3-4b")
: host_(host),
port_(port),
api_key_(api_key),
model_(model),
use_https_(port == 443),
interactive_mode_(false) {
if (use_https_) {
ssl_client_ = std::make_unique<httplib::SSLClient>(host, port);
ssl_client_->set_read_timeout(60, 0);
@ -58,7 +68,9 @@ public:
std::string endpoint;
if (is_public_api) {
endpoint = stream ? "/v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse"
endpoint =
stream
? "/v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse"
: "/v1beta/models/gemini-2.0-flash:generateContent";
} else {
endpoint = stream ? "/v1beta/models/" + model_ + ":streamGenerateContent"
@ -67,7 +79,8 @@ public:
// Only show verbose output in non-interactive mode
if (!interactive_mode_) {
std::cout << "\n" << BOLD << BLUE << "📤 POST " << endpoint << RESET << std::endl;
std::cout << "\n"
<< BOLD << BLUE << "📤 POST " << endpoint << RESET << std::endl;
std::cout << "Request: " << request.dump(2) << std::endl;
}
@ -83,18 +96,21 @@ public:
json response = ProcessRequest(request, stream);
if (response.contains("error")) {
std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl;
std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET
<< std::endl;
}
}
void TestListModels() {
std::cout << "\n" << BOLD << BLUE << "📤 GET /v1beta/models" << RESET << std::endl;
std::cout << "\n"
<< BOLD << BLUE << "📤 GET /v1beta/models" << RESET << std::endl;
httplib::Headers headers;
if (!api_key_.empty()) {
headers.emplace("X-goog-api-key", api_key_);
}
auto res = use_https_ ? ssl_client_->Get("/v1beta/models", headers) : client_->Get("/v1beta/models", headers);
auto res = use_https_ ? ssl_client_->Get("/v1beta/models", headers)
: client_->Get("/v1beta/models", headers);
if (res && res->status == 200) {
json response = json::parse(res->body);
@ -106,7 +122,9 @@ public:
}
void InteractiveChat() {
std::cout << "\n" << BOLD << CYAN << "💬 Interactive Chat Mode (with session)" << RESET << std::endl;
std::cout << "\n"
<< BOLD << CYAN << "💬 Interactive Chat Mode (with session)"
<< RESET << std::endl;
std::cout << "Type ':gemma %q' to end.\n" << std::endl;
interactive_mode_ = true;
@ -141,14 +159,16 @@ public:
if (response.contains("candidates") && !response["candidates"].empty()) {
auto& candidate = response["candidates"][0];
if (candidate.contains("content") && candidate["content"].contains("parts")) {
if (candidate.contains("content") &&
candidate["content"].contains("parts")) {
for (const auto& part : candidate["content"]["parts"]) {
if (part.contains("text")) {
std::string assistant_response = part["text"].get<std::string>();
// For streaming, the response is already displayed in real-time
// Just add to message history for context
json assistant_message = {{"parts", {{{"text", assistant_response}}}}};
json assistant_message = {
{"parts", {{{"text", assistant_response}}}}};
if (!api_key_.empty()) {
assistant_message["role"] = "model";
}
@ -157,7 +177,8 @@ public:
}
}
} else if (response.contains("error")) {
std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl;
std::cerr << RED << "❌ Error: " << response["error"]["message"]
<< RESET << std::endl;
}
std::cout << std::endl;
@ -165,14 +186,11 @@ public:
}
private:
json CreateAPIRequest(const std::string& prompt, const json& messages = json::array()) {
json CreateAPIRequest(const std::string& prompt,
const json& messages = json::array()) {
json request = {
{"generationConfig", {
{"temperature", 0.9},
{"topK", 1},
{"maxOutputTokens", 1024}
}}
};
{"generationConfig",
{{"temperature", 0.9}, {"topK", 1}, {"maxOutputTokens", 1024}}}};
if (messages.empty()) {
// Single prompt
@ -189,38 +207,42 @@ private:
return request;
}
json ProcessNonStreamingRequest(const json& request, const std::string& endpoint) {
json ProcessNonStreamingRequest(const json& request,
const std::string& endpoint) {
httplib::Headers headers = {{"Content-Type", "application/json"}};
if (!api_key_.empty()) {
headers.emplace("X-goog-api-key", api_key_);
}
auto res = use_https_ ? ssl_client_->Post(endpoint, headers, request.dump(), "application/json")
: client_->Post(endpoint, headers, request.dump(), "application/json");
auto res = use_https_ ? ssl_client_->Post(endpoint, headers, request.dump(),
"application/json")
: client_->Post(endpoint, headers, request.dump(),
"application/json");
if (res && res->status == 200) {
json response = json::parse(res->body);
if (!interactive_mode_) {
std::cout << "\n" << BOLD << GREEN << "📥 Response:" << RESET << std::endl;
std::cout << "\n"
<< BOLD << GREEN << "📥 Response:" << RESET << std::endl;
std::cout << response.dump(2) << std::endl;
}
return response;
} else {
json error_response = {
{"error", {
{"message", "Request failed"},
{"status", res ? res->status : -1}
}}
};
json error_response = {{"error",
{{"message", "Request failed"},
{"status", res ? res->status : -1}}}};
if (res && !res->body.empty()) {
error_response["error"]["details"] = res->body;
}
std::cerr << RED << "❌ Request failed. Status: " << (res ? res->status : -1) << RESET << std::endl;
std::cerr << RED
<< "❌ Request failed. Status: " << (res ? res->status : -1)
<< RESET << std::endl;
return error_response;
}
}
json ProcessStreamingRequest(const json& request, const std::string& endpoint) {
json ProcessStreamingRequest(const json& request,
const std::string& endpoint) {
std::string accumulated_response;
// Use same SSE logic for both public and local APIs
@ -233,7 +255,9 @@ private:
}
req.body = request.dump();
req.content_receiver = [&accumulated_response, this](const char* data, size_t data_length, uint64_t offset, uint64_t total_length) -> bool {
req.content_receiver = [&accumulated_response, this](
const char* data, size_t data_length,
uint64_t offset, uint64_t total_length) -> bool {
std::string chunk(data, data_length);
std::istringstream stream(chunk);
std::string line;
@ -244,14 +268,18 @@ private:
if (event_data == "[DONE]") {
if (!interactive_mode_) {
std::cout << "\n\n" << GREEN << "✅ Generation complete!" << RESET << std::endl;
std::cout << "\n\n"
<< GREEN << "✅ Generation complete!" << RESET
<< std::endl;
}
} else {
try {
json event = json::parse(event_data);
if (event.contains("candidates") && !event["candidates"].empty()) {
if (event.contains("candidates") &&
!event["candidates"].empty()) {
auto& candidate = event["candidates"][0];
if (candidate.contains("content") && candidate["content"].contains("parts")) {
if (candidate.contains("content") &&
candidate["content"].contains("parts")) {
for (const auto& part : candidate["content"]["parts"]) {
if (part.contains("text")) {
std::string text = part["text"].get<std::string>();
@ -272,27 +300,22 @@ private:
httplib::Response res;
httplib::Error error;
bool success = use_https_ ? ssl_client_->send(req, res, error) : client_->send(req, res, error);
bool success = use_https_ ? ssl_client_->send(req, res, error)
: client_->send(req, res, error);
if (res.status == 200 && !accumulated_response.empty()) {
return json{
{"candidates", {{
{"content", {
{"parts", {{{"text", accumulated_response}}}}
}}
}}}
};
{"candidates",
{{{"content", {{"parts", {{{"text", accumulated_response}}}}}}}}}};
} else {
json error_response = {
{"error", {
{"message", "Streaming request failed"},
{"status", res.status}
}}
};
{"error",
{{"message", "Streaming request failed"}, {"status", res.status}}}};
if (!res.body.empty()) {
error_response["error"]["details"] = res.body;
}
std::cerr << RED << "❌ Streaming request failed. Status: " << res.status << RESET << std::endl;
std::cerr << RED << "❌ Streaming request failed. Status: " << res.status
<< RESET << std::endl;
return error_response;
}
}
@ -308,19 +331,55 @@ private:
bool interactive_mode_;
};
struct ClientArgs : public ArgsBase<ClientArgs> {
ClientArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
ClientArgs() { Init(); };
std::string host;
int port;
std::string api_key;
std::string model;
std::string prompt;
bool interactive;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(host, "host", std::string("localhost"),
"Server host (default: localhost)");
visitor(port, "port", 8080, "Server port (default: 8080)");
visitor(api_key, "api_key", std::string(""),
"Use public API with key (changes host to "
"generativelanguage.googleapis.com:443)");
visitor(model, "model", std::string("gemma3-4b"),
"Model name to use (default: gemma3-4b)");
visitor(prompt, "prompt", std::string("Hello! How are you?"),
"Prompt for generation (default: 'Hello! How are you?')");
visitor(interactive, "interactive", false,
"Start interactive chat mode (0 = no, 1 = yes)");
}
};
} // namespace gcpp
int main(int argc, char* argv[]) {
gcpp::ClientArgs client_args(argc, argv);
gcpp::ConsumedArgs consumed(argc, argv);
gcpp::ClientArgs client_args(argc, argv, consumed);
if (gcpp::HasHelp(argc, argv)) {
std::cout << "\nAPI Client for gemma.cpp\n";
std::cout << "========================\n\n";
fprintf(stderr,
"\nAPI Client for gemma.cpp\n"
"========================\n\n");
client_args.Help();
std::cout << std::endl;
std::cout << "Environment Variables:" << std::endl;
std::cout << " GOOGLE_API_KEY : Automatically use public Google API if set" << std::endl;
fprintf(stderr,
"\n*Environment Variables:\n"
" GOOGLE_API_KEY : Automatically use public Google API if set\n");
return 0;
}
consumed.AbortIfUnconsumed();
// Check for GOOGLE_API_KEY environment variable
const char* env_api_key = std::getenv("GOOGLE_API_KEY");
if (env_api_key != nullptr && strlen(env_api_key) > 0) {
@ -335,11 +394,12 @@ int main(int argc, char* argv[]) {
client_args.port = 443;
}
std::cout << BOLD << YELLOW << "🚀 Testing API Server at "
<< client_args.host << ":" << client_args.port << RESET << std::endl;
std::cout << BOLD << YELLOW << "🚀 Testing API Server at " << client_args.host
<< ":" << client_args.port << RESET << std::endl;
try {
APIClient client(client_args.host, client_args.port, client_args.api_key, client_args.model);
APIClient client(client_args.host, client_args.port, client_args.api_key,
client_args.model);
if (client_args.interactive) {
client.InteractiveChat();
@ -347,11 +407,12 @@ int main(int argc, char* argv[]) {
client.TestListModels();
client.TestGenerateContent(client_args.prompt, true);
}
} catch (const std::exception& e) {
std::cerr << RED << "❌ Error: " << e.what() << RESET << std::endl;
std::cerr << "Make sure the API server is running:" << std::endl;
std::cerr << " ./build/gemma_api_server --tokenizer <path> --weights <path>" << std::endl;
std::cerr
<< " ./build/gemma_api_server --tokenizer <path> --weights <path>"
<< std::endl;
return 1;
}

View File

@ -15,22 +15,19 @@
// HTTP API server for gemma.cpp with SSE support
#include <stdio.h>
#include <signal.h>
#include <stdio.h>
#include <iostream>
#include <memory>
#include <random>
#include <string>
#include <string_view>
#include <vector>
#include <thread>
#include <atomic>
#include <chrono>
#include <sstream>
#include <iomanip>
#include <iostream>
#include <memory>
#include <mutex>
#include <sstream>
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <vector>
// HTTP server library
#undef CPPHTTPLIB_OPENSSL_SUPPORT
@ -38,16 +35,12 @@
#include "httplib.h"
// JSON library
#include "nlohmann/json.hpp"
#include "compression/types.h"
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/tokenizer.h"
#include "ops/matmul.h"
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/profiler.h"
#include "nlohmann/json.hpp"
using json = nlohmann::json;
@ -90,7 +83,8 @@ struct ServerState {
std::lock_guard<std::mutex> lock(sessions_mutex);
auto& session = sessions[session_id];
if (!session.kv_cache) {
session.kv_cache = std::make_unique<KVCache>(gemma->Config(), InferenceArgs(), env->ctx.allocator);
session.kv_cache = std::make_unique<KVCache>(
gemma->Config(), InferenceArgs(), env->ctx.allocator);
}
session.last_access = std::chrono::steady_clock::now();
return session;
@ -107,7 +101,8 @@ std::string GenerateSessionId() {
return ss.str();
}
// Wraps messages with start_of_turn markers - handles both with and without roles
// Wraps messages with start_of_turn markers - handles both with and without
// roles
std::string WrapMessagesWithTurnMarkers(const json& contents) {
std::string prompt;
@ -121,12 +116,14 @@ std::string WrapMessagesWithTurnMarkers(const json& contents) {
std::string text = part["text"];
if (role == "user") {
prompt += "<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
prompt +=
"<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
} else if (role == "model") {
prompt += text + "\n";
} else if (role.empty()) {
// Local format without roles - for now, treat as user input
prompt += "<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
prompt +=
"<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
}
}
}
@ -163,18 +160,15 @@ RuntimeConfig ParseGenerationConfig(const json& request) {
return config;
}
// Unified response formatter - creates consistent format regardless of request type
json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false) {
// Unified response formatter - creates consistent format regardless of request
// type
json CreateAPIResponse(const std::string& text,
bool is_streaming_chunk = false) {
json response = {
{"candidates", {{
{"content", {
{"parts", {{{"text", text}}}},
{"role", "model"}
}},
{"index", 0}
}}},
{"promptFeedback", {{"safetyRatings", json::array()}}}
};
{"candidates",
{{{"content", {{"parts", {{{"text", text}}}}, {"role", "model"}}},
{"index", 0}}}},
{"promptFeedback", {{"safetyRatings", json::array()}}}};
// Only add finishReason for non-streaming chunks
if (!is_streaming_chunk) {
@ -185,7 +179,9 @@ json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false)
}
// Handle generateContent endpoint (non-streaming)
void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) {
void HandleGenerateContentNonStreaming(ServerState& state,
const httplib::Request& req,
httplib::Response& res) {
try {
json request = json::parse(req.body);
@ -199,7 +195,9 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
prompt = WrapMessagesWithTurnMarkers(request["contents"]);
} else {
res.status = 400;
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json");
res.set_content(
json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(),
"application/json");
return;
}
@ -209,12 +207,7 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
// Set up runtime config
RuntimeConfig runtime_config = ParseGenerationConfig(request);
// Collect full response
std::string full_response;
runtime_config.stream_token = [&full_response](int token, float) {
// Skip EOS token
return true;
};
runtime_config.stream_token = [](int token, float) { return true; };
// Tokenize prompt
std::vector<int> tokens = WrapAndTokenize(
@ -227,7 +220,8 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
// Temporarily redirect output to capture response
std::stringstream output;
runtime_config.stream_token = [&output, &state, &session, &tokens](int token, float) {
runtime_config.stream_token = [&output, &state, &session, &tokens](
int token, float) {
// Skip prompt tokens
if (session.abs_pos < tokens.size()) {
session.abs_pos++;
@ -279,7 +273,9 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
}
// Handle streamGenerateContent endpoint with SSE)
void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) {
void HandleGenerateContentStreaming(ServerState& state,
const httplib::Request& req,
httplib::Response& res) {
try {
json request = json::parse(req.body);
@ -293,7 +289,9 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
prompt = WrapMessagesWithTurnMarkers(request["contents"]);
} else {
res.status = 400;
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json");
res.set_content(
json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(),
"application/json");
return;
}
@ -305,8 +303,8 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
// Set up chunked content provider for SSE
res.set_chunked_content_provider(
"text/event-stream",
[&state, request, prompt, session_id](size_t offset, httplib::DataSink& sink) {
"text/event-stream", [&state, request, prompt, session_id](
size_t offset, httplib::DataSink& sink) {
try {
// Lock for inference
std::lock_guard<std::mutex> lock(state.inference_mutex);
@ -338,7 +336,8 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
// Decode token
std::string token_text;
state.gemma->Tokenizer().Decode(std::vector<int>{token}, &token_text);
state.gemma->Tokenizer().Decode(std::vector<int>{token},
&token_text);
accumulated_text += token_text;
// Send SSE event using unified formatter
@ -365,8 +364,7 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
final_event["usageMetadata"] = {
{"promptTokenCount", tokens.size()},
{"candidatesTokenCount", session.abs_pos - tokens.size()},
{"totalTokenCount", session.abs_pos}
};
{"totalTokenCount", session.abs_pos}};
std::string final_sse = "data: " + final_event.dump() + "\n\n";
sink.write(final_sse.data(), final_sse.size());
@ -377,16 +375,13 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
// Ensure all data is sent
sink.done();
return false; // End streaming
} catch (const std::exception& e) {
json error_event = {{"error", {{"message", e.what()}}}};
std::string error_sse = "data: " + error_event.dump() + "\n\n";
sink.write(error_sse.data(), error_sse.size());
return false;
}
}
);
});
} catch (const json::exception& e) {
res.status = 400;
res.set_content(
@ -398,20 +393,20 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
}
// Handle models list endpoint
void HandleListModels(ServerState& state, const InferenceArgs& inference, const httplib::Request& req, httplib::Response& res) {
void HandleListModels(ServerState& state, const InferenceArgs& inference,
const httplib::Request& req, httplib::Response& res) {
json response = {
{"models", {{
{"name", "models/" + inference.model},
{"models",
{{{"name", "models/" + inference.model},
{"version", "001"},
{"displayName", inference.model},
{"description", inference.model + " model running locally"},
{"inputTokenLimit", 8192},
{"outputTokenLimit", 8192},
{"supportedGenerationMethods", json::array({"generateContent", "streamGenerateContent"})},
{"supportedGenerationMethods",
json::array({"generateContent", "streamGenerateContent"})},
{"temperature", 1.0},
{"topK", 1}
}}}
};
{"topK", 1}}}}};
res.set_content(response.dump(), "application/json");
}
@ -421,37 +416,43 @@ void HandleListModels(ServerState& state, const InferenceArgs& inference, const
// server_running = false;
// }
void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference) {
void RunServer(const GemmaArgs& args) {
std::cerr << "Loading model..." << std::endl;
// Initialize model
ThreadingContext ctx(threading);
ThreadingContext ctx(args.threading);
MatMulEnv env(ctx);
ServerState state;
state.gemma = std::make_unique<Gemma>(loader, inference, ctx);
state.gemma = std::make_unique<Gemma>(args, ctx);
state.env = &env;
state.ctx = &ctx;
const InferenceArgs& inference = args.inference;
httplib::Server server;
// Set up routes
server.Get("/", [&inference](const httplib::Request&, httplib::Response& res) {
res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" + inference.model + ":generateContent", "text/plain");
server.Get(
"/", [&inference](const httplib::Request&, httplib::Response& res) {
res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" +
inference.model + ":generateContent",
"text/plain");
});
// API endpoints
server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req, httplib::Response& res) {
server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req,
httplib::Response& res) {
HandleListModels(state, inference, req, res);
});
std::string model_endpoint = "/v1beta/models/" + inference.model;
server.Post(model_endpoint + ":generateContent", [&state](const httplib::Request& req, httplib::Response& res) {
server.Post(model_endpoint + ":generateContent",
[&state](const httplib::Request& req, httplib::Response& res) {
HandleGenerateContentNonStreaming(state, req, res);
});
server.Post(model_endpoint + ":streamGenerateContent", [&state](const httplib::Request& req, httplib::Response& res) {
server.Post(model_endpoint + ":streamGenerateContent",
[&state](const httplib::Request& req, httplib::Response& res) {
HandleGenerateContentStreaming(state, req, res);
});
@ -466,12 +467,15 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
std::cerr << "Starting API server on port " << inference.port << std::endl;
std::cerr << "Model loaded successfully" << std::endl;
std::cerr << "Endpoints:" << std::endl;
std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent" << std::endl;
std::cerr << " POST /v1beta/models/" << inference.model << ":streamGenerateContent (SSE)" << std::endl;
std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent"
<< std::endl;
std::cerr << " POST /v1beta/models/" << inference.model
<< ":streamGenerateContent (SSE)" << std::endl;
std::cerr << " GET /v1beta/models" << std::endl;
if (!server.listen("0.0.0.0", inference.port)) {
std::cerr << "Failed to start server on port " << inference.port << std::endl;
std::cerr << "Failed to start server on port " << inference.port
<< std::endl;
}
cleanup_thread.join();
@ -482,35 +486,27 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
int main(int argc, char** argv) {
gcpp::InternalInit();
gcpp::LoaderArgs loader(argc, argv);
gcpp::ThreadingArgs threading(argc, argv);
gcpp::InferenceArgs inference(argc, argv);
gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
if (gcpp::HasHelp(argc, argv)) {
std::cerr << "\n\nAPI server for gemma.cpp\n";
std::cout << "========================\n\n";
std::cerr << "Usage: " << argv[0] << " --weights <path> --tokenizer <path> [options]\n";
std::cerr << "\nOptions:\n";
std::cerr << " --port PORT Server port (default: 8080)\n";
std::cerr << " --model MODEL Model name for endpoints (default: gemma3-4b)\n";
std::cerr << "\n";
std::cerr << "\n*Model Loading Arguments*\n\n";
loader.Help();
std::cerr << "\n*Threading Arguments*\n\n";
threading.Help();
std::cerr << "\n*Inference Arguments*\n\n";
inference.Help();
std::cerr << "\n";
fprintf(
stderr,
"\n\nAPI server for gemma.cpp\n"
"========================\n\n"
" --port PORT Server port (default: 8080)\n"
" --model MODEL Model name for endpoints (default: gemma3-4b)\n");
args.Help();
return 0;
}
// Arguments are now handled by InferenceArgs
consumed.AbortIfUnconsumed();
// // Set up signal handler
// signal(SIGINT, gcpp::HandleShutdown);
// signal(SIGTERM, gcpp::HandleShutdown);
gcpp::RunServer(loader, threading, inference);
gcpp::RunServer(args);
return 0;
}

View File

@ -73,45 +73,38 @@ GemmaContext* GemmaContext::Create(const char* tokenizer_path,
ThreadingArgs threading_args;
threading_args.spin = gcpp::Tristate::kFalse;
LoaderArgs loader(tokenizer_path, weights_path);
LogDebug("LoaderArgs created");
threading_args.spin = gcpp::Tristate::kFalse;
GemmaArgs args(LoaderArgs(tokenizer_path, weights_path), threading_args);
// Initialize cached args
LogDebug("Initializing inference args");
InferenceArgs inference_args;
inference_args.Init();
inference_args.max_generated_tokens = max_generated_tokens;
inference_args.temperature = 0.7f;
inference_args.top_k = 1;
inference_args.deterministic = false;
args.inference.max_generated_tokens = max_generated_tokens;
args.inference.temperature = 0.7f;
args.inference.top_k = 1;
args.inference.deterministic = false;
ss.str("");
ss << "Inference args initialized with max_tokens: " << max_generated_tokens
<< ", temperature: " << inference_args.temperature
<< ", top_k: " << inference_args.top_k << ", deterministic: "
<< (inference_args.deterministic ? "true" : "false");
<< ", temperature: " << args.inference.temperature
<< ", top_k: " << args.inference.top_k << ", deterministic: "
<< (args.inference.deterministic ? "true" : "false");
LogDebug(ss.str().c_str());
return new GemmaContext(loader, inference_args, threading_args,
max_generated_tokens);
return new GemmaContext(args, max_generated_tokens);
}
GemmaContext::GemmaContext(const LoaderArgs& loader,
const InferenceArgs& inference_args,
const ThreadingArgs& threading_args,
int max_generated_tokens)
: inference_args(inference_args),
threading_args(threading_args),
ctx(threading_args),
GemmaContext::GemmaContext(const GemmaArgs& args, int max_generated_tokens)
: args(args),
ctx(args.threading),
matmul_env(ctx),
active_conversation_name("default"),
model(loader, inference_args, matmul_env.ctx) {
model(args, matmul_env.ctx) {
std::stringstream ss;
LogDebug("Creating initial ConversationData");
// Create the initial ConversationData object using make_shared
active_conversation = std::make_shared<ConversationData>(
model.Config(), inference_args, ctx.allocator);
model.Config(), args.inference, ctx.allocator);
LogDebug(
"Storing initial ConversationData in conversation_cache[\"default\"]");
@ -172,8 +165,8 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
// set up runtime config
TimingInfo timing_info = {};
RuntimeConfig runtime_config = {.stream_token = stream_token,
.use_spinning = threading_args.spin};
inference_args.CopyTo(runtime_config);
.use_spinning = args.threading.spin};
args.inference.CopyTo(runtime_config);
size_t prefix_end = 0;
const ModelConfig& model_config = model.Config();
@ -247,7 +240,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
timing_info);
// prepare for next turn
if (!inference_args.multiturn ||
if (!args.inference.multiturn ||
model_config.wrapping == PromptWrapping::PALIGEMMA) {
// If not multiturn, or Paligemma (which handles turns differently),
// reset the *active* conversation's position.

View File

@ -53,8 +53,7 @@ typedef void (*GemmaLogCallback)(const char* message, void* user_data);
class GemmaContext {
private:
GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args,
const ThreadingArgs& threading_args, int max_generated_tokens);
GemmaContext(const GemmaArgs& args, int max_generated_tokens);
public:
static GemmaContext* Create(const char* tokenizer_path,
@ -81,37 +80,37 @@ class GemmaContext {
// Set max generated tokens
void SetMaxGeneratedTokens(size_t value) {
inference_args.max_generated_tokens = value;
args.inference.max_generated_tokens = value;
LogDebug("Setting max_generated_tokens to configured value");
}
// Set multiturn flag (0 = disabled, 1 = enabled)
void SetMultiturn(int value) {
inference_args.multiturn = value;
args.inference.multiturn = value;
LogDebug("Setting multiturn to configured value");
}
// Set temperature for token generation
void SetTemperature(float value) {
inference_args.temperature = value;
args.inference.temperature = value;
LogDebug("Setting temperature to configured value");
}
// Set top_k parameter for sampling
void SetTopK(int value) {
inference_args.top_k = value;
args.inference.top_k = value;
LogDebug("Setting top_k to configured value");
}
// Set deterministic flag
void SetDeterministic(bool value) {
inference_args.deterministic = value;
args.inference.deterministic = value;
LogDebug("Setting deterministic flag to configured value");
}
// Set prefill_tbatch_size
void SetPrefillTbatchSize(size_t value) {
inference_args.prefill_tbatch_size = value;
args.inference.prefill_tbatch_size = value;
LogDebug("Setting prefill_tbatch_size to configured value");
}
@ -175,7 +174,7 @@ class GemmaContext {
active_conversation->abs_pos = 0;
// Replace the cache within the current ConversationData object
active_conversation->kv_cache = std::make_unique<KVCache>(
model.Config(), inference_args, ctx.allocator);
model.Config(), args.inference, ctx.allocator);
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
} else {
@ -193,7 +192,7 @@ class GemmaContext {
LogDebug("Creating new conversation");
// Create a new ConversationData object using make_shared
conversation_cache[name] = std::make_shared<ConversationData>(
model.Config(), inference_args, ctx.allocator);
model.Config(), args.inference, ctx.allocator);
return true;
}
@ -274,8 +273,7 @@ class GemmaContext {
std::vector<int> token_buffer;
// Cached args (remain global for the context)
InferenceArgs inference_args;
ThreadingArgs threading_args;
GemmaArgs args;
ThreadingContext ctx;
MatMulEnv matmul_env;

View File

@ -738,17 +738,16 @@ HWY_EXPORT(GenerateSingleT);
HWY_EXPORT(GenerateBatchT);
HWY_EXPORT(GenerateImageTokensT);
Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
ThreadingContext& ctx)
: reader_(loader.weights),
model_(reader_, loader.tokenizer, loader.wrapping),
Gemma::Gemma(const GemmaArgs& args, ThreadingContext& ctx)
: reader_(args.loader.weights),
model_(reader_, args.loader.tokenizer, args.loader.wrapping),
weights_(model_.Config()),
chat_template_(model_.Tokenizer(), model_.Config().model),
inference_(inference),
aes_ctr_engine_(inference.deterministic) {
inference_(args.inference),
aes_ctr_engine_(args.inference.deterministic) {
// Negligible CPU time in the ctor body (except ReadFromBlobs).
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference,
mat_owners_, ctx);
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, args.loader,
args.inference, mat_owners_, ctx);
// Read everything into memory, or `weights_.mapped_` keeps the mapping alive.
reader_.CloseFile();
}

View File

@ -130,11 +130,16 @@ struct TimingInfo {
// separate `ThreadingContext` and `MatMulEnv` for each concurrent `Generate`.
class Gemma {
public:
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
// Reads weights/config/tokenizer from `BlobStore` at `args.loader.weights`.
// `ctx` is only used to read tensors and not stored. Calls to `Generate*`
// may reference the same, or other `ThreadingContext` via `MatMulEnv`.
Gemma(const GemmaArgs& args, ThreadingContext& ctx);
// Deprecated prior interface for backwards compatibility.
Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
ThreadingContext& ctx);
ThreadingContext& ctx)
: Gemma(GemmaArgs(loader, ThreadingArgs(), inference), ctx) {}
~Gemma();
const ModelConfig& Config() const { return model_.Config(); }

View File

@ -26,9 +26,10 @@
#include "gemma/configs.h"
#include "io/io.h" // Path
#include "util/args.h"
#include "util/args.h" // IWYU pragma: export
#include "util/basics.h" // Tristate
#include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" // HWY_ABORT
#include "hwy/profiler.h"
@ -36,7 +37,9 @@
namespace gcpp {
struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
LoaderArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
LoaderArgs(const std::string& tokenizer_path,
const std::string& weights_path) {
Init(); // Init sets to defaults, so assignments must come after Init().
@ -169,7 +172,9 @@ struct RuntimeConfig {
};
struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
InferenceArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
InferenceArgs() { Init(); };
bool IsInteractive() const { return prompt.empty() && prompt_file.Empty(); }
@ -275,33 +280,35 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
}
};
struct ClientArgs : public ArgsBase<ClientArgs> {
ClientArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
ClientArgs() { Init(); };
// Bundles all args required to construct a `GemmaEnv` or the equivalent.
struct GemmaArgs {
// For callers that do not parse command line args.
GemmaArgs(const LoaderArgs& loader,
const ThreadingArgs& threading = ThreadingArgs(),
const InferenceArgs& inference = InferenceArgs())
: loader(loader), threading(threading), inference(inference) {}
std::string host;
int port;
std::string api_key;
std::string model;
std::string prompt;
bool interactive;
GemmaArgs(int argc, char** argv, ConsumedArgs& consumed)
: loader(argc, argv, consumed),
threading(argc, argv, consumed),
inference(argc, argv, consumed) {}
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(host, "host", std::string("localhost"),
"Server host (default: localhost)");
visitor(port, "port", 8080,
"Server port (default: 8080)");
visitor(api_key, "api_key", std::string(""),
"Use public API with key (changes host to "
"generativelanguage.googleapis.com:443)");
visitor(model, "model", std::string("gemma3-4b"),
"Model name to use (default: gemma3-4b)");
visitor(prompt, "prompt", std::string("Hello! How are you?"),
"Prompt for generation (default: 'Hello! How are you?')");
visitor(interactive, "interactive", false,
"Start interactive chat mode (0 = no, 1 = yes)");
void Help() {
fprintf(stderr,
"To run with pre-2025 weights, specify --tokenizer and --weights.\n"
"With the single-file weights format, specify just --weights.\n"
"\n*Model Loading Arguments*\n");
loader.Help();
fprintf(stderr, "\n*Threading Arguments*\n");
threading.Help();
fprintf(stderr, "\n*Inference Arguments*\n");
inference.Help();
fprintf(stderr, "\n");
}
LoaderArgs loader;
ThreadingArgs threading;
InferenceArgs inference;
};
} // namespace gcpp

74
gemma/gemma_args_test.cc Normal file
View File

@ -0,0 +1,74 @@
#include "gemma/gemma_args.h"
#include <stddef.h>
#include <string>
#include <vector>
#include "gtest/gtest.h"
namespace gcpp {
void FillPtrs(const std::vector<std::string>& args, std::vector<char*>& ptrs) {
ptrs.reserve(args.size());
for (const std::string& arg : args) {
ptrs.push_back(const_cast<char*>(arg.data()));
}
}
static void CheckAllConsumed(const std::vector<std::string>& args) {
std::vector<char*> ptrs;
FillPtrs(args, ptrs);
const int argc = static_cast<int>(args.size());
char** argv = const_cast<char**>(ptrs.data());
ConsumedArgs consumed(argc, argv);
GemmaArgs gemma_args(argc, argv, consumed);
consumed.AbortIfUnconsumed();
}
static void CheckUnconsumed(const std::vector<std::string>& args,
size_t expected) {
std::vector<char*> ptrs;
FillPtrs(args, ptrs);
const int argc = static_cast<int>(args.size());
char** argv = const_cast<char**>(ptrs.data());
ConsumedArgs consumed(argc, argv);
GemmaArgs gemma_args(argc, argv, consumed);
ASSERT_EQ(expected, consumed.FirstUnconsumed());
}
// Note: do not use --help because that is not actually consumed; it is actually
// special-cased in `HasHelp`.
TEST(GemmaArgsTest, AllConsumedArgs) {
// Single arg
CheckAllConsumed({"gemma", "--weights=x"});
// Two args, one with =
CheckAllConsumed({"gemma", "--weights=x", "--verbosity=1"});
// Two args, one with extra value
CheckAllConsumed({"gemma", "--weights=x", "--verbosity", "2"});
// Two args with values
CheckAllConsumed({"gemma", "--verbosity", "2", "--deterministic=true"});
}
TEST(GemmaArgsTest, UnconsumedArgs) {
// Single unconsumed arg
CheckUnconsumed({"gemma", "--UNDEFINED"}, 1);
// Single unconsumed arg, no --
CheckUnconsumed({"gemma", "UNDEFINED"}, 1);
// Single unconsumed arg after valid arg
CheckUnconsumed({"gemma", "--weights=x", "--UNDEFINED"}, 2);
// Single unconsumed arg before valid arg
CheckUnconsumed({"gemma", "--UNDEFINED", "--weights=x"}, 1);
// Single unconsumed arg with = after valid arg
CheckUnconsumed({"gemma", "--weights=x", "--UNDEFINED=1"}, 2);
// Single unconsumed arg with = before valid arg
CheckUnconsumed({"gemma", "--UNDEFINED=false", "--weights=x"}, 1);
// Multiple unconsumed args
CheckUnconsumed({"gemma", "--UNDEFINED", "--XXX"}, 1);
// Multiple unconsumed args with valid arg between
CheckUnconsumed({"gemma", "--UNDEFINED", "--weights=x", "--XXX"}, 1);
}
} // namespace gcpp

View File

@ -89,9 +89,11 @@ std::string GetPrompt(const InferenceArgs& inference) {
}
// The main Read-Eval-Print Loop.
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
const Gemma& gemma, KVCache& kv_cache, MatMulEnv& env) {
void ReplGemma(const GemmaArgs& args, const Gemma& gemma, KVCache& kv_cache,
MatMulEnv& env) {
PROFILER_ZONE("Gen.misc");
const InferenceArgs& inference = args.inference;
const int verbosity = inference.verbosity;
size_t abs_pos = 0; // across turns
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
size_t prompt_size = 0;
@ -113,12 +115,12 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
HWY_ASSERT(image.ReadPPM(inference.image_file.path));
const size_t image_size = config.vit_config.image_size;
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.verbosity = inference.verbosity,
.use_spinning = threading.spin};
RuntimeConfig runtime_config = {.verbosity = verbosity,
.use_spinning = args.threading.spin};
double image_tokens_start = hwy::platform::Now();
gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
image_tokens, env);
if (inference.verbosity >= 1) {
if (verbosity >= 1) {
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
fprintf(stderr,
"\n\n[ Timing info ] Image token generation took: %d ms\n",
@ -189,7 +191,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
TimingInfo timing_info = {.verbosity = inference.verbosity};
RuntimeConfig runtime_config = {.verbosity = inference.verbosity,
.batch_stream_token = batch_stream_token,
.use_spinning = threading.spin};
.use_spinning = args.threading.spin};
inference.CopyTo(runtime_config);
std::vector<int> prompt;
size_t prefix_end = 0;
@ -252,14 +254,14 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
}
}
void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference) {
void Run(const GemmaArgs& args) {
PROFILER_ZONE("Run.misc");
ThreadingContext ctx(threading);
ThreadingContext ctx(args.threading);
MatMulEnv env(ctx);
const InferenceArgs& inference = args.inference;
if (inference.verbosity >= 3) env.print_best = true;
const Gemma gemma(loader, inference, ctx);
const Gemma gemma(args, ctx);
KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
if (inference.verbosity >= 1) {
@ -287,13 +289,12 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
if (inference.IsInteractive()) {
std::cout << "\033[2J\033[1;1H" // clear screen
<< kAsciiArtBanner << "\n\n";
ShowConfig(loader, threading, inference, gemma.Config(),
gemma.WeightReadMode(), ctx);
ShowConfig(args, gemma.Config(), gemma.WeightReadMode(), ctx);
std::cout << "\n" << instructions << "\n";
}
}
ReplGemma(threading, inference, gemma, kv_cache, env);
ReplGemma(args, gemma, kv_cache, env);
}
} // namespace gcpp
@ -302,17 +303,24 @@ int main(int argc, char** argv) {
gcpp::InternalInit();
{
// Negligible CPU time.
gcpp::LoaderArgs loader(argc, argv);
gcpp::ThreadingArgs threading(argc, argv);
gcpp::InferenceArgs inference(argc, argv);
gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
if (gcpp::HasHelp(argc, argv)) {
std::cerr << gcpp::kAsciiArtBanner;
gcpp::ShowHelp(loader, threading, inference);
fprintf(stderr,
"\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
"==========================================================\n\n"
"*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
"--weights gemma2-2b-it-sfp.sbs\n\n");
args.Help();
return 0;
}
gcpp::Run(loader, threading, inference);
// After `HasHelp` so that we print --help even if unconsumed args remain.
consumed.AbortIfUnconsumed();
gcpp::Run(args);
}
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
return 0;

View File

@ -23,7 +23,9 @@ namespace gcpp {
namespace {
struct WriterArgs : public ArgsBase<WriterArgs> {
WriterArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
WriterArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
Path output_weights;
@ -38,12 +40,15 @@ struct WriterArgs : public ArgsBase<WriterArgs> {
} // namespace gcpp
int main(int argc, char** argv) {
gcpp::WriterArgs args(argc, argv);
if (args.output_weights.Empty()) {
gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
gcpp::WriterArgs writer_args(argc, argv, consumed);
if (writer_args.output_weights.Empty()) {
HWY_ABORT("Missing --output_weights flag, a file for the model weights.");
}
consumed.AbortIfUnconsumed();
gcpp::GemmaEnv env(argc, argv);
env.GetGemma()->Save(args.output_weights, env.Env().ctx);
gcpp::GemmaEnv env(args);
env.GetGemma()->Save(writer_args.output_weights, env.Env().ctx);
return 0;
}

View File

@ -21,10 +21,9 @@
#include "evals/benchmark_helper.h"
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "io/io.h"
#include "paligemma/paligemma_helper.h"
#include "util/allocator.h"
#include "hwy/tests/hwy_gtest.h"
#include "paligemma/paligemma_helper.h"
// This test can be run manually with the downloaded PaliGemma weights.
// It should pass for `paligemma-3b-mix-224` and `paligemma2-3b-pt-448`.
@ -73,7 +72,11 @@ TEST_F(PaliGemmaTest, QueryObjects) {
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
gcpp::GemmaEnv env(argc, argv);
gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
consumed.AbortIfUnconsumed();
gcpp::GemmaEnv env(args);
gcpp::s_env = &env;
return RUN_ALL_TESTS();

View File

@ -45,10 +45,7 @@ static void RemoveTrailingZeros(std::vector<int> &vec) {
// Wrapper around GemmaEnv to expose to Python.
class GemmaModel {
public:
GemmaModel(const gcpp::LoaderArgs& loader,
const gcpp::ThreadingArgs& threading,
const gcpp::InferenceArgs& inference)
: env_(loader, threading, inference), last_prob_(0.0f) {}
GemmaModel(const gcpp::GemmaArgs& args) : env_(args), last_prob_(0.0f) {}
// Generates a single example, given a prompt and a callback to stream the
// generated tokens.
@ -254,13 +251,15 @@ PYBIND11_MODULE(gemma, mod) {
py::class_<GemmaModel>(mod, "GemmaModel")
.def(py::init([](const std::string& tokenizer, const std::string& weights,
size_t max_threads) {
const gcpp::LoaderArgs loader(tokenizer, weights);
gcpp::ThreadingArgs threading;
threading.max_lps = max_threads;
gcpp::InferenceArgs inference;
inference.max_generated_tokens = 512;
auto gemma =
std::make_unique<GemmaModel>(loader, threading, inference);
const gcpp::GemmaArgs args(gcpp::LoaderArgs(tokenizer, weights),
threading, inference);
auto gemma = std::make_unique<GemmaModel>(args);
if (!gemma->ModelIsLoaded()) {
throw std::invalid_argument("Could not load model.");
}

View File

@ -22,6 +22,7 @@
#include <algorithm> // std::transform
#include <string>
#include <vector>
#include "io/io.h" // Path
#include "util/basics.h" // Tristate
@ -29,6 +30,56 @@
namespace gcpp {
// For checking which args were not matched/consumed. Passed to each `*Args`
// ctor that parses argc/argv to ensure that their args are tracked, without
// requiring global state.
class ConsumedArgs {
public:
ConsumedArgs(int argc, char** argv) : argv_(argv), consumed_(argc) {
// We assume argc >= 1, because argv[0] is the binary name. That allows us
// to signal "called AbortIfUnconsumed" with an empty vector.
HWY_ASSERT(!consumed_.empty());
}
~ConsumedArgs() {
if (HWY_UNLIKELY(!consumed_.empty())) {
HWY_ABORT("AbortIfUnconsumed was not called.");
}
}
void NotifyConsumed(size_t idx) {
HWY_ASSERT(idx < consumed_.size());
HWY_ASSERT(consumed_[idx] == 0);
consumed_[idx] = 1;
}
// Returns index of first unconsumed arg, or 0 if none. Also disarms the
// warning in the dtor checking whether this/`AbortIfUnconsumed` were called.
size_t FirstUnconsumed() {
// Ignore argv[0], which is the binary name.
for (size_t i = 1; i < consumed_.size(); ++i) {
if (HWY_UNLIKELY(consumed_[i] == 0)) {
consumed_.clear();
return i;
}
}
consumed_.clear();
return 0;
}
void AbortIfUnconsumed() {
const size_t i = FirstUnconsumed();
if (HWY_UNLIKELY(i != 0)) {
HWY_ABORT("Unrecognized arg %zu: %s\n", i, argv_[i]);
}
}
private:
char** argv_;
std::vector<uint8_t> consumed_;
};
// Args is a class that provides a ForEach member function which visits each of
// its member variables. ArgsBase provides functions called by Args to
// initialize values to their defaults (passed as an argument to the visitor),
@ -93,7 +144,8 @@ class ArgsBase {
// consider adding a hash-map to speed this up.
class ParseVisitor {
public:
ParseVisitor(int argc, char* argv[]) : argc_(argc), argv_(argv) {}
ParseVisitor(int argc, char* argv[], ConsumedArgs& consumed)
: argc_(argc), argv_(argv), consumed_(consumed) {}
template <typename T>
void operator()(T& t, const char* name, const T& /*init*/,
@ -108,6 +160,8 @@ class ArgsBase {
if (!SetValue(argv_[i + 1], t)) {
HWY_ABORT("Invalid value for %s, got %s\n", name, argv_[i + 1]);
}
consumed_.NotifyConsumed(i);
consumed_.NotifyConsumed(i + 1);
return;
}
if (std::string(argv_[i]).find(prefixed_eq) == 0) {
@ -115,6 +169,7 @@ class ArgsBase {
if (!SetValue(value, t)) {
HWY_ABORT("Invalid value for %s, got %s\n", name, value);
}
consumed_.NotifyConsumed(i);
return;
}
}
@ -181,8 +236,9 @@ class ArgsBase {
}
}
int argc_;
char** argv_;
const int argc_;
char** const argv_;
ConsumedArgs& consumed_;
}; // ParseVisitor
template <class Visitor>
@ -211,15 +267,15 @@ class ArgsBase {
ForEach(visitor);
}
void Parse(int argc, char* argv[]) {
ParseVisitor visitor(argc, argv);
void Parse(int argc, char* argv[], ConsumedArgs& consumed) {
ParseVisitor visitor(argc, argv, consumed);
ForEach(visitor);
}
// For convenience, enables single-line constructor.
void InitAndParse(int argc, char* argv[]) {
void InitAndParse(int argc, char* argv[], ConsumedArgs& consumed) {
Init();
Parse(argc, argv);
Parse(argc, argv, consumed);
}
};

View File

@ -38,7 +38,9 @@ namespace gcpp {
// Optional arguments for `ThreadingContext` from the command line.
class ThreadingArgs : public ArgsBase<ThreadingArgs> {
public:
ThreadingArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
ThreadingArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
ThreadingArgs() { Init(); };
// For BoundedTopology: