mirror of https://github.com/google/gemma.cpp.git
Abort if args are unrecognized, refactor argument passing
This catches typos/incorrect usage. Refactor: group Loader/Threading/Inference into GemmaArgs. All *Args ctors now have an extra ConsumedArgs& argument. PiperOrigin-RevId: 844690553
This commit is contained in:
parent
f50550f4ce
commit
0c64987a96
11
BUILD.bazel
11
BUILD.bazel
|
|
@ -556,12 +556,22 @@ cc_library(
|
|||
":basics",
|
||||
":configs",
|
||||
":mat",
|
||||
":threading_context",
|
||||
"//io",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "gemma_args_test",
|
||||
srcs = ["gemma/gemma_args_test.cc"],
|
||||
deps = [
|
||||
":gemma_args",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gemma_lib",
|
||||
srcs = [
|
||||
|
|
@ -666,7 +676,6 @@ cc_library(
|
|||
":gemma_args",
|
||||
":gemma_lib",
|
||||
":matmul_env",
|
||||
":ops",
|
||||
":threading_context",
|
||||
":tokenizer",
|
||||
"@google_benchmark//:benchmark",
|
||||
|
|
|
|||
|
|
@ -222,6 +222,7 @@ set(GEMMA_TEST_FILES
|
|||
compression/nuq_test.cc
|
||||
compression/sfp_test.cc
|
||||
evals/gemma_test.cc
|
||||
gemma/gemma_args_test.cc
|
||||
gemma/flash_attention_test.cc
|
||||
gemma/tensor_info_test.cc
|
||||
io/blob_store_test.cc
|
||||
|
|
|
|||
|
|
@ -101,13 +101,11 @@ cc_test(
|
|||
# for test_suite.
|
||||
tags = ["hwy_ops_test"],
|
||||
deps = [
|
||||
":distortion",
|
||||
":int",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//:test_util",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:nanobenchmark",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,9 @@ using json = nlohmann::json;
|
|||
|
||||
class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
|
||||
public:
|
||||
BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
BenchmarkArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||
InitAndParse(argc, argv, consumed);
|
||||
}
|
||||
|
||||
Path summarize_text;
|
||||
Path cross_entropy;
|
||||
|
|
@ -127,9 +129,16 @@ int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file,
|
|||
} // namespace gcpp
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
gcpp::BenchmarkArgs benchmark_args(argc, argv);
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||
gcpp::BenchmarkArgs benchmark_args(argc, argv, consumed);
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
args.Help();
|
||||
return 0;
|
||||
}
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
gcpp::GemmaEnv env(args);
|
||||
if (!benchmark_args.summarize_text.Empty()) {
|
||||
return BenchmarkSummary(env, benchmark_args.summarize_text);
|
||||
} else if (!benchmark_args.cross_entropy.Empty()) {
|
||||
|
|
|
|||
|
|
@ -36,35 +36,29 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference)
|
||||
GemmaEnv::GemmaEnv(const GemmaArgs& args)
|
||||
: initializer_value_(gcpp::InternalInit()),
|
||||
ctx_(threading),
|
||||
ctx_(args.threading),
|
||||
env_(ctx_),
|
||||
gemma_(loader, inference, ctx_) {
|
||||
gemma_(args, ctx_) {
|
||||
const ModelConfig& config = gemma_.Config();
|
||||
// Only allocate one for starters because GenerateBatch might not be called.
|
||||
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator));
|
||||
kv_caches_.push_back(KVCache(config, args.inference, ctx_.allocator));
|
||||
|
||||
if (inference.verbosity >= 2) {
|
||||
ShowConfig(loader, threading, inference, config, gemma_.WeightReadMode(),
|
||||
ctx_);
|
||||
if (args.inference.verbosity >= 2) {
|
||||
ShowConfig(args, config, gemma_.WeightReadMode(), ctx_);
|
||||
}
|
||||
if (inference.verbosity >= 3) env_.print_best = true;
|
||||
if (inference.verbosity >= 4) env_.print_config = true;
|
||||
if (args.inference.verbosity >= 3) env_.print_best = true;
|
||||
if (args.inference.verbosity >= 4) env_.print_config = true;
|
||||
|
||||
runtime_config_ = {
|
||||
.max_generated_tokens = inference.max_generated_tokens,
|
||||
.temperature = inference.temperature,
|
||||
.verbosity = inference.verbosity,
|
||||
.max_generated_tokens = args.inference.max_generated_tokens,
|
||||
.temperature = args.inference.temperature,
|
||||
.verbosity = args.inference.verbosity,
|
||||
};
|
||||
inference.CopyTo(runtime_config_);
|
||||
args.inference.CopyTo(runtime_config_);
|
||||
}
|
||||
|
||||
GemmaEnv::GemmaEnv(int argc, char** argv)
|
||||
: GemmaEnv(LoaderArgs(argc, argv), ThreadingArgs(argc, argv),
|
||||
InferenceArgs(argc, argv)) {}
|
||||
|
||||
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
||||
QueryResult result;
|
||||
|
||||
|
|
@ -234,19 +228,19 @@ static constexpr const char* CompiledConfig() {
|
|||
}
|
||||
}
|
||||
|
||||
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference, const ModelConfig& config,
|
||||
void ShowConfig(const GemmaArgs& args, const ModelConfig& config,
|
||||
const WeightsPtrs::Mode weight_read_mode,
|
||||
const ThreadingContext& ctx) {
|
||||
threading.Print(inference.verbosity);
|
||||
loader.Print(inference.verbosity);
|
||||
inference.Print(inference.verbosity);
|
||||
fprintf(
|
||||
stderr, "Model : %s, to_bf16 %d, mmap %d => %s\n",
|
||||
config.Specifier().c_str(), static_cast<int>(loader.to_bf16),
|
||||
static_cast<int>(loader.map), WeightsPtrs::ToString(weight_read_mode));
|
||||
args.threading.Print(args.inference.verbosity);
|
||||
args.loader.Print(args.inference.verbosity);
|
||||
args.inference.Print(args.inference.verbosity);
|
||||
fprintf(stderr,
|
||||
"Model : %s, to_bf16 %d, mmap %d => %s\n",
|
||||
config.Specifier().c_str(), static_cast<int>(args.loader.to_bf16),
|
||||
static_cast<int>(args.loader.map),
|
||||
WeightsPtrs::ToString(weight_read_mode));
|
||||
|
||||
if (inference.verbosity >= 2) {
|
||||
if (args.inference.verbosity >= 2) {
|
||||
time_t now = time(nullptr);
|
||||
char* dt = ctime(&now); // NOLINT
|
||||
char cpu100[100] = "unknown";
|
||||
|
|
@ -259,7 +253,7 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
"Instruction set : %s (%zu bits)\n"
|
||||
"Compiled config : %s, profiler %d\n"
|
||||
"Memory MiB : %4zu\n",
|
||||
dt, cpu100, static_cast<int>(threading.bind),
|
||||
dt, cpu100, static_cast<int>(args.threading.bind),
|
||||
ctx.topology.TopologyString(), ctx.pools.PinString(),
|
||||
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
|
||||
ctx.cache_info.VectorBytes() * 8, CompiledConfig(),
|
||||
|
|
@ -267,22 +261,4 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
}
|
||||
}
|
||||
|
||||
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference) {
|
||||
std::cerr
|
||||
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
|
||||
"==========================================================\n\n"
|
||||
"To run with pre-2025 weights, specify --tokenizer and --weights.\n"
|
||||
"With the single-file weights format, specify just --weights.\n";
|
||||
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
|
||||
"--weights gemma2-2b-it-sfp.sbs\n";
|
||||
std::cerr << "\n*Model Loading Arguments*\n\n";
|
||||
loader.Help();
|
||||
std::cerr << "\n*Threading Arguments*\n\n";
|
||||
threading.Help();
|
||||
std::cerr << "\n*Inference Arguments*\n\n";
|
||||
inference.Help();
|
||||
std::cerr << "\n";
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/gemma_args.h" // IWYU pragma: export
|
||||
#include "gemma/tokenizer.h" // WrapAndTokenize
|
||||
#include "ops/matmul.h"
|
||||
#include "util/threading_context.h"
|
||||
|
|
@ -50,10 +50,8 @@ struct QueryResultAndMetrics {
|
|||
// Convenience class to load a model and run inference.
|
||||
class GemmaEnv {
|
||||
public:
|
||||
// Calls the other constructor with *Args arguments initialized from argv.
|
||||
GemmaEnv(int argc, char** argv);
|
||||
GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference);
|
||||
explicit GemmaEnv(const GemmaArgs& args);
|
||||
|
||||
MatMulEnv& Env() { return env_; }
|
||||
|
||||
size_t MaxGeneratedTokens() const {
|
||||
|
|
@ -137,12 +135,9 @@ class GemmaEnv {
|
|||
// Logs the inference speed in tokens/sec.
|
||||
void LogSpeedStats(double time_start, size_t total_tokens);
|
||||
|
||||
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference, const ModelConfig& config,
|
||||
void ShowConfig(const GemmaArgs& args, const ModelConfig& config,
|
||||
WeightsPtrs::Mode weight_read_mode,
|
||||
const ThreadingContext& ctx);
|
||||
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
|
|
@ -98,7 +98,11 @@ BENCHMARK(BM_coding_prompt)
|
|||
->UseRealTime();
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
gcpp::GemmaEnv env(args);
|
||||
env.SetMaxGeneratedTokens(256);
|
||||
gcpp::s_env = &env;
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,9 @@ namespace gcpp {
|
|||
|
||||
class PromptArgs : public ArgsBase<PromptArgs> {
|
||||
public:
|
||||
PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
PromptArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||
InitAndParse(argc, argv, consumed);
|
||||
}
|
||||
|
||||
Path layers_output; // optional
|
||||
std::string prompt;
|
||||
|
|
@ -51,11 +53,15 @@ class PromptArgs : public ArgsBase<PromptArgs> {
|
|||
};
|
||||
|
||||
int Run(int argc, char** argv) {
|
||||
PromptArgs prompt_args(argc, argv);
|
||||
ConsumedArgs consumed(argc, argv);
|
||||
const GemmaArgs args(argc, argv, consumed);
|
||||
const PromptArgs prompt_args(argc, argv, consumed);
|
||||
AbortIfInvalidArgs(prompt_args);
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
json json_output;
|
||||
GemmaEnv env(argc, argv);
|
||||
GemmaEnv env(args);
|
||||
|
||||
env.MutableConfig().layers_output =
|
||||
prompt_args.layers_output.Empty()
|
||||
? LayersOutputFunc()
|
||||
|
|
|
|||
|
|
@ -146,7 +146,11 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
|
|||
|
||||
int main(int argc, char** argv) {
|
||||
fprintf(stderr, "GemmaEnv setup..\n");
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
gcpp::GemmaEnv env(args);
|
||||
gcpp::s_env = &env;
|
||||
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@
|
|||
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "io/io.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
||||
|
|
@ -42,7 +41,11 @@ class GemmaTest : public ::testing::Test {
|
|||
// Requires argc/argv, hence do not use `SetUpTestSuite`.
|
||||
static void InitEnv(int argc, char** argv) {
|
||||
HWY_ASSERT(s_env == nullptr); // Should only be called once.
|
||||
s_env = new GemmaEnv(argc, argv);
|
||||
ConsumedArgs consumed(argc, argv);
|
||||
GemmaArgs args(argc, argv, consumed);
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
s_env = new GemmaEnv(args);
|
||||
const gcpp::ModelConfig& config = s_env->GetGemma()->Config();
|
||||
fprintf(stderr, "Using %s\n", config.Specifier().c_str());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,7 +31,9 @@
|
|||
namespace gcpp {
|
||||
|
||||
struct JsonArgs : public ArgsBase<JsonArgs> {
|
||||
JsonArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
JsonArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||
InitAndParse(argc, argv, consumed);
|
||||
}
|
||||
|
||||
Path input;
|
||||
|
||||
|
|
@ -151,10 +153,14 @@ void Run(GemmaEnv& env, JsonArgs& json) {
|
|||
int main(int argc, char** argv) {
|
||||
{
|
||||
PROFILER_ZONE("Startup.all");
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
gcpp::JsonArgs json(argc, argv);
|
||||
gcpp::AbortIfInvalidArgs(json);
|
||||
gcpp::Run(env, json);
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||
gcpp::JsonArgs json_args(argc, argv, consumed);
|
||||
gcpp::AbortIfInvalidArgs(json_args);
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
gcpp::GemmaEnv env(args);
|
||||
gcpp::Run(env, json_args);
|
||||
}
|
||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -24,20 +24,20 @@
|
|||
#include <vector>
|
||||
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h" // LoaderArgs
|
||||
#include "gemma/gemma_args.h" // GemmaArgs
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "util/args.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
gcpp::ThreadingArgs threading(argc, argv);
|
||||
gcpp::InferenceArgs inference(argc, argv);
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
loader.Help();
|
||||
args.Help();
|
||||
return 0;
|
||||
}
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
// Demonstrate constrained decoding by never outputting certain tokens.
|
||||
std::set<int> reject_tokens;
|
||||
|
|
@ -49,10 +49,10 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
|
||||
// Instantiate model and KV Cache
|
||||
gcpp::ThreadingContext ctx(threading);
|
||||
gcpp::ThreadingContext ctx(args.threading);
|
||||
gcpp::MatMulEnv env(ctx);
|
||||
gcpp::Gemma gemma(loader, inference, ctx);
|
||||
gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
||||
gcpp::Gemma gemma(args, ctx);
|
||||
gcpp::KVCache kv_cache(gemma.Config(), args.inference, ctx.allocator);
|
||||
size_t generated = 0;
|
||||
|
||||
// Tokenize instructions.
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "third_party/gemma_cpp/gemma/gemma.h"
|
||||
#include "third_party/gemma_cpp/gemma/gemma_args.h" // LoaderArgs
|
||||
#include "third_party/gemma_cpp/gemma/gemma_args.h" // GemmaArgs
|
||||
#include "third_party/gemma_cpp/gemma/tokenizer.h"
|
||||
#include "third_party/gemma_cpp/ops/matmul.h"
|
||||
#include "third_party/gemma_cpp/util/threading_context.h"
|
||||
|
|
@ -31,18 +31,11 @@
|
|||
|
||||
class SimplifiedGemma {
|
||||
public:
|
||||
SimplifiedGemma(const gcpp::LoaderArgs& loader,
|
||||
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
|
||||
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
|
||||
: ctx_(threading),
|
||||
SimplifiedGemma(const gcpp::GemmaArgs& args)
|
||||
: ctx_(args.threading),
|
||||
env_(ctx_),
|
||||
gemma_(loader, inference, ctx_),
|
||||
kv_cache_(gemma_.Config(), inference, ctx_.allocator) {}
|
||||
|
||||
SimplifiedGemma(int argc, char** argv)
|
||||
: SimplifiedGemma(gcpp::LoaderArgs(argc, argv),
|
||||
gcpp::ThreadingArgs(argc, argv),
|
||||
gcpp::InferenceArgs(argc, argv)) {}
|
||||
gemma_(args, ctx_),
|
||||
kv_cache_(gemma_.Config(), args.inference, ctx_.allocator) {}
|
||||
|
||||
void Generate(std::string& prompt, size_t max_generated_tokens = 1024,
|
||||
float temperature = 0.7,
|
||||
|
|
|
|||
|
|
@ -18,26 +18,16 @@
|
|||
#include <string>
|
||||
|
||||
#include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp"
|
||||
#include "gemma/gemma_args.h" // LoaderArgs
|
||||
#include "gemma/gemma_args.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
// Standard usage: LoaderArgs takes argc and argv as input, then parses
|
||||
// necessary flags.
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
// Sets arguments from argc and argv. Note that you can instead pass in
|
||||
// LoaderArgs, ThreadingArgs, and InferenceArgs directly.
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
// Optional: LoaderArgs can also take tokenizer and weights paths directly.
|
||||
//
|
||||
// gcpp::LoaderArgs loader("/path/to/tokenizer", "/path/to/weights",
|
||||
// "model_identifier");
|
||||
|
||||
// Optional: ThreadingArgs and InferenceArgs can be passed in as well. If not
|
||||
// specified, default values will be used.
|
||||
//
|
||||
// gcpp::InferenceArgs inference(argc, argv);
|
||||
// gcpp::ThreadingArgs threading(argc, argv);
|
||||
// SimplifiedGemma gemma(loader, threading, inference);
|
||||
|
||||
SimplifiedGemma gemma(loader);
|
||||
SimplifiedGemma gemma(args);
|
||||
std::string prompt = "Write a greeting to the world.";
|
||||
gemma.Generate(prompt, 256, 0.6);
|
||||
|
||||
|
|
|
|||
|
|
@ -15,18 +15,22 @@
|
|||
|
||||
// Test client for API server
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "httplib.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// ANSI color codes
|
||||
const std::string RESET = "\033[0m";
|
||||
const std::string BOLD = "\033[1m";
|
||||
|
|
@ -37,9 +41,15 @@ const std::string YELLOW = "\033[33m";
|
|||
const std::string RED = "\033[31m";
|
||||
|
||||
class APIClient {
|
||||
public:
|
||||
APIClient(const std::string& host, int port, const std::string& api_key = "", const std::string& model = "gemma3-4b")
|
||||
: host_(host), port_(port), api_key_(api_key), model_(model), use_https_(port == 443), interactive_mode_(false) {
|
||||
public:
|
||||
APIClient(const std::string& host, int port, const std::string& api_key = "",
|
||||
const std::string& model = "gemma3-4b")
|
||||
: host_(host),
|
||||
port_(port),
|
||||
api_key_(api_key),
|
||||
model_(model),
|
||||
use_https_(port == 443),
|
||||
interactive_mode_(false) {
|
||||
if (use_https_) {
|
||||
ssl_client_ = std::make_unique<httplib::SSLClient>(host, port);
|
||||
ssl_client_->set_read_timeout(60, 0);
|
||||
|
|
@ -58,7 +68,9 @@ public:
|
|||
|
||||
std::string endpoint;
|
||||
if (is_public_api) {
|
||||
endpoint = stream ? "/v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse"
|
||||
endpoint =
|
||||
stream
|
||||
? "/v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse"
|
||||
: "/v1beta/models/gemini-2.0-flash:generateContent";
|
||||
} else {
|
||||
endpoint = stream ? "/v1beta/models/" + model_ + ":streamGenerateContent"
|
||||
|
|
@ -67,7 +79,8 @@ public:
|
|||
|
||||
// Only show verbose output in non-interactive mode
|
||||
if (!interactive_mode_) {
|
||||
std::cout << "\n" << BOLD << BLUE << "📤 POST " << endpoint << RESET << std::endl;
|
||||
std::cout << "\n"
|
||||
<< BOLD << BLUE << "📤 POST " << endpoint << RESET << std::endl;
|
||||
std::cout << "Request: " << request.dump(2) << std::endl;
|
||||
}
|
||||
|
||||
|
|
@ -83,18 +96,21 @@ public:
|
|||
json response = ProcessRequest(request, stream);
|
||||
|
||||
if (response.contains("error")) {
|
||||
std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl;
|
||||
std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void TestListModels() {
|
||||
std::cout << "\n" << BOLD << BLUE << "📤 GET /v1beta/models" << RESET << std::endl;
|
||||
std::cout << "\n"
|
||||
<< BOLD << BLUE << "📤 GET /v1beta/models" << RESET << std::endl;
|
||||
|
||||
httplib::Headers headers;
|
||||
if (!api_key_.empty()) {
|
||||
headers.emplace("X-goog-api-key", api_key_);
|
||||
}
|
||||
auto res = use_https_ ? ssl_client_->Get("/v1beta/models", headers) : client_->Get("/v1beta/models", headers);
|
||||
auto res = use_https_ ? ssl_client_->Get("/v1beta/models", headers)
|
||||
: client_->Get("/v1beta/models", headers);
|
||||
|
||||
if (res && res->status == 200) {
|
||||
json response = json::parse(res->body);
|
||||
|
|
@ -106,7 +122,9 @@ public:
|
|||
}
|
||||
|
||||
void InteractiveChat() {
|
||||
std::cout << "\n" << BOLD << CYAN << "💬 Interactive Chat Mode (with session)" << RESET << std::endl;
|
||||
std::cout << "\n"
|
||||
<< BOLD << CYAN << "💬 Interactive Chat Mode (with session)"
|
||||
<< RESET << std::endl;
|
||||
std::cout << "Type ':gemma %q' to end.\n" << std::endl;
|
||||
|
||||
interactive_mode_ = true;
|
||||
|
|
@ -141,14 +159,16 @@ public:
|
|||
|
||||
if (response.contains("candidates") && !response["candidates"].empty()) {
|
||||
auto& candidate = response["candidates"][0];
|
||||
if (candidate.contains("content") && candidate["content"].contains("parts")) {
|
||||
if (candidate.contains("content") &&
|
||||
candidate["content"].contains("parts")) {
|
||||
for (const auto& part : candidate["content"]["parts"]) {
|
||||
if (part.contains("text")) {
|
||||
std::string assistant_response = part["text"].get<std::string>();
|
||||
|
||||
// For streaming, the response is already displayed in real-time
|
||||
// Just add to message history for context
|
||||
json assistant_message = {{"parts", {{{"text", assistant_response}}}}};
|
||||
json assistant_message = {
|
||||
{"parts", {{{"text", assistant_response}}}}};
|
||||
if (!api_key_.empty()) {
|
||||
assistant_message["role"] = "model";
|
||||
}
|
||||
|
|
@ -157,22 +177,20 @@ public:
|
|||
}
|
||||
}
|
||||
} else if (response.contains("error")) {
|
||||
std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl;
|
||||
std::cerr << RED << "❌ Error: " << response["error"]["message"]
|
||||
<< RESET << std::endl;
|
||||
}
|
||||
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
json CreateAPIRequest(const std::string& prompt, const json& messages = json::array()) {
|
||||
private:
|
||||
json CreateAPIRequest(const std::string& prompt,
|
||||
const json& messages = json::array()) {
|
||||
json request = {
|
||||
{"generationConfig", {
|
||||
{"temperature", 0.9},
|
||||
{"topK", 1},
|
||||
{"maxOutputTokens", 1024}
|
||||
}}
|
||||
};
|
||||
{"generationConfig",
|
||||
{{"temperature", 0.9}, {"topK", 1}, {"maxOutputTokens", 1024}}}};
|
||||
|
||||
if (messages.empty()) {
|
||||
// Single prompt
|
||||
|
|
@ -189,38 +207,42 @@ private:
|
|||
return request;
|
||||
}
|
||||
|
||||
json ProcessNonStreamingRequest(const json& request, const std::string& endpoint) {
|
||||
json ProcessNonStreamingRequest(const json& request,
|
||||
const std::string& endpoint) {
|
||||
httplib::Headers headers = {{"Content-Type", "application/json"}};
|
||||
if (!api_key_.empty()) {
|
||||
headers.emplace("X-goog-api-key", api_key_);
|
||||
}
|
||||
|
||||
auto res = use_https_ ? ssl_client_->Post(endpoint, headers, request.dump(), "application/json")
|
||||
: client_->Post(endpoint, headers, request.dump(), "application/json");
|
||||
auto res = use_https_ ? ssl_client_->Post(endpoint, headers, request.dump(),
|
||||
"application/json")
|
||||
: client_->Post(endpoint, headers, request.dump(),
|
||||
"application/json");
|
||||
|
||||
if (res && res->status == 200) {
|
||||
json response = json::parse(res->body);
|
||||
if (!interactive_mode_) {
|
||||
std::cout << "\n" << BOLD << GREEN << "📥 Response:" << RESET << std::endl;
|
||||
std::cout << "\n"
|
||||
<< BOLD << GREEN << "📥 Response:" << RESET << std::endl;
|
||||
std::cout << response.dump(2) << std::endl;
|
||||
}
|
||||
return response;
|
||||
} else {
|
||||
json error_response = {
|
||||
{"error", {
|
||||
{"message", "Request failed"},
|
||||
{"status", res ? res->status : -1}
|
||||
}}
|
||||
};
|
||||
json error_response = {{"error",
|
||||
{{"message", "Request failed"},
|
||||
{"status", res ? res->status : -1}}}};
|
||||
if (res && !res->body.empty()) {
|
||||
error_response["error"]["details"] = res->body;
|
||||
}
|
||||
std::cerr << RED << "❌ Request failed. Status: " << (res ? res->status : -1) << RESET << std::endl;
|
||||
std::cerr << RED
|
||||
<< "❌ Request failed. Status: " << (res ? res->status : -1)
|
||||
<< RESET << std::endl;
|
||||
return error_response;
|
||||
}
|
||||
}
|
||||
|
||||
json ProcessStreamingRequest(const json& request, const std::string& endpoint) {
|
||||
json ProcessStreamingRequest(const json& request,
|
||||
const std::string& endpoint) {
|
||||
std::string accumulated_response;
|
||||
|
||||
// Use same SSE logic for both public and local APIs
|
||||
|
|
@ -233,7 +255,9 @@ private:
|
|||
}
|
||||
req.body = request.dump();
|
||||
|
||||
req.content_receiver = [&accumulated_response, this](const char* data, size_t data_length, uint64_t offset, uint64_t total_length) -> bool {
|
||||
req.content_receiver = [&accumulated_response, this](
|
||||
const char* data, size_t data_length,
|
||||
uint64_t offset, uint64_t total_length) -> bool {
|
||||
std::string chunk(data, data_length);
|
||||
std::istringstream stream(chunk);
|
||||
std::string line;
|
||||
|
|
@ -244,14 +268,18 @@ private:
|
|||
|
||||
if (event_data == "[DONE]") {
|
||||
if (!interactive_mode_) {
|
||||
std::cout << "\n\n" << GREEN << "✅ Generation complete!" << RESET << std::endl;
|
||||
std::cout << "\n\n"
|
||||
<< GREEN << "✅ Generation complete!" << RESET
|
||||
<< std::endl;
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
json event = json::parse(event_data);
|
||||
if (event.contains("candidates") && !event["candidates"].empty()) {
|
||||
if (event.contains("candidates") &&
|
||||
!event["candidates"].empty()) {
|
||||
auto& candidate = event["candidates"][0];
|
||||
if (candidate.contains("content") && candidate["content"].contains("parts")) {
|
||||
if (candidate.contains("content") &&
|
||||
candidate["content"].contains("parts")) {
|
||||
for (const auto& part : candidate["content"]["parts"]) {
|
||||
if (part.contains("text")) {
|
||||
std::string text = part["text"].get<std::string>();
|
||||
|
|
@ -272,32 +300,27 @@ private:
|
|||
|
||||
httplib::Response res;
|
||||
httplib::Error error;
|
||||
bool success = use_https_ ? ssl_client_->send(req, res, error) : client_->send(req, res, error);
|
||||
bool success = use_https_ ? ssl_client_->send(req, res, error)
|
||||
: client_->send(req, res, error);
|
||||
|
||||
if (res.status == 200 && !accumulated_response.empty()) {
|
||||
return json{
|
||||
{"candidates", {{
|
||||
{"content", {
|
||||
{"parts", {{{"text", accumulated_response}}}}
|
||||
}}
|
||||
}}}
|
||||
};
|
||||
{"candidates",
|
||||
{{{"content", {{"parts", {{{"text", accumulated_response}}}}}}}}}};
|
||||
} else {
|
||||
json error_response = {
|
||||
{"error", {
|
||||
{"message", "Streaming request failed"},
|
||||
{"status", res.status}
|
||||
}}
|
||||
};
|
||||
{"error",
|
||||
{{"message", "Streaming request failed"}, {"status", res.status}}}};
|
||||
if (!res.body.empty()) {
|
||||
error_response["error"]["details"] = res.body;
|
||||
}
|
||||
std::cerr << RED << "❌ Streaming request failed. Status: " << res.status << RESET << std::endl;
|
||||
std::cerr << RED << "❌ Streaming request failed. Status: " << res.status
|
||||
<< RESET << std::endl;
|
||||
return error_response;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
std::unique_ptr<httplib::Client> client_;
|
||||
std::unique_ptr<httplib::SSLClient> ssl_client_;
|
||||
std::string host_;
|
||||
|
|
@ -308,19 +331,55 @@ private:
|
|||
bool interactive_mode_;
|
||||
};
|
||||
|
||||
struct ClientArgs : public ArgsBase<ClientArgs> {
|
||||
ClientArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||
InitAndParse(argc, argv, consumed);
|
||||
}
|
||||
ClientArgs() { Init(); };
|
||||
|
||||
std::string host;
|
||||
int port;
|
||||
std::string api_key;
|
||||
std::string model;
|
||||
std::string prompt;
|
||||
bool interactive;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(host, "host", std::string("localhost"),
|
||||
"Server host (default: localhost)");
|
||||
visitor(port, "port", 8080, "Server port (default: 8080)");
|
||||
visitor(api_key, "api_key", std::string(""),
|
||||
"Use public API with key (changes host to "
|
||||
"generativelanguage.googleapis.com:443)");
|
||||
visitor(model, "model", std::string("gemma3-4b"),
|
||||
"Model name to use (default: gemma3-4b)");
|
||||
visitor(prompt, "prompt", std::string("Hello! How are you?"),
|
||||
"Prompt for generation (default: 'Hello! How are you?')");
|
||||
visitor(interactive, "interactive", false,
|
||||
"Start interactive chat mode (0 = no, 1 = yes)");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gcpp::ClientArgs client_args(argc, argv);
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::ClientArgs client_args(argc, argv, consumed);
|
||||
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
std::cout << "\nAPI Client for gemma.cpp\n";
|
||||
std::cout << "========================\n\n";
|
||||
fprintf(stderr,
|
||||
"\nAPI Client for gemma.cpp\n"
|
||||
"========================\n\n");
|
||||
client_args.Help();
|
||||
std::cout << std::endl;
|
||||
std::cout << "Environment Variables:" << std::endl;
|
||||
std::cout << " GOOGLE_API_KEY : Automatically use public Google API if set" << std::endl;
|
||||
fprintf(stderr,
|
||||
"\n*Environment Variables:\n"
|
||||
" GOOGLE_API_KEY : Automatically use public Google API if set\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
// Check for GOOGLE_API_KEY environment variable
|
||||
const char* env_api_key = std::getenv("GOOGLE_API_KEY");
|
||||
if (env_api_key != nullptr && strlen(env_api_key) > 0) {
|
||||
|
|
@ -335,11 +394,12 @@ int main(int argc, char* argv[]) {
|
|||
client_args.port = 443;
|
||||
}
|
||||
|
||||
std::cout << BOLD << YELLOW << "🚀 Testing API Server at "
|
||||
<< client_args.host << ":" << client_args.port << RESET << std::endl;
|
||||
std::cout << BOLD << YELLOW << "🚀 Testing API Server at " << client_args.host
|
||||
<< ":" << client_args.port << RESET << std::endl;
|
||||
|
||||
try {
|
||||
APIClient client(client_args.host, client_args.port, client_args.api_key, client_args.model);
|
||||
APIClient client(client_args.host, client_args.port, client_args.api_key,
|
||||
client_args.model);
|
||||
|
||||
if (client_args.interactive) {
|
||||
client.InteractiveChat();
|
||||
|
|
@ -347,11 +407,12 @@ int main(int argc, char* argv[]) {
|
|||
client.TestListModels();
|
||||
client.TestGenerateContent(client_args.prompt, true);
|
||||
}
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << RED << "❌ Error: " << e.what() << RESET << std::endl;
|
||||
std::cerr << "Make sure the API server is running:" << std::endl;
|
||||
std::cerr << " ./build/gemma_api_server --tokenizer <path> --weights <path>" << std::endl;
|
||||
std::cerr
|
||||
<< " ./build/gemma_api_server --tokenizer <path> --weights <path>"
|
||||
<< std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -15,22 +15,19 @@
|
|||
|
||||
// HTTP API server for gemma.cpp with SSE support
|
||||
|
||||
#include <stdio.h>
|
||||
#include <signal.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <thread> // NOLINT
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
// HTTP server library
|
||||
#undef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
|
|
@ -38,16 +35,12 @@
|
|||
#include "httplib.h"
|
||||
|
||||
// JSON library
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include "compression/types.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "ops/matmul.h"
|
||||
#include "util/args.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
|
|
@ -90,7 +83,8 @@ struct ServerState {
|
|||
std::lock_guard<std::mutex> lock(sessions_mutex);
|
||||
auto& session = sessions[session_id];
|
||||
if (!session.kv_cache) {
|
||||
session.kv_cache = std::make_unique<KVCache>(gemma->Config(), InferenceArgs(), env->ctx.allocator);
|
||||
session.kv_cache = std::make_unique<KVCache>(
|
||||
gemma->Config(), InferenceArgs(), env->ctx.allocator);
|
||||
}
|
||||
session.last_access = std::chrono::steady_clock::now();
|
||||
return session;
|
||||
|
|
@ -107,7 +101,8 @@ std::string GenerateSessionId() {
|
|||
return ss.str();
|
||||
}
|
||||
|
||||
// Wraps messages with start_of_turn markers - handles both with and without roles
|
||||
// Wraps messages with start_of_turn markers - handles both with and without
|
||||
// roles
|
||||
std::string WrapMessagesWithTurnMarkers(const json& contents) {
|
||||
std::string prompt;
|
||||
|
||||
|
|
@ -121,12 +116,14 @@ std::string WrapMessagesWithTurnMarkers(const json& contents) {
|
|||
std::string text = part["text"];
|
||||
|
||||
if (role == "user") {
|
||||
prompt += "<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
|
||||
prompt +=
|
||||
"<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
|
||||
} else if (role == "model") {
|
||||
prompt += text + "\n";
|
||||
} else if (role.empty()) {
|
||||
// Local format without roles - for now, treat as user input
|
||||
prompt += "<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
|
||||
prompt +=
|
||||
"<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -163,18 +160,15 @@ RuntimeConfig ParseGenerationConfig(const json& request) {
|
|||
return config;
|
||||
}
|
||||
|
||||
// Unified response formatter - creates consistent format regardless of request type
|
||||
json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false) {
|
||||
// Unified response formatter - creates consistent format regardless of request
|
||||
// type
|
||||
json CreateAPIResponse(const std::string& text,
|
||||
bool is_streaming_chunk = false) {
|
||||
json response = {
|
||||
{"candidates", {{
|
||||
{"content", {
|
||||
{"parts", {{{"text", text}}}},
|
||||
{"role", "model"}
|
||||
}},
|
||||
{"index", 0}
|
||||
}}},
|
||||
{"promptFeedback", {{"safetyRatings", json::array()}}}
|
||||
};
|
||||
{"candidates",
|
||||
{{{"content", {{"parts", {{{"text", text}}}}, {"role", "model"}}},
|
||||
{"index", 0}}}},
|
||||
{"promptFeedback", {{"safetyRatings", json::array()}}}};
|
||||
|
||||
// Only add finishReason for non-streaming chunks
|
||||
if (!is_streaming_chunk) {
|
||||
|
|
@ -185,7 +179,9 @@ json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false)
|
|||
}
|
||||
|
||||
// Handle generateContent endpoint (non-streaming)
|
||||
void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) {
|
||||
void HandleGenerateContentNonStreaming(ServerState& state,
|
||||
const httplib::Request& req,
|
||||
httplib::Response& res) {
|
||||
try {
|
||||
json request = json::parse(req.body);
|
||||
|
||||
|
|
@ -199,7 +195,9 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
|
|||
prompt = WrapMessagesWithTurnMarkers(request["contents"]);
|
||||
} else {
|
||||
res.status = 400;
|
||||
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json");
|
||||
res.set_content(
|
||||
json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(),
|
||||
"application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -209,12 +207,7 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
|
|||
// Set up runtime config
|
||||
RuntimeConfig runtime_config = ParseGenerationConfig(request);
|
||||
|
||||
// Collect full response
|
||||
std::string full_response;
|
||||
runtime_config.stream_token = [&full_response](int token, float) {
|
||||
// Skip EOS token
|
||||
return true;
|
||||
};
|
||||
runtime_config.stream_token = [](int token, float) { return true; };
|
||||
|
||||
// Tokenize prompt
|
||||
std::vector<int> tokens = WrapAndTokenize(
|
||||
|
|
@ -227,7 +220,8 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
|
|||
|
||||
// Temporarily redirect output to capture response
|
||||
std::stringstream output;
|
||||
runtime_config.stream_token = [&output, &state, &session, &tokens](int token, float) {
|
||||
runtime_config.stream_token = [&output, &state, &session, &tokens](
|
||||
int token, float) {
|
||||
// Skip prompt tokens
|
||||
if (session.abs_pos < tokens.size()) {
|
||||
session.abs_pos++;
|
||||
|
|
@ -279,7 +273,9 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
|
|||
}
|
||||
|
||||
// Handle streamGenerateContent endpoint with SSE)
|
||||
void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) {
|
||||
void HandleGenerateContentStreaming(ServerState& state,
|
||||
const httplib::Request& req,
|
||||
httplib::Response& res) {
|
||||
try {
|
||||
json request = json::parse(req.body);
|
||||
|
||||
|
|
@ -293,7 +289,9 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
|||
prompt = WrapMessagesWithTurnMarkers(request["contents"]);
|
||||
} else {
|
||||
res.status = 400;
|
||||
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json");
|
||||
res.set_content(
|
||||
json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(),
|
||||
"application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -305,8 +303,8 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
|||
|
||||
// Set up chunked content provider for SSE
|
||||
res.set_chunked_content_provider(
|
||||
"text/event-stream",
|
||||
[&state, request, prompt, session_id](size_t offset, httplib::DataSink& sink) {
|
||||
"text/event-stream", [&state, request, prompt, session_id](
|
||||
size_t offset, httplib::DataSink& sink) {
|
||||
try {
|
||||
// Lock for inference
|
||||
std::lock_guard<std::mutex> lock(state.inference_mutex);
|
||||
|
|
@ -338,7 +336,8 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
|||
|
||||
// Decode token
|
||||
std::string token_text;
|
||||
state.gemma->Tokenizer().Decode(std::vector<int>{token}, &token_text);
|
||||
state.gemma->Tokenizer().Decode(std::vector<int>{token},
|
||||
&token_text);
|
||||
accumulated_text += token_text;
|
||||
|
||||
// Send SSE event using unified formatter
|
||||
|
|
@ -365,8 +364,7 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
|||
final_event["usageMetadata"] = {
|
||||
{"promptTokenCount", tokens.size()},
|
||||
{"candidatesTokenCount", session.abs_pos - tokens.size()},
|
||||
{"totalTokenCount", session.abs_pos}
|
||||
};
|
||||
{"totalTokenCount", session.abs_pos}};
|
||||
|
||||
std::string final_sse = "data: " + final_event.dump() + "\n\n";
|
||||
sink.write(final_sse.data(), final_sse.size());
|
||||
|
|
@ -377,16 +375,13 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
|||
// Ensure all data is sent
|
||||
sink.done();
|
||||
return false; // End streaming
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
json error_event = {{"error", {{"message", e.what()}}}};
|
||||
std::string error_sse = "data: " + error_event.dump() + "\n\n";
|
||||
sink.write(error_sse.data(), error_sse.size());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
});
|
||||
} catch (const json::exception& e) {
|
||||
res.status = 400;
|
||||
res.set_content(
|
||||
|
|
@ -398,20 +393,20 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
|||
}
|
||||
|
||||
// Handle models list endpoint
|
||||
void HandleListModels(ServerState& state, const InferenceArgs& inference, const httplib::Request& req, httplib::Response& res) {
|
||||
void HandleListModels(ServerState& state, const InferenceArgs& inference,
|
||||
const httplib::Request& req, httplib::Response& res) {
|
||||
json response = {
|
||||
{"models", {{
|
||||
{"name", "models/" + inference.model},
|
||||
{"models",
|
||||
{{{"name", "models/" + inference.model},
|
||||
{"version", "001"},
|
||||
{"displayName", inference.model},
|
||||
{"description", inference.model + " model running locally"},
|
||||
{"inputTokenLimit", 8192},
|
||||
{"outputTokenLimit", 8192},
|
||||
{"supportedGenerationMethods", json::array({"generateContent", "streamGenerateContent"})},
|
||||
{"supportedGenerationMethods",
|
||||
json::array({"generateContent", "streamGenerateContent"})},
|
||||
{"temperature", 1.0},
|
||||
{"topK", 1}
|
||||
}}}
|
||||
};
|
||||
{"topK", 1}}}}};
|
||||
|
||||
res.set_content(response.dump(), "application/json");
|
||||
}
|
||||
|
|
@ -421,37 +416,43 @@ void HandleListModels(ServerState& state, const InferenceArgs& inference, const
|
|||
// server_running = false;
|
||||
// }
|
||||
|
||||
void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference) {
|
||||
void RunServer(const GemmaArgs& args) {
|
||||
std::cerr << "Loading model..." << std::endl;
|
||||
|
||||
// Initialize model
|
||||
ThreadingContext ctx(threading);
|
||||
ThreadingContext ctx(args.threading);
|
||||
MatMulEnv env(ctx);
|
||||
|
||||
ServerState state;
|
||||
state.gemma = std::make_unique<Gemma>(loader, inference, ctx);
|
||||
state.gemma = std::make_unique<Gemma>(args, ctx);
|
||||
state.env = &env;
|
||||
state.ctx = &ctx;
|
||||
|
||||
const InferenceArgs& inference = args.inference;
|
||||
|
||||
httplib::Server server;
|
||||
|
||||
// Set up routes
|
||||
server.Get("/", [&inference](const httplib::Request&, httplib::Response& res) {
|
||||
res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" + inference.model + ":generateContent", "text/plain");
|
||||
server.Get(
|
||||
"/", [&inference](const httplib::Request&, httplib::Response& res) {
|
||||
res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" +
|
||||
inference.model + ":generateContent",
|
||||
"text/plain");
|
||||
});
|
||||
|
||||
// API endpoints
|
||||
server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req, httplib::Response& res) {
|
||||
server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req,
|
||||
httplib::Response& res) {
|
||||
HandleListModels(state, inference, req, res);
|
||||
});
|
||||
|
||||
std::string model_endpoint = "/v1beta/models/" + inference.model;
|
||||
server.Post(model_endpoint + ":generateContent", [&state](const httplib::Request& req, httplib::Response& res) {
|
||||
server.Post(model_endpoint + ":generateContent",
|
||||
[&state](const httplib::Request& req, httplib::Response& res) {
|
||||
HandleGenerateContentNonStreaming(state, req, res);
|
||||
});
|
||||
|
||||
server.Post(model_endpoint + ":streamGenerateContent", [&state](const httplib::Request& req, httplib::Response& res) {
|
||||
server.Post(model_endpoint + ":streamGenerateContent",
|
||||
[&state](const httplib::Request& req, httplib::Response& res) {
|
||||
HandleGenerateContentStreaming(state, req, res);
|
||||
});
|
||||
|
||||
|
|
@ -466,12 +467,15 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
std::cerr << "Starting API server on port " << inference.port << std::endl;
|
||||
std::cerr << "Model loaded successfully" << std::endl;
|
||||
std::cerr << "Endpoints:" << std::endl;
|
||||
std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent" << std::endl;
|
||||
std::cerr << " POST /v1beta/models/" << inference.model << ":streamGenerateContent (SSE)" << std::endl;
|
||||
std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent"
|
||||
<< std::endl;
|
||||
std::cerr << " POST /v1beta/models/" << inference.model
|
||||
<< ":streamGenerateContent (SSE)" << std::endl;
|
||||
std::cerr << " GET /v1beta/models" << std::endl;
|
||||
|
||||
if (!server.listen("0.0.0.0", inference.port)) {
|
||||
std::cerr << "Failed to start server on port " << inference.port << std::endl;
|
||||
std::cerr << "Failed to start server on port " << inference.port
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
cleanup_thread.join();
|
||||
|
|
@ -482,35 +486,27 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
int main(int argc, char** argv) {
|
||||
gcpp::InternalInit();
|
||||
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
gcpp::ThreadingArgs threading(argc, argv);
|
||||
gcpp::InferenceArgs inference(argc, argv);
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
std::cerr << "\n\nAPI server for gemma.cpp\n";
|
||||
std::cout << "========================\n\n";
|
||||
std::cerr << "Usage: " << argv[0] << " --weights <path> --tokenizer <path> [options]\n";
|
||||
std::cerr << "\nOptions:\n";
|
||||
std::cerr << " --port PORT Server port (default: 8080)\n";
|
||||
std::cerr << " --model MODEL Model name for endpoints (default: gemma3-4b)\n";
|
||||
std::cerr << "\n";
|
||||
std::cerr << "\n*Model Loading Arguments*\n\n";
|
||||
loader.Help();
|
||||
std::cerr << "\n*Threading Arguments*\n\n";
|
||||
threading.Help();
|
||||
std::cerr << "\n*Inference Arguments*\n\n";
|
||||
inference.Help();
|
||||
std::cerr << "\n";
|
||||
fprintf(
|
||||
stderr,
|
||||
"\n\nAPI server for gemma.cpp\n"
|
||||
"========================\n\n"
|
||||
" --port PORT Server port (default: 8080)\n"
|
||||
" --model MODEL Model name for endpoints (default: gemma3-4b)\n");
|
||||
args.Help();
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Arguments are now handled by InferenceArgs
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
// // Set up signal handler
|
||||
// signal(SIGINT, gcpp::HandleShutdown);
|
||||
// signal(SIGTERM, gcpp::HandleShutdown);
|
||||
|
||||
gcpp::RunServer(loader, threading, inference);
|
||||
gcpp::RunServer(args);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -73,45 +73,38 @@ GemmaContext* GemmaContext::Create(const char* tokenizer_path,
|
|||
ThreadingArgs threading_args;
|
||||
threading_args.spin = gcpp::Tristate::kFalse;
|
||||
|
||||
LoaderArgs loader(tokenizer_path, weights_path);
|
||||
LogDebug("LoaderArgs created");
|
||||
threading_args.spin = gcpp::Tristate::kFalse;
|
||||
GemmaArgs args(LoaderArgs(tokenizer_path, weights_path), threading_args);
|
||||
|
||||
// Initialize cached args
|
||||
LogDebug("Initializing inference args");
|
||||
InferenceArgs inference_args;
|
||||
inference_args.Init();
|
||||
inference_args.max_generated_tokens = max_generated_tokens;
|
||||
inference_args.temperature = 0.7f;
|
||||
inference_args.top_k = 1;
|
||||
inference_args.deterministic = false;
|
||||
args.inference.max_generated_tokens = max_generated_tokens;
|
||||
args.inference.temperature = 0.7f;
|
||||
args.inference.top_k = 1;
|
||||
args.inference.deterministic = false;
|
||||
|
||||
ss.str("");
|
||||
ss << "Inference args initialized with max_tokens: " << max_generated_tokens
|
||||
<< ", temperature: " << inference_args.temperature
|
||||
<< ", top_k: " << inference_args.top_k << ", deterministic: "
|
||||
<< (inference_args.deterministic ? "true" : "false");
|
||||
<< ", temperature: " << args.inference.temperature
|
||||
<< ", top_k: " << args.inference.top_k << ", deterministic: "
|
||||
<< (args.inference.deterministic ? "true" : "false");
|
||||
LogDebug(ss.str().c_str());
|
||||
|
||||
return new GemmaContext(loader, inference_args, threading_args,
|
||||
max_generated_tokens);
|
||||
return new GemmaContext(args, max_generated_tokens);
|
||||
}
|
||||
|
||||
GemmaContext::GemmaContext(const LoaderArgs& loader,
|
||||
const InferenceArgs& inference_args,
|
||||
const ThreadingArgs& threading_args,
|
||||
int max_generated_tokens)
|
||||
: inference_args(inference_args),
|
||||
threading_args(threading_args),
|
||||
ctx(threading_args),
|
||||
GemmaContext::GemmaContext(const GemmaArgs& args, int max_generated_tokens)
|
||||
: args(args),
|
||||
ctx(args.threading),
|
||||
matmul_env(ctx),
|
||||
active_conversation_name("default"),
|
||||
model(loader, inference_args, matmul_env.ctx) {
|
||||
model(args, matmul_env.ctx) {
|
||||
std::stringstream ss;
|
||||
|
||||
LogDebug("Creating initial ConversationData");
|
||||
// Create the initial ConversationData object using make_shared
|
||||
active_conversation = std::make_shared<ConversationData>(
|
||||
model.Config(), inference_args, ctx.allocator);
|
||||
model.Config(), args.inference, ctx.allocator);
|
||||
|
||||
LogDebug(
|
||||
"Storing initial ConversationData in conversation_cache[\"default\"]");
|
||||
|
|
@ -172,8 +165,8 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
// set up runtime config
|
||||
TimingInfo timing_info = {};
|
||||
RuntimeConfig runtime_config = {.stream_token = stream_token,
|
||||
.use_spinning = threading_args.spin};
|
||||
inference_args.CopyTo(runtime_config);
|
||||
.use_spinning = args.threading.spin};
|
||||
args.inference.CopyTo(runtime_config);
|
||||
size_t prefix_end = 0;
|
||||
|
||||
const ModelConfig& model_config = model.Config();
|
||||
|
|
@ -247,7 +240,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
timing_info);
|
||||
|
||||
// prepare for next turn
|
||||
if (!inference_args.multiturn ||
|
||||
if (!args.inference.multiturn ||
|
||||
model_config.wrapping == PromptWrapping::PALIGEMMA) {
|
||||
// If not multiturn, or Paligemma (which handles turns differently),
|
||||
// reset the *active* conversation's position.
|
||||
|
|
|
|||
|
|
@ -53,8 +53,7 @@ typedef void (*GemmaLogCallback)(const char* message, void* user_data);
|
|||
|
||||
class GemmaContext {
|
||||
private:
|
||||
GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args,
|
||||
const ThreadingArgs& threading_args, int max_generated_tokens);
|
||||
GemmaContext(const GemmaArgs& args, int max_generated_tokens);
|
||||
|
||||
public:
|
||||
static GemmaContext* Create(const char* tokenizer_path,
|
||||
|
|
@ -81,37 +80,37 @@ class GemmaContext {
|
|||
|
||||
// Set max generated tokens
|
||||
void SetMaxGeneratedTokens(size_t value) {
|
||||
inference_args.max_generated_tokens = value;
|
||||
args.inference.max_generated_tokens = value;
|
||||
LogDebug("Setting max_generated_tokens to configured value");
|
||||
}
|
||||
|
||||
// Set multiturn flag (0 = disabled, 1 = enabled)
|
||||
void SetMultiturn(int value) {
|
||||
inference_args.multiturn = value;
|
||||
args.inference.multiturn = value;
|
||||
LogDebug("Setting multiturn to configured value");
|
||||
}
|
||||
|
||||
// Set temperature for token generation
|
||||
void SetTemperature(float value) {
|
||||
inference_args.temperature = value;
|
||||
args.inference.temperature = value;
|
||||
LogDebug("Setting temperature to configured value");
|
||||
}
|
||||
|
||||
// Set top_k parameter for sampling
|
||||
void SetTopK(int value) {
|
||||
inference_args.top_k = value;
|
||||
args.inference.top_k = value;
|
||||
LogDebug("Setting top_k to configured value");
|
||||
}
|
||||
|
||||
// Set deterministic flag
|
||||
void SetDeterministic(bool value) {
|
||||
inference_args.deterministic = value;
|
||||
args.inference.deterministic = value;
|
||||
LogDebug("Setting deterministic flag to configured value");
|
||||
}
|
||||
|
||||
// Set prefill_tbatch_size
|
||||
void SetPrefillTbatchSize(size_t value) {
|
||||
inference_args.prefill_tbatch_size = value;
|
||||
args.inference.prefill_tbatch_size = value;
|
||||
LogDebug("Setting prefill_tbatch_size to configured value");
|
||||
}
|
||||
|
||||
|
|
@ -175,7 +174,7 @@ class GemmaContext {
|
|||
active_conversation->abs_pos = 0;
|
||||
// Replace the cache within the current ConversationData object
|
||||
active_conversation->kv_cache = std::make_unique<KVCache>(
|
||||
model.Config(), inference_args, ctx.allocator);
|
||||
model.Config(), args.inference, ctx.allocator);
|
||||
|
||||
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
|
||||
} else {
|
||||
|
|
@ -193,7 +192,7 @@ class GemmaContext {
|
|||
LogDebug("Creating new conversation");
|
||||
// Create a new ConversationData object using make_shared
|
||||
conversation_cache[name] = std::make_shared<ConversationData>(
|
||||
model.Config(), inference_args, ctx.allocator);
|
||||
model.Config(), args.inference, ctx.allocator);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -274,8 +273,7 @@ class GemmaContext {
|
|||
std::vector<int> token_buffer;
|
||||
|
||||
// Cached args (remain global for the context)
|
||||
InferenceArgs inference_args;
|
||||
ThreadingArgs threading_args;
|
||||
GemmaArgs args;
|
||||
ThreadingContext ctx;
|
||||
MatMulEnv matmul_env;
|
||||
|
||||
|
|
|
|||
|
|
@ -738,17 +738,16 @@ HWY_EXPORT(GenerateSingleT);
|
|||
HWY_EXPORT(GenerateBatchT);
|
||||
HWY_EXPORT(GenerateImageTokensT);
|
||||
|
||||
Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
ThreadingContext& ctx)
|
||||
: reader_(loader.weights),
|
||||
model_(reader_, loader.tokenizer, loader.wrapping),
|
||||
Gemma::Gemma(const GemmaArgs& args, ThreadingContext& ctx)
|
||||
: reader_(args.loader.weights),
|
||||
model_(reader_, args.loader.tokenizer, args.loader.wrapping),
|
||||
weights_(model_.Config()),
|
||||
chat_template_(model_.Tokenizer(), model_.Config().model),
|
||||
inference_(inference),
|
||||
aes_ctr_engine_(inference.deterministic) {
|
||||
inference_(args.inference),
|
||||
aes_ctr_engine_(args.inference.deterministic) {
|
||||
// Negligible CPU time in the ctor body (except ReadFromBlobs).
|
||||
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference,
|
||||
mat_owners_, ctx);
|
||||
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, args.loader,
|
||||
args.inference, mat_owners_, ctx);
|
||||
// Read everything into memory, or `weights_.mapped_` keeps the mapping alive.
|
||||
reader_.CloseFile();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -130,11 +130,16 @@ struct TimingInfo {
|
|||
// separate `ThreadingContext` and `MatMulEnv` for each concurrent `Generate`.
|
||||
class Gemma {
|
||||
public:
|
||||
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
|
||||
// Reads weights/config/tokenizer from `BlobStore` at `args.loader.weights`.
|
||||
// `ctx` is only used to read tensors and not stored. Calls to `Generate*`
|
||||
// may reference the same, or other `ThreadingContext` via `MatMulEnv`.
|
||||
Gemma(const GemmaArgs& args, ThreadingContext& ctx);
|
||||
|
||||
// Deprecated prior interface for backwards compatibility.
|
||||
Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
ThreadingContext& ctx);
|
||||
ThreadingContext& ctx)
|
||||
: Gemma(GemmaArgs(loader, ThreadingArgs(), inference), ctx) {}
|
||||
|
||||
~Gemma();
|
||||
|
||||
const ModelConfig& Config() const { return model_.Config(); }
|
||||
|
|
|
|||
|
|
@ -26,9 +26,10 @@
|
|||
|
||||
#include "gemma/configs.h"
|
||||
#include "io/io.h" // Path
|
||||
#include "util/args.h"
|
||||
#include "util/args.h" // IWYU pragma: export
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
#include "hwy/profiler.h"
|
||||
|
|
@ -36,7 +37,9 @@
|
|||
namespace gcpp {
|
||||
|
||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
LoaderArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||
InitAndParse(argc, argv, consumed);
|
||||
}
|
||||
LoaderArgs(const std::string& tokenizer_path,
|
||||
const std::string& weights_path) {
|
||||
Init(); // Init sets to defaults, so assignments must come after Init().
|
||||
|
|
@ -169,7 +172,9 @@ struct RuntimeConfig {
|
|||
};
|
||||
|
||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
InferenceArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||
InitAndParse(argc, argv, consumed);
|
||||
}
|
||||
InferenceArgs() { Init(); };
|
||||
|
||||
bool IsInteractive() const { return prompt.empty() && prompt_file.Empty(); }
|
||||
|
|
@ -275,33 +280,35 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
}
|
||||
};
|
||||
|
||||
struct ClientArgs : public ArgsBase<ClientArgs> {
|
||||
ClientArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
ClientArgs() { Init(); };
|
||||
// Bundles all args required to construct a `GemmaEnv` or the equivalent.
|
||||
struct GemmaArgs {
|
||||
// For callers that do not parse command line args.
|
||||
GemmaArgs(const LoaderArgs& loader,
|
||||
const ThreadingArgs& threading = ThreadingArgs(),
|
||||
const InferenceArgs& inference = InferenceArgs())
|
||||
: loader(loader), threading(threading), inference(inference) {}
|
||||
|
||||
std::string host;
|
||||
int port;
|
||||
std::string api_key;
|
||||
std::string model;
|
||||
std::string prompt;
|
||||
bool interactive;
|
||||
GemmaArgs(int argc, char** argv, ConsumedArgs& consumed)
|
||||
: loader(argc, argv, consumed),
|
||||
threading(argc, argv, consumed),
|
||||
inference(argc, argv, consumed) {}
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(host, "host", std::string("localhost"),
|
||||
"Server host (default: localhost)");
|
||||
visitor(port, "port", 8080,
|
||||
"Server port (default: 8080)");
|
||||
visitor(api_key, "api_key", std::string(""),
|
||||
"Use public API with key (changes host to "
|
||||
"generativelanguage.googleapis.com:443)");
|
||||
visitor(model, "model", std::string("gemma3-4b"),
|
||||
"Model name to use (default: gemma3-4b)");
|
||||
visitor(prompt, "prompt", std::string("Hello! How are you?"),
|
||||
"Prompt for generation (default: 'Hello! How are you?')");
|
||||
visitor(interactive, "interactive", false,
|
||||
"Start interactive chat mode (0 = no, 1 = yes)");
|
||||
void Help() {
|
||||
fprintf(stderr,
|
||||
"To run with pre-2025 weights, specify --tokenizer and --weights.\n"
|
||||
"With the single-file weights format, specify just --weights.\n"
|
||||
"\n*Model Loading Arguments*\n");
|
||||
loader.Help();
|
||||
fprintf(stderr, "\n*Threading Arguments*\n");
|
||||
threading.Help();
|
||||
fprintf(stderr, "\n*Inference Arguments*\n");
|
||||
inference.Help();
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
LoaderArgs loader;
|
||||
ThreadingArgs threading;
|
||||
InferenceArgs inference;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -0,0 +1,74 @@
|
|||
#include "gemma/gemma_args.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void FillPtrs(const std::vector<std::string>& args, std::vector<char*>& ptrs) {
|
||||
ptrs.reserve(args.size());
|
||||
for (const std::string& arg : args) {
|
||||
ptrs.push_back(const_cast<char*>(arg.data()));
|
||||
}
|
||||
}
|
||||
|
||||
static void CheckAllConsumed(const std::vector<std::string>& args) {
|
||||
std::vector<char*> ptrs;
|
||||
FillPtrs(args, ptrs);
|
||||
const int argc = static_cast<int>(args.size());
|
||||
char** argv = const_cast<char**>(ptrs.data());
|
||||
|
||||
ConsumedArgs consumed(argc, argv);
|
||||
GemmaArgs gemma_args(argc, argv, consumed);
|
||||
consumed.AbortIfUnconsumed();
|
||||
}
|
||||
|
||||
static void CheckUnconsumed(const std::vector<std::string>& args,
|
||||
size_t expected) {
|
||||
std::vector<char*> ptrs;
|
||||
FillPtrs(args, ptrs);
|
||||
const int argc = static_cast<int>(args.size());
|
||||
char** argv = const_cast<char**>(ptrs.data());
|
||||
|
||||
ConsumedArgs consumed(argc, argv);
|
||||
GemmaArgs gemma_args(argc, argv, consumed);
|
||||
ASSERT_EQ(expected, consumed.FirstUnconsumed());
|
||||
}
|
||||
|
||||
// Note: do not use --help because that is not actually consumed; it is actually
|
||||
// special-cased in `HasHelp`.
|
||||
TEST(GemmaArgsTest, AllConsumedArgs) {
|
||||
// Single arg
|
||||
CheckAllConsumed({"gemma", "--weights=x"});
|
||||
// Two args, one with =
|
||||
CheckAllConsumed({"gemma", "--weights=x", "--verbosity=1"});
|
||||
// Two args, one with extra value
|
||||
CheckAllConsumed({"gemma", "--weights=x", "--verbosity", "2"});
|
||||
// Two args with values
|
||||
CheckAllConsumed({"gemma", "--verbosity", "2", "--deterministic=true"});
|
||||
}
|
||||
|
||||
TEST(GemmaArgsTest, UnconsumedArgs) {
|
||||
// Single unconsumed arg
|
||||
CheckUnconsumed({"gemma", "--UNDEFINED"}, 1);
|
||||
// Single unconsumed arg, no --
|
||||
CheckUnconsumed({"gemma", "UNDEFINED"}, 1);
|
||||
// Single unconsumed arg after valid arg
|
||||
CheckUnconsumed({"gemma", "--weights=x", "--UNDEFINED"}, 2);
|
||||
// Single unconsumed arg before valid arg
|
||||
CheckUnconsumed({"gemma", "--UNDEFINED", "--weights=x"}, 1);
|
||||
// Single unconsumed arg with = after valid arg
|
||||
CheckUnconsumed({"gemma", "--weights=x", "--UNDEFINED=1"}, 2);
|
||||
// Single unconsumed arg with = before valid arg
|
||||
CheckUnconsumed({"gemma", "--UNDEFINED=false", "--weights=x"}, 1);
|
||||
// Multiple unconsumed args
|
||||
CheckUnconsumed({"gemma", "--UNDEFINED", "--XXX"}, 1);
|
||||
// Multiple unconsumed args with valid arg between
|
||||
CheckUnconsumed({"gemma", "--UNDEFINED", "--weights=x", "--XXX"}, 1);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
44
gemma/run.cc
44
gemma/run.cc
|
|
@ -89,9 +89,11 @@ std::string GetPrompt(const InferenceArgs& inference) {
|
|||
}
|
||||
|
||||
// The main Read-Eval-Print Loop.
|
||||
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||
const Gemma& gemma, KVCache& kv_cache, MatMulEnv& env) {
|
||||
void ReplGemma(const GemmaArgs& args, const Gemma& gemma, KVCache& kv_cache,
|
||||
MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.misc");
|
||||
const InferenceArgs& inference = args.inference;
|
||||
const int verbosity = inference.verbosity;
|
||||
size_t abs_pos = 0; // across turns
|
||||
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
||||
size_t prompt_size = 0;
|
||||
|
|
@ -113,12 +115,12 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
HWY_ASSERT(image.ReadPPM(inference.image_file.path));
|
||||
const size_t image_size = config.vit_config.image_size;
|
||||
image.Resize(image_size, image_size);
|
||||
RuntimeConfig runtime_config = {.verbosity = inference.verbosity,
|
||||
.use_spinning = threading.spin};
|
||||
RuntimeConfig runtime_config = {.verbosity = verbosity,
|
||||
.use_spinning = args.threading.spin};
|
||||
double image_tokens_start = hwy::platform::Now();
|
||||
gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
|
||||
image_tokens, env);
|
||||
if (inference.verbosity >= 1) {
|
||||
if (verbosity >= 1) {
|
||||
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||
fprintf(stderr,
|
||||
"\n\n[ Timing info ] Image token generation took: %d ms\n",
|
||||
|
|
@ -189,7 +191,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||
RuntimeConfig runtime_config = {.verbosity = inference.verbosity,
|
||||
.batch_stream_token = batch_stream_token,
|
||||
.use_spinning = threading.spin};
|
||||
.use_spinning = args.threading.spin};
|
||||
inference.CopyTo(runtime_config);
|
||||
std::vector<int> prompt;
|
||||
size_t prefix_end = 0;
|
||||
|
|
@ -252,14 +254,14 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
}
|
||||
}
|
||||
|
||||
void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference) {
|
||||
void Run(const GemmaArgs& args) {
|
||||
PROFILER_ZONE("Run.misc");
|
||||
|
||||
ThreadingContext ctx(threading);
|
||||
ThreadingContext ctx(args.threading);
|
||||
MatMulEnv env(ctx);
|
||||
const InferenceArgs& inference = args.inference;
|
||||
if (inference.verbosity >= 3) env.print_best = true;
|
||||
const Gemma gemma(loader, inference, ctx);
|
||||
const Gemma gemma(args, ctx);
|
||||
KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
||||
|
||||
if (inference.verbosity >= 1) {
|
||||
|
|
@ -287,13 +289,12 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
if (inference.IsInteractive()) {
|
||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||
<< kAsciiArtBanner << "\n\n";
|
||||
ShowConfig(loader, threading, inference, gemma.Config(),
|
||||
gemma.WeightReadMode(), ctx);
|
||||
ShowConfig(args, gemma.Config(), gemma.WeightReadMode(), ctx);
|
||||
std::cout << "\n" << instructions << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
ReplGemma(threading, inference, gemma, kv_cache, env);
|
||||
ReplGemma(args, gemma, kv_cache, env);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -302,17 +303,24 @@ int main(int argc, char** argv) {
|
|||
gcpp::InternalInit();
|
||||
{
|
||||
// Negligible CPU time.
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
gcpp::ThreadingArgs threading(argc, argv);
|
||||
gcpp::InferenceArgs inference(argc, argv);
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
std::cerr << gcpp::kAsciiArtBanner;
|
||||
gcpp::ShowHelp(loader, threading, inference);
|
||||
fprintf(stderr,
|
||||
"\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
|
||||
"==========================================================\n\n"
|
||||
"*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
|
||||
"--weights gemma2-2b-it-sfp.sbs\n\n");
|
||||
args.Help();
|
||||
return 0;
|
||||
}
|
||||
|
||||
gcpp::Run(loader, threading, inference);
|
||||
// After `HasHelp` so that we print --help even if unconsumed args remain.
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
gcpp::Run(args);
|
||||
}
|
||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -23,7 +23,9 @@ namespace gcpp {
|
|||
namespace {
|
||||
|
||||
struct WriterArgs : public ArgsBase<WriterArgs> {
|
||||
WriterArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
WriterArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||
InitAndParse(argc, argv, consumed);
|
||||
}
|
||||
|
||||
Path output_weights;
|
||||
|
||||
|
|
@ -38,12 +40,15 @@ struct WriterArgs : public ArgsBase<WriterArgs> {
|
|||
} // namespace gcpp
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::WriterArgs args(argc, argv);
|
||||
if (args.output_weights.Empty()) {
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||
gcpp::WriterArgs writer_args(argc, argv, consumed);
|
||||
if (writer_args.output_weights.Empty()) {
|
||||
HWY_ABORT("Missing --output_weights flag, a file for the model weights.");
|
||||
}
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
env.GetGemma()->Save(args.output_weights, env.Env().ctx);
|
||||
gcpp::GemmaEnv env(args);
|
||||
env.GetGemma()->Save(writer_args.output_weights, env.Env().ctx);
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,10 +21,9 @@
|
|||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "io/io.h"
|
||||
#include "paligemma/paligemma_helper.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
#include "paligemma/paligemma_helper.h"
|
||||
|
||||
// This test can be run manually with the downloaded PaliGemma weights.
|
||||
// It should pass for `paligemma-3b-mix-224` and `paligemma2-3b-pt-448`.
|
||||
|
|
@ -73,7 +72,11 @@ TEST_F(PaliGemmaTest, QueryObjects) {
|
|||
int main(int argc, char** argv) {
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
gcpp::GemmaEnv env(args);
|
||||
gcpp::s_env = &env;
|
||||
|
||||
return RUN_ALL_TESTS();
|
||||
|
|
|
|||
|
|
@ -45,10 +45,7 @@ static void RemoveTrailingZeros(std::vector<int> &vec) {
|
|||
// Wrapper around GemmaEnv to expose to Python.
|
||||
class GemmaModel {
|
||||
public:
|
||||
GemmaModel(const gcpp::LoaderArgs& loader,
|
||||
const gcpp::ThreadingArgs& threading,
|
||||
const gcpp::InferenceArgs& inference)
|
||||
: env_(loader, threading, inference), last_prob_(0.0f) {}
|
||||
GemmaModel(const gcpp::GemmaArgs& args) : env_(args), last_prob_(0.0f) {}
|
||||
|
||||
// Generates a single example, given a prompt and a callback to stream the
|
||||
// generated tokens.
|
||||
|
|
@ -254,13 +251,15 @@ PYBIND11_MODULE(gemma, mod) {
|
|||
py::class_<GemmaModel>(mod, "GemmaModel")
|
||||
.def(py::init([](const std::string& tokenizer, const std::string& weights,
|
||||
size_t max_threads) {
|
||||
const gcpp::LoaderArgs loader(tokenizer, weights);
|
||||
gcpp::ThreadingArgs threading;
|
||||
threading.max_lps = max_threads;
|
||||
|
||||
gcpp::InferenceArgs inference;
|
||||
inference.max_generated_tokens = 512;
|
||||
auto gemma =
|
||||
std::make_unique<GemmaModel>(loader, threading, inference);
|
||||
|
||||
const gcpp::GemmaArgs args(gcpp::LoaderArgs(tokenizer, weights),
|
||||
threading, inference);
|
||||
auto gemma = std::make_unique<GemmaModel>(args);
|
||||
if (!gemma->ModelIsLoaded()) {
|
||||
throw std::invalid_argument("Could not load model.");
|
||||
}
|
||||
|
|
|
|||
70
util/args.h
70
util/args.h
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
#include <algorithm> // std::transform
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "io/io.h" // Path
|
||||
#include "util/basics.h" // Tristate
|
||||
|
|
@ -29,6 +30,56 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
// For checking which args were not matched/consumed. Passed to each `*Args`
|
||||
// ctor that parses argc/argv to ensure that their args are tracked, without
|
||||
// requiring global state.
|
||||
class ConsumedArgs {
|
||||
public:
|
||||
ConsumedArgs(int argc, char** argv) : argv_(argv), consumed_(argc) {
|
||||
// We assume argc >= 1, because argv[0] is the binary name. That allows us
|
||||
// to signal "called AbortIfUnconsumed" with an empty vector.
|
||||
HWY_ASSERT(!consumed_.empty());
|
||||
}
|
||||
|
||||
~ConsumedArgs() {
|
||||
if (HWY_UNLIKELY(!consumed_.empty())) {
|
||||
HWY_ABORT("AbortIfUnconsumed was not called.");
|
||||
}
|
||||
}
|
||||
|
||||
void NotifyConsumed(size_t idx) {
|
||||
HWY_ASSERT(idx < consumed_.size());
|
||||
HWY_ASSERT(consumed_[idx] == 0);
|
||||
consumed_[idx] = 1;
|
||||
}
|
||||
|
||||
// Returns index of first unconsumed arg, or 0 if none. Also disarms the
|
||||
// warning in the dtor checking whether this/`AbortIfUnconsumed` were called.
|
||||
size_t FirstUnconsumed() {
|
||||
// Ignore argv[0], which is the binary name.
|
||||
for (size_t i = 1; i < consumed_.size(); ++i) {
|
||||
if (HWY_UNLIKELY(consumed_[i] == 0)) {
|
||||
consumed_.clear();
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
consumed_.clear();
|
||||
return 0;
|
||||
}
|
||||
|
||||
void AbortIfUnconsumed() {
|
||||
const size_t i = FirstUnconsumed();
|
||||
if (HWY_UNLIKELY(i != 0)) {
|
||||
HWY_ABORT("Unrecognized arg %zu: %s\n", i, argv_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
char** argv_;
|
||||
std::vector<uint8_t> consumed_;
|
||||
};
|
||||
|
||||
// Args is a class that provides a ForEach member function which visits each of
|
||||
// its member variables. ArgsBase provides functions called by Args to
|
||||
// initialize values to their defaults (passed as an argument to the visitor),
|
||||
|
|
@ -93,7 +144,8 @@ class ArgsBase {
|
|||
// consider adding a hash-map to speed this up.
|
||||
class ParseVisitor {
|
||||
public:
|
||||
ParseVisitor(int argc, char* argv[]) : argc_(argc), argv_(argv) {}
|
||||
ParseVisitor(int argc, char* argv[], ConsumedArgs& consumed)
|
||||
: argc_(argc), argv_(argv), consumed_(consumed) {}
|
||||
|
||||
template <typename T>
|
||||
void operator()(T& t, const char* name, const T& /*init*/,
|
||||
|
|
@ -108,6 +160,8 @@ class ArgsBase {
|
|||
if (!SetValue(argv_[i + 1], t)) {
|
||||
HWY_ABORT("Invalid value for %s, got %s\n", name, argv_[i + 1]);
|
||||
}
|
||||
consumed_.NotifyConsumed(i);
|
||||
consumed_.NotifyConsumed(i + 1);
|
||||
return;
|
||||
}
|
||||
if (std::string(argv_[i]).find(prefixed_eq) == 0) {
|
||||
|
|
@ -115,6 +169,7 @@ class ArgsBase {
|
|||
if (!SetValue(value, t)) {
|
||||
HWY_ABORT("Invalid value for %s, got %s\n", name, value);
|
||||
}
|
||||
consumed_.NotifyConsumed(i);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
@ -181,8 +236,9 @@ class ArgsBase {
|
|||
}
|
||||
}
|
||||
|
||||
int argc_;
|
||||
char** argv_;
|
||||
const int argc_;
|
||||
char** const argv_;
|
||||
ConsumedArgs& consumed_;
|
||||
}; // ParseVisitor
|
||||
|
||||
template <class Visitor>
|
||||
|
|
@ -211,15 +267,15 @@ class ArgsBase {
|
|||
ForEach(visitor);
|
||||
}
|
||||
|
||||
void Parse(int argc, char* argv[]) {
|
||||
ParseVisitor visitor(argc, argv);
|
||||
void Parse(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||
ParseVisitor visitor(argc, argv, consumed);
|
||||
ForEach(visitor);
|
||||
}
|
||||
|
||||
// For convenience, enables single-line constructor.
|
||||
void InitAndParse(int argc, char* argv[]) {
|
||||
void InitAndParse(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||
Init();
|
||||
Parse(argc, argv);
|
||||
Parse(argc, argv, consumed);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,9 @@ namespace gcpp {
|
|||
// Optional arguments for `ThreadingContext` from the command line.
|
||||
class ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
||||
public:
|
||||
ThreadingArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
ThreadingArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||
InitAndParse(argc, argv, consumed);
|
||||
}
|
||||
ThreadingArgs() { Init(); };
|
||||
|
||||
// For BoundedTopology:
|
||||
|
|
|
|||
Loading…
Reference in New Issue