mirror of https://github.com/google/gemma.cpp.git
Abort if args are unrecognized, refactor argument passing
This catches typos/incorrect usage. Refactor: group Loader/Threading/Inference into GemmaArgs. All *Args ctors now have an extra ConsumedArgs& argument. PiperOrigin-RevId: 844690553
This commit is contained in:
parent
f50550f4ce
commit
0c64987a96
11
BUILD.bazel
11
BUILD.bazel
|
|
@ -556,12 +556,22 @@ cc_library(
|
||||||
":basics",
|
":basics",
|
||||||
":configs",
|
":configs",
|
||||||
":mat",
|
":mat",
|
||||||
|
":threading_context",
|
||||||
"//io",
|
"//io",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:profiler",
|
"@highway//:profiler",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "gemma_args_test",
|
||||||
|
srcs = ["gemma/gemma_args_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":gemma_args",
|
||||||
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gemma_lib",
|
name = "gemma_lib",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
|
@ -666,7 +676,6 @@ cc_library(
|
||||||
":gemma_args",
|
":gemma_args",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":matmul_env",
|
":matmul_env",
|
||||||
":ops",
|
|
||||||
":threading_context",
|
":threading_context",
|
||||||
":tokenizer",
|
":tokenizer",
|
||||||
"@google_benchmark//:benchmark",
|
"@google_benchmark//:benchmark",
|
||||||
|
|
|
||||||
|
|
@ -222,6 +222,7 @@ set(GEMMA_TEST_FILES
|
||||||
compression/nuq_test.cc
|
compression/nuq_test.cc
|
||||||
compression/sfp_test.cc
|
compression/sfp_test.cc
|
||||||
evals/gemma_test.cc
|
evals/gemma_test.cc
|
||||||
|
gemma/gemma_args_test.cc
|
||||||
gemma/flash_attention_test.cc
|
gemma/flash_attention_test.cc
|
||||||
gemma/tensor_info_test.cc
|
gemma/tensor_info_test.cc
|
||||||
io/blob_store_test.cc
|
io/blob_store_test.cc
|
||||||
|
|
|
||||||
|
|
@ -101,13 +101,11 @@ cc_test(
|
||||||
# for test_suite.
|
# for test_suite.
|
||||||
tags = ["hwy_ops_test"],
|
tags = ["hwy_ops_test"],
|
||||||
deps = [
|
deps = [
|
||||||
":distortion",
|
|
||||||
":int",
|
":int",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//:test_util",
|
"//:test_util",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
"@highway//:nanobenchmark",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,9 @@ using json = nlohmann::json;
|
||||||
|
|
||||||
class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
|
class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
|
||||||
public:
|
public:
|
||||||
BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
BenchmarkArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||||
|
InitAndParse(argc, argv, consumed);
|
||||||
|
}
|
||||||
|
|
||||||
Path summarize_text;
|
Path summarize_text;
|
||||||
Path cross_entropy;
|
Path cross_entropy;
|
||||||
|
|
@ -127,9 +129,16 @@ int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file,
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
gcpp::GemmaEnv env(argc, argv);
|
gcpp::ConsumedArgs consumed(argc, argv);
|
||||||
gcpp::BenchmarkArgs benchmark_args(argc, argv);
|
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||||
|
gcpp::BenchmarkArgs benchmark_args(argc, argv, consumed);
|
||||||
|
if (gcpp::HasHelp(argc, argv)) {
|
||||||
|
args.Help();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
consumed.AbortIfUnconsumed();
|
||||||
|
|
||||||
|
gcpp::GemmaEnv env(args);
|
||||||
if (!benchmark_args.summarize_text.Empty()) {
|
if (!benchmark_args.summarize_text.Empty()) {
|
||||||
return BenchmarkSummary(env, benchmark_args.summarize_text);
|
return BenchmarkSummary(env, benchmark_args.summarize_text);
|
||||||
} else if (!benchmark_args.cross_entropy.Empty()) {
|
} else if (!benchmark_args.cross_entropy.Empty()) {
|
||||||
|
|
|
||||||
|
|
@ -36,35 +36,29 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
GemmaEnv::GemmaEnv(const GemmaArgs& args)
|
||||||
const InferenceArgs& inference)
|
|
||||||
: initializer_value_(gcpp::InternalInit()),
|
: initializer_value_(gcpp::InternalInit()),
|
||||||
ctx_(threading),
|
ctx_(args.threading),
|
||||||
env_(ctx_),
|
env_(ctx_),
|
||||||
gemma_(loader, inference, ctx_) {
|
gemma_(args, ctx_) {
|
||||||
const ModelConfig& config = gemma_.Config();
|
const ModelConfig& config = gemma_.Config();
|
||||||
// Only allocate one for starters because GenerateBatch might not be called.
|
// Only allocate one for starters because GenerateBatch might not be called.
|
||||||
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator));
|
kv_caches_.push_back(KVCache(config, args.inference, ctx_.allocator));
|
||||||
|
|
||||||
if (inference.verbosity >= 2) {
|
if (args.inference.verbosity >= 2) {
|
||||||
ShowConfig(loader, threading, inference, config, gemma_.WeightReadMode(),
|
ShowConfig(args, config, gemma_.WeightReadMode(), ctx_);
|
||||||
ctx_);
|
|
||||||
}
|
}
|
||||||
if (inference.verbosity >= 3) env_.print_best = true;
|
if (args.inference.verbosity >= 3) env_.print_best = true;
|
||||||
if (inference.verbosity >= 4) env_.print_config = true;
|
if (args.inference.verbosity >= 4) env_.print_config = true;
|
||||||
|
|
||||||
runtime_config_ = {
|
runtime_config_ = {
|
||||||
.max_generated_tokens = inference.max_generated_tokens,
|
.max_generated_tokens = args.inference.max_generated_tokens,
|
||||||
.temperature = inference.temperature,
|
.temperature = args.inference.temperature,
|
||||||
.verbosity = inference.verbosity,
|
.verbosity = args.inference.verbosity,
|
||||||
};
|
};
|
||||||
inference.CopyTo(runtime_config_);
|
args.inference.CopyTo(runtime_config_);
|
||||||
}
|
}
|
||||||
|
|
||||||
GemmaEnv::GemmaEnv(int argc, char** argv)
|
|
||||||
: GemmaEnv(LoaderArgs(argc, argv), ThreadingArgs(argc, argv),
|
|
||||||
InferenceArgs(argc, argv)) {}
|
|
||||||
|
|
||||||
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
||||||
QueryResult result;
|
QueryResult result;
|
||||||
|
|
||||||
|
|
@ -234,19 +228,19 @@ static constexpr const char* CompiledConfig() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
void ShowConfig(const GemmaArgs& args, const ModelConfig& config,
|
||||||
const InferenceArgs& inference, const ModelConfig& config,
|
|
||||||
const WeightsPtrs::Mode weight_read_mode,
|
const WeightsPtrs::Mode weight_read_mode,
|
||||||
const ThreadingContext& ctx) {
|
const ThreadingContext& ctx) {
|
||||||
threading.Print(inference.verbosity);
|
args.threading.Print(args.inference.verbosity);
|
||||||
loader.Print(inference.verbosity);
|
args.loader.Print(args.inference.verbosity);
|
||||||
inference.Print(inference.verbosity);
|
args.inference.Print(args.inference.verbosity);
|
||||||
fprintf(
|
fprintf(stderr,
|
||||||
stderr, "Model : %s, to_bf16 %d, mmap %d => %s\n",
|
"Model : %s, to_bf16 %d, mmap %d => %s\n",
|
||||||
config.Specifier().c_str(), static_cast<int>(loader.to_bf16),
|
config.Specifier().c_str(), static_cast<int>(args.loader.to_bf16),
|
||||||
static_cast<int>(loader.map), WeightsPtrs::ToString(weight_read_mode));
|
static_cast<int>(args.loader.map),
|
||||||
|
WeightsPtrs::ToString(weight_read_mode));
|
||||||
|
|
||||||
if (inference.verbosity >= 2) {
|
if (args.inference.verbosity >= 2) {
|
||||||
time_t now = time(nullptr);
|
time_t now = time(nullptr);
|
||||||
char* dt = ctime(&now); // NOLINT
|
char* dt = ctime(&now); // NOLINT
|
||||||
char cpu100[100] = "unknown";
|
char cpu100[100] = "unknown";
|
||||||
|
|
@ -259,7 +253,7 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
"Instruction set : %s (%zu bits)\n"
|
"Instruction set : %s (%zu bits)\n"
|
||||||
"Compiled config : %s, profiler %d\n"
|
"Compiled config : %s, profiler %d\n"
|
||||||
"Memory MiB : %4zu\n",
|
"Memory MiB : %4zu\n",
|
||||||
dt, cpu100, static_cast<int>(threading.bind),
|
dt, cpu100, static_cast<int>(args.threading.bind),
|
||||||
ctx.topology.TopologyString(), ctx.pools.PinString(),
|
ctx.topology.TopologyString(), ctx.pools.PinString(),
|
||||||
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
|
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
|
||||||
ctx.cache_info.VectorBytes() * 8, CompiledConfig(),
|
ctx.cache_info.VectorBytes() * 8, CompiledConfig(),
|
||||||
|
|
@ -267,22 +261,4 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|
||||||
const InferenceArgs& inference) {
|
|
||||||
std::cerr
|
|
||||||
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
|
|
||||||
"==========================================================\n\n"
|
|
||||||
"To run with pre-2025 weights, specify --tokenizer and --weights.\n"
|
|
||||||
"With the single-file weights format, specify just --weights.\n";
|
|
||||||
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
|
|
||||||
"--weights gemma2-2b-it-sfp.sbs\n";
|
|
||||||
std::cerr << "\n*Model Loading Arguments*\n\n";
|
|
||||||
loader.Help();
|
|
||||||
std::cerr << "\n*Threading Arguments*\n\n";
|
|
||||||
threading.Help();
|
|
||||||
std::cerr << "\n*Inference Arguments*\n\n";
|
|
||||||
inference.Help();
|
|
||||||
std::cerr << "\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@
|
||||||
|
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "gemma/gemma_args.h"
|
#include "gemma/gemma_args.h" // IWYU pragma: export
|
||||||
#include "gemma/tokenizer.h" // WrapAndTokenize
|
#include "gemma/tokenizer.h" // WrapAndTokenize
|
||||||
#include "ops/matmul.h"
|
#include "ops/matmul.h"
|
||||||
#include "util/threading_context.h"
|
#include "util/threading_context.h"
|
||||||
|
|
@ -50,10 +50,8 @@ struct QueryResultAndMetrics {
|
||||||
// Convenience class to load a model and run inference.
|
// Convenience class to load a model and run inference.
|
||||||
class GemmaEnv {
|
class GemmaEnv {
|
||||||
public:
|
public:
|
||||||
// Calls the other constructor with *Args arguments initialized from argv.
|
explicit GemmaEnv(const GemmaArgs& args);
|
||||||
GemmaEnv(int argc, char** argv);
|
|
||||||
GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|
||||||
const InferenceArgs& inference);
|
|
||||||
MatMulEnv& Env() { return env_; }
|
MatMulEnv& Env() { return env_; }
|
||||||
|
|
||||||
size_t MaxGeneratedTokens() const {
|
size_t MaxGeneratedTokens() const {
|
||||||
|
|
@ -137,12 +135,9 @@ class GemmaEnv {
|
||||||
// Logs the inference speed in tokens/sec.
|
// Logs the inference speed in tokens/sec.
|
||||||
void LogSpeedStats(double time_start, size_t total_tokens);
|
void LogSpeedStats(double time_start, size_t total_tokens);
|
||||||
|
|
||||||
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
void ShowConfig(const GemmaArgs& args, const ModelConfig& config,
|
||||||
const InferenceArgs& inference, const ModelConfig& config,
|
|
||||||
WeightsPtrs::Mode weight_read_mode,
|
WeightsPtrs::Mode weight_read_mode,
|
||||||
const ThreadingContext& ctx);
|
const ThreadingContext& ctx);
|
||||||
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|
||||||
const InferenceArgs& inference);
|
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,11 @@ BENCHMARK(BM_coding_prompt)
|
||||||
->UseRealTime();
|
->UseRealTime();
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
gcpp::GemmaEnv env(argc, argv);
|
gcpp::ConsumedArgs consumed(argc, argv);
|
||||||
|
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||||
|
consumed.AbortIfUnconsumed();
|
||||||
|
|
||||||
|
gcpp::GemmaEnv env(args);
|
||||||
env.SetMaxGeneratedTokens(256);
|
env.SetMaxGeneratedTokens(256);
|
||||||
gcpp::s_env = &env;
|
gcpp::s_env = &env;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,9 @@ namespace gcpp {
|
||||||
|
|
||||||
class PromptArgs : public ArgsBase<PromptArgs> {
|
class PromptArgs : public ArgsBase<PromptArgs> {
|
||||||
public:
|
public:
|
||||||
PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
PromptArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||||
|
InitAndParse(argc, argv, consumed);
|
||||||
|
}
|
||||||
|
|
||||||
Path layers_output; // optional
|
Path layers_output; // optional
|
||||||
std::string prompt;
|
std::string prompt;
|
||||||
|
|
@ -51,11 +53,15 @@ class PromptArgs : public ArgsBase<PromptArgs> {
|
||||||
};
|
};
|
||||||
|
|
||||||
int Run(int argc, char** argv) {
|
int Run(int argc, char** argv) {
|
||||||
PromptArgs prompt_args(argc, argv);
|
ConsumedArgs consumed(argc, argv);
|
||||||
|
const GemmaArgs args(argc, argv, consumed);
|
||||||
|
const PromptArgs prompt_args(argc, argv, consumed);
|
||||||
AbortIfInvalidArgs(prompt_args);
|
AbortIfInvalidArgs(prompt_args);
|
||||||
|
consumed.AbortIfUnconsumed();
|
||||||
|
|
||||||
json json_output;
|
json json_output;
|
||||||
GemmaEnv env(argc, argv);
|
GemmaEnv env(args);
|
||||||
|
|
||||||
env.MutableConfig().layers_output =
|
env.MutableConfig().layers_output =
|
||||||
prompt_args.layers_output.Empty()
|
prompt_args.layers_output.Empty()
|
||||||
? LayersOutputFunc()
|
? LayersOutputFunc()
|
||||||
|
|
|
||||||
|
|
@ -146,7 +146,11 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
fprintf(stderr, "GemmaEnv setup..\n");
|
fprintf(stderr, "GemmaEnv setup..\n");
|
||||||
gcpp::GemmaEnv env(argc, argv);
|
gcpp::ConsumedArgs consumed(argc, argv);
|
||||||
|
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||||
|
consumed.AbortIfUnconsumed();
|
||||||
|
|
||||||
|
gcpp::GemmaEnv env(args);
|
||||||
gcpp::s_env = &env;
|
gcpp::s_env = &env;
|
||||||
|
|
||||||
testing::InitGoogleTest(&argc, argv);
|
testing::InitGoogleTest(&argc, argv);
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@
|
||||||
|
|
||||||
#include "evals/benchmark_helper.h"
|
#include "evals/benchmark_helper.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "io/io.h"
|
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
|
||||||
|
|
@ -42,7 +41,11 @@ class GemmaTest : public ::testing::Test {
|
||||||
// Requires argc/argv, hence do not use `SetUpTestSuite`.
|
// Requires argc/argv, hence do not use `SetUpTestSuite`.
|
||||||
static void InitEnv(int argc, char** argv) {
|
static void InitEnv(int argc, char** argv) {
|
||||||
HWY_ASSERT(s_env == nullptr); // Should only be called once.
|
HWY_ASSERT(s_env == nullptr); // Should only be called once.
|
||||||
s_env = new GemmaEnv(argc, argv);
|
ConsumedArgs consumed(argc, argv);
|
||||||
|
GemmaArgs args(argc, argv, consumed);
|
||||||
|
consumed.AbortIfUnconsumed();
|
||||||
|
|
||||||
|
s_env = new GemmaEnv(args);
|
||||||
const gcpp::ModelConfig& config = s_env->GetGemma()->Config();
|
const gcpp::ModelConfig& config = s_env->GetGemma()->Config();
|
||||||
fprintf(stderr, "Using %s\n", config.Specifier().c_str());
|
fprintf(stderr, "Using %s\n", config.Specifier().c_str());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,9 @@
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
struct JsonArgs : public ArgsBase<JsonArgs> {
|
struct JsonArgs : public ArgsBase<JsonArgs> {
|
||||||
JsonArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
JsonArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||||
|
InitAndParse(argc, argv, consumed);
|
||||||
|
}
|
||||||
|
|
||||||
Path input;
|
Path input;
|
||||||
|
|
||||||
|
|
@ -151,10 +153,14 @@ void Run(GemmaEnv& env, JsonArgs& json) {
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Startup.all");
|
PROFILER_ZONE("Startup.all");
|
||||||
gcpp::GemmaEnv env(argc, argv);
|
gcpp::ConsumedArgs consumed(argc, argv);
|
||||||
gcpp::JsonArgs json(argc, argv);
|
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||||
gcpp::AbortIfInvalidArgs(json);
|
gcpp::JsonArgs json_args(argc, argv, consumed);
|
||||||
gcpp::Run(env, json);
|
gcpp::AbortIfInvalidArgs(json_args);
|
||||||
|
consumed.AbortIfUnconsumed();
|
||||||
|
|
||||||
|
gcpp::GemmaEnv env(args);
|
||||||
|
gcpp::Run(env, json_args);
|
||||||
}
|
}
|
||||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||||
return 0;
|
return 0;
|
||||||
|
|
|
||||||
|
|
@ -24,20 +24,20 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "gemma/gemma_args.h" // LoaderArgs
|
#include "gemma/gemma_args.h" // GemmaArgs
|
||||||
#include "gemma/tokenizer.h"
|
#include "gemma/tokenizer.h"
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "util/threading_context.h"
|
#include "util/threading_context.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
gcpp::LoaderArgs loader(argc, argv);
|
gcpp::ConsumedArgs consumed(argc, argv);
|
||||||
gcpp::ThreadingArgs threading(argc, argv);
|
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||||
gcpp::InferenceArgs inference(argc, argv);
|
|
||||||
if (gcpp::HasHelp(argc, argv)) {
|
if (gcpp::HasHelp(argc, argv)) {
|
||||||
loader.Help();
|
args.Help();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
consumed.AbortIfUnconsumed();
|
||||||
|
|
||||||
// Demonstrate constrained decoding by never outputting certain tokens.
|
// Demonstrate constrained decoding by never outputting certain tokens.
|
||||||
std::set<int> reject_tokens;
|
std::set<int> reject_tokens;
|
||||||
|
|
@ -49,10 +49,10 @@ int main(int argc, char** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Instantiate model and KV Cache
|
// Instantiate model and KV Cache
|
||||||
gcpp::ThreadingContext ctx(threading);
|
gcpp::ThreadingContext ctx(args.threading);
|
||||||
gcpp::MatMulEnv env(ctx);
|
gcpp::MatMulEnv env(ctx);
|
||||||
gcpp::Gemma gemma(loader, inference, ctx);
|
gcpp::Gemma gemma(args, ctx);
|
||||||
gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
gcpp::KVCache kv_cache(gemma.Config(), args.inference, ctx.allocator);
|
||||||
size_t generated = 0;
|
size_t generated = 0;
|
||||||
|
|
||||||
// Tokenize instructions.
|
// Tokenize instructions.
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "third_party/gemma_cpp/gemma/gemma.h"
|
#include "third_party/gemma_cpp/gemma/gemma.h"
|
||||||
#include "third_party/gemma_cpp/gemma/gemma_args.h" // LoaderArgs
|
#include "third_party/gemma_cpp/gemma/gemma_args.h" // GemmaArgs
|
||||||
#include "third_party/gemma_cpp/gemma/tokenizer.h"
|
#include "third_party/gemma_cpp/gemma/tokenizer.h"
|
||||||
#include "third_party/gemma_cpp/ops/matmul.h"
|
#include "third_party/gemma_cpp/ops/matmul.h"
|
||||||
#include "third_party/gemma_cpp/util/threading_context.h"
|
#include "third_party/gemma_cpp/util/threading_context.h"
|
||||||
|
|
@ -31,18 +31,11 @@
|
||||||
|
|
||||||
class SimplifiedGemma {
|
class SimplifiedGemma {
|
||||||
public:
|
public:
|
||||||
SimplifiedGemma(const gcpp::LoaderArgs& loader,
|
SimplifiedGemma(const gcpp::GemmaArgs& args)
|
||||||
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
|
: ctx_(args.threading),
|
||||||
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
|
|
||||||
: ctx_(threading),
|
|
||||||
env_(ctx_),
|
env_(ctx_),
|
||||||
gemma_(loader, inference, ctx_),
|
gemma_(args, ctx_),
|
||||||
kv_cache_(gemma_.Config(), inference, ctx_.allocator) {}
|
kv_cache_(gemma_.Config(), args.inference, ctx_.allocator) {}
|
||||||
|
|
||||||
SimplifiedGemma(int argc, char** argv)
|
|
||||||
: SimplifiedGemma(gcpp::LoaderArgs(argc, argv),
|
|
||||||
gcpp::ThreadingArgs(argc, argv),
|
|
||||||
gcpp::InferenceArgs(argc, argv)) {}
|
|
||||||
|
|
||||||
void Generate(std::string& prompt, size_t max_generated_tokens = 1024,
|
void Generate(std::string& prompt, size_t max_generated_tokens = 1024,
|
||||||
float temperature = 0.7,
|
float temperature = 0.7,
|
||||||
|
|
|
||||||
|
|
@ -18,26 +18,16 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp"
|
#include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp"
|
||||||
#include "gemma/gemma_args.h" // LoaderArgs
|
#include "gemma/gemma_args.h"
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
// Standard usage: LoaderArgs takes argc and argv as input, then parses
|
// Sets arguments from argc and argv. Note that you can instead pass in
|
||||||
// necessary flags.
|
// LoaderArgs, ThreadingArgs, and InferenceArgs directly.
|
||||||
gcpp::LoaderArgs loader(argc, argv);
|
gcpp::ConsumedArgs consumed(argc, argv);
|
||||||
|
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||||
|
consumed.AbortIfUnconsumed();
|
||||||
|
|
||||||
// Optional: LoaderArgs can also take tokenizer and weights paths directly.
|
SimplifiedGemma gemma(args);
|
||||||
//
|
|
||||||
// gcpp::LoaderArgs loader("/path/to/tokenizer", "/path/to/weights",
|
|
||||||
// "model_identifier");
|
|
||||||
|
|
||||||
// Optional: ThreadingArgs and InferenceArgs can be passed in as well. If not
|
|
||||||
// specified, default values will be used.
|
|
||||||
//
|
|
||||||
// gcpp::InferenceArgs inference(argc, argv);
|
|
||||||
// gcpp::ThreadingArgs threading(argc, argv);
|
|
||||||
// SimplifiedGemma gemma(loader, threading, inference);
|
|
||||||
|
|
||||||
SimplifiedGemma gemma(loader);
|
|
||||||
std::string prompt = "Write a greeting to the world.";
|
std::string prompt = "Write a greeting to the world.";
|
||||||
gemma.Generate(prompt, 256, 0.6);
|
gemma.Generate(prompt, 256, 0.6);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,18 +15,22 @@
|
||||||
|
|
||||||
// Test client for API server
|
// Test client for API server
|
||||||
|
|
||||||
#include <iostream>
|
#include <stdio.h>
|
||||||
#include <string>
|
|
||||||
#include <sstream>
|
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "httplib.h"
|
#include "httplib.h"
|
||||||
#include "nlohmann/json.hpp"
|
|
||||||
#include "gemma/gemma_args.h"
|
#include "gemma/gemma_args.h"
|
||||||
|
#include "nlohmann/json.hpp"
|
||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
// ANSI color codes
|
// ANSI color codes
|
||||||
const std::string RESET = "\033[0m";
|
const std::string RESET = "\033[0m";
|
||||||
const std::string BOLD = "\033[1m";
|
const std::string BOLD = "\033[1m";
|
||||||
|
|
@ -37,9 +41,15 @@ const std::string YELLOW = "\033[33m";
|
||||||
const std::string RED = "\033[31m";
|
const std::string RED = "\033[31m";
|
||||||
|
|
||||||
class APIClient {
|
class APIClient {
|
||||||
public:
|
public:
|
||||||
APIClient(const std::string& host, int port, const std::string& api_key = "", const std::string& model = "gemma3-4b")
|
APIClient(const std::string& host, int port, const std::string& api_key = "",
|
||||||
: host_(host), port_(port), api_key_(api_key), model_(model), use_https_(port == 443), interactive_mode_(false) {
|
const std::string& model = "gemma3-4b")
|
||||||
|
: host_(host),
|
||||||
|
port_(port),
|
||||||
|
api_key_(api_key),
|
||||||
|
model_(model),
|
||||||
|
use_https_(port == 443),
|
||||||
|
interactive_mode_(false) {
|
||||||
if (use_https_) {
|
if (use_https_) {
|
||||||
ssl_client_ = std::make_unique<httplib::SSLClient>(host, port);
|
ssl_client_ = std::make_unique<httplib::SSLClient>(host, port);
|
||||||
ssl_client_->set_read_timeout(60, 0);
|
ssl_client_->set_read_timeout(60, 0);
|
||||||
|
|
@ -58,8 +68,10 @@ public:
|
||||||
|
|
||||||
std::string endpoint;
|
std::string endpoint;
|
||||||
if (is_public_api) {
|
if (is_public_api) {
|
||||||
endpoint = stream ? "/v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse"
|
endpoint =
|
||||||
: "/v1beta/models/gemini-2.0-flash:generateContent";
|
stream
|
||||||
|
? "/v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse"
|
||||||
|
: "/v1beta/models/gemini-2.0-flash:generateContent";
|
||||||
} else {
|
} else {
|
||||||
endpoint = stream ? "/v1beta/models/" + model_ + ":streamGenerateContent"
|
endpoint = stream ? "/v1beta/models/" + model_ + ":streamGenerateContent"
|
||||||
: "/v1beta/models/" + model_ + ":generateContent";
|
: "/v1beta/models/" + model_ + ":generateContent";
|
||||||
|
|
@ -67,7 +79,8 @@ public:
|
||||||
|
|
||||||
// Only show verbose output in non-interactive mode
|
// Only show verbose output in non-interactive mode
|
||||||
if (!interactive_mode_) {
|
if (!interactive_mode_) {
|
||||||
std::cout << "\n" << BOLD << BLUE << "📤 POST " << endpoint << RESET << std::endl;
|
std::cout << "\n"
|
||||||
|
<< BOLD << BLUE << "📤 POST " << endpoint << RESET << std::endl;
|
||||||
std::cout << "Request: " << request.dump(2) << std::endl;
|
std::cout << "Request: " << request.dump(2) << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -83,18 +96,21 @@ public:
|
||||||
json response = ProcessRequest(request, stream);
|
json response = ProcessRequest(request, stream);
|
||||||
|
|
||||||
if (response.contains("error")) {
|
if (response.contains("error")) {
|
||||||
std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl;
|
std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET
|
||||||
|
<< std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestListModels() {
|
void TestListModels() {
|
||||||
std::cout << "\n" << BOLD << BLUE << "📤 GET /v1beta/models" << RESET << std::endl;
|
std::cout << "\n"
|
||||||
|
<< BOLD << BLUE << "📤 GET /v1beta/models" << RESET << std::endl;
|
||||||
|
|
||||||
httplib::Headers headers;
|
httplib::Headers headers;
|
||||||
if (!api_key_.empty()) {
|
if (!api_key_.empty()) {
|
||||||
headers.emplace("X-goog-api-key", api_key_);
|
headers.emplace("X-goog-api-key", api_key_);
|
||||||
}
|
}
|
||||||
auto res = use_https_ ? ssl_client_->Get("/v1beta/models", headers) : client_->Get("/v1beta/models", headers);
|
auto res = use_https_ ? ssl_client_->Get("/v1beta/models", headers)
|
||||||
|
: client_->Get("/v1beta/models", headers);
|
||||||
|
|
||||||
if (res && res->status == 200) {
|
if (res && res->status == 200) {
|
||||||
json response = json::parse(res->body);
|
json response = json::parse(res->body);
|
||||||
|
|
@ -106,7 +122,9 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
void InteractiveChat() {
|
void InteractiveChat() {
|
||||||
std::cout << "\n" << BOLD << CYAN << "💬 Interactive Chat Mode (with session)" << RESET << std::endl;
|
std::cout << "\n"
|
||||||
|
<< BOLD << CYAN << "💬 Interactive Chat Mode (with session)"
|
||||||
|
<< RESET << std::endl;
|
||||||
std::cout << "Type ':gemma %q' to end.\n" << std::endl;
|
std::cout << "Type ':gemma %q' to end.\n" << std::endl;
|
||||||
|
|
||||||
interactive_mode_ = true;
|
interactive_mode_ = true;
|
||||||
|
|
@ -141,14 +159,16 @@ public:
|
||||||
|
|
||||||
if (response.contains("candidates") && !response["candidates"].empty()) {
|
if (response.contains("candidates") && !response["candidates"].empty()) {
|
||||||
auto& candidate = response["candidates"][0];
|
auto& candidate = response["candidates"][0];
|
||||||
if (candidate.contains("content") && candidate["content"].contains("parts")) {
|
if (candidate.contains("content") &&
|
||||||
|
candidate["content"].contains("parts")) {
|
||||||
for (const auto& part : candidate["content"]["parts"]) {
|
for (const auto& part : candidate["content"]["parts"]) {
|
||||||
if (part.contains("text")) {
|
if (part.contains("text")) {
|
||||||
std::string assistant_response = part["text"].get<std::string>();
|
std::string assistant_response = part["text"].get<std::string>();
|
||||||
|
|
||||||
// For streaming, the response is already displayed in real-time
|
// For streaming, the response is already displayed in real-time
|
||||||
// Just add to message history for context
|
// Just add to message history for context
|
||||||
json assistant_message = {{"parts", {{{"text", assistant_response}}}}};
|
json assistant_message = {
|
||||||
|
{"parts", {{{"text", assistant_response}}}}};
|
||||||
if (!api_key_.empty()) {
|
if (!api_key_.empty()) {
|
||||||
assistant_message["role"] = "model";
|
assistant_message["role"] = "model";
|
||||||
}
|
}
|
||||||
|
|
@ -157,22 +177,20 @@ public:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (response.contains("error")) {
|
} else if (response.contains("error")) {
|
||||||
std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl;
|
std::cerr << RED << "❌ Error: " << response["error"]["message"]
|
||||||
|
<< RESET << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
json CreateAPIRequest(const std::string& prompt, const json& messages = json::array()) {
|
json CreateAPIRequest(const std::string& prompt,
|
||||||
|
const json& messages = json::array()) {
|
||||||
json request = {
|
json request = {
|
||||||
{"generationConfig", {
|
{"generationConfig",
|
||||||
{"temperature", 0.9},
|
{{"temperature", 0.9}, {"topK", 1}, {"maxOutputTokens", 1024}}}};
|
||||||
{"topK", 1},
|
|
||||||
{"maxOutputTokens", 1024}
|
|
||||||
}}
|
|
||||||
};
|
|
||||||
|
|
||||||
if (messages.empty()) {
|
if (messages.empty()) {
|
||||||
// Single prompt
|
// Single prompt
|
||||||
|
|
@ -189,38 +207,42 @@ private:
|
||||||
return request;
|
return request;
|
||||||
}
|
}
|
||||||
|
|
||||||
json ProcessNonStreamingRequest(const json& request, const std::string& endpoint) {
|
json ProcessNonStreamingRequest(const json& request,
|
||||||
|
const std::string& endpoint) {
|
||||||
httplib::Headers headers = {{"Content-Type", "application/json"}};
|
httplib::Headers headers = {{"Content-Type", "application/json"}};
|
||||||
if (!api_key_.empty()) {
|
if (!api_key_.empty()) {
|
||||||
headers.emplace("X-goog-api-key", api_key_);
|
headers.emplace("X-goog-api-key", api_key_);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto res = use_https_ ? ssl_client_->Post(endpoint, headers, request.dump(), "application/json")
|
auto res = use_https_ ? ssl_client_->Post(endpoint, headers, request.dump(),
|
||||||
: client_->Post(endpoint, headers, request.dump(), "application/json");
|
"application/json")
|
||||||
|
: client_->Post(endpoint, headers, request.dump(),
|
||||||
|
"application/json");
|
||||||
|
|
||||||
if (res && res->status == 200) {
|
if (res && res->status == 200) {
|
||||||
json response = json::parse(res->body);
|
json response = json::parse(res->body);
|
||||||
if (!interactive_mode_) {
|
if (!interactive_mode_) {
|
||||||
std::cout << "\n" << BOLD << GREEN << "📥 Response:" << RESET << std::endl;
|
std::cout << "\n"
|
||||||
|
<< BOLD << GREEN << "📥 Response:" << RESET << std::endl;
|
||||||
std::cout << response.dump(2) << std::endl;
|
std::cout << response.dump(2) << std::endl;
|
||||||
}
|
}
|
||||||
return response;
|
return response;
|
||||||
} else {
|
} else {
|
||||||
json error_response = {
|
json error_response = {{"error",
|
||||||
{"error", {
|
{{"message", "Request failed"},
|
||||||
{"message", "Request failed"},
|
{"status", res ? res->status : -1}}}};
|
||||||
{"status", res ? res->status : -1}
|
|
||||||
}}
|
|
||||||
};
|
|
||||||
if (res && !res->body.empty()) {
|
if (res && !res->body.empty()) {
|
||||||
error_response["error"]["details"] = res->body;
|
error_response["error"]["details"] = res->body;
|
||||||
}
|
}
|
||||||
std::cerr << RED << "❌ Request failed. Status: " << (res ? res->status : -1) << RESET << std::endl;
|
std::cerr << RED
|
||||||
|
<< "❌ Request failed. Status: " << (res ? res->status : -1)
|
||||||
|
<< RESET << std::endl;
|
||||||
return error_response;
|
return error_response;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
json ProcessStreamingRequest(const json& request, const std::string& endpoint) {
|
json ProcessStreamingRequest(const json& request,
|
||||||
|
const std::string& endpoint) {
|
||||||
std::string accumulated_response;
|
std::string accumulated_response;
|
||||||
|
|
||||||
// Use same SSE logic for both public and local APIs
|
// Use same SSE logic for both public and local APIs
|
||||||
|
|
@ -233,71 +255,72 @@ private:
|
||||||
}
|
}
|
||||||
req.body = request.dump();
|
req.body = request.dump();
|
||||||
|
|
||||||
req.content_receiver = [&accumulated_response, this](const char* data, size_t data_length, uint64_t offset, uint64_t total_length) -> bool {
|
req.content_receiver = [&accumulated_response, this](
|
||||||
std::string chunk(data, data_length);
|
const char* data, size_t data_length,
|
||||||
std::istringstream stream(chunk);
|
uint64_t offset, uint64_t total_length) -> bool {
|
||||||
std::string line;
|
std::string chunk(data, data_length);
|
||||||
|
std::istringstream stream(chunk);
|
||||||
|
std::string line;
|
||||||
|
|
||||||
while (std::getline(stream, line)) {
|
while (std::getline(stream, line)) {
|
||||||
if (line.substr(0, 6) == "data: ") {
|
if (line.substr(0, 6) == "data: ") {
|
||||||
std::string event_data = line.substr(6);
|
std::string event_data = line.substr(6);
|
||||||
|
|
||||||
if (event_data == "[DONE]") {
|
if (event_data == "[DONE]") {
|
||||||
if (!interactive_mode_) {
|
if (!interactive_mode_) {
|
||||||
std::cout << "\n\n" << GREEN << "✅ Generation complete!" << RESET << std::endl;
|
std::cout << "\n\n"
|
||||||
}
|
<< GREEN << "✅ Generation complete!" << RESET
|
||||||
} else {
|
<< std::endl;
|
||||||
try {
|
}
|
||||||
json event = json::parse(event_data);
|
} else {
|
||||||
if (event.contains("candidates") && !event["candidates"].empty()) {
|
try {
|
||||||
auto& candidate = event["candidates"][0];
|
json event = json::parse(event_data);
|
||||||
if (candidate.contains("content") && candidate["content"].contains("parts")) {
|
if (event.contains("candidates") &&
|
||||||
for (const auto& part : candidate["content"]["parts"]) {
|
!event["candidates"].empty()) {
|
||||||
if (part.contains("text")) {
|
auto& candidate = event["candidates"][0];
|
||||||
std::string text = part["text"].get<std::string>();
|
if (candidate.contains("content") &&
|
||||||
std::cout << text << std::flush;
|
candidate["content"].contains("parts")) {
|
||||||
accumulated_response += text;
|
for (const auto& part : candidate["content"]["parts"]) {
|
||||||
}
|
if (part.contains("text")) {
|
||||||
|
std::string text = part["text"].get<std::string>();
|
||||||
|
std::cout << text << std::flush;
|
||||||
|
accumulated_response += text;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (const json::exception& e) {
|
|
||||||
// Skip parse errors
|
|
||||||
}
|
}
|
||||||
|
} catch (const json::exception& e) {
|
||||||
|
// Skip parse errors
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
}
|
||||||
};
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
httplib::Response res;
|
httplib::Response res;
|
||||||
httplib::Error error;
|
httplib::Error error;
|
||||||
bool success = use_https_ ? ssl_client_->send(req, res, error) : client_->send(req, res, error);
|
bool success = use_https_ ? ssl_client_->send(req, res, error)
|
||||||
|
: client_->send(req, res, error);
|
||||||
|
|
||||||
if (res.status == 200 && !accumulated_response.empty()) {
|
if (res.status == 200 && !accumulated_response.empty()) {
|
||||||
return json{
|
return json{
|
||||||
{"candidates", {{
|
{"candidates",
|
||||||
{"content", {
|
{{{"content", {{"parts", {{{"text", accumulated_response}}}}}}}}}};
|
||||||
{"parts", {{{"text", accumulated_response}}}}
|
|
||||||
}}
|
|
||||||
}}}
|
|
||||||
};
|
|
||||||
} else {
|
} else {
|
||||||
json error_response = {
|
json error_response = {
|
||||||
{"error", {
|
{"error",
|
||||||
{"message", "Streaming request failed"},
|
{{"message", "Streaming request failed"}, {"status", res.status}}}};
|
||||||
{"status", res.status}
|
|
||||||
}}
|
|
||||||
};
|
|
||||||
if (!res.body.empty()) {
|
if (!res.body.empty()) {
|
||||||
error_response["error"]["details"] = res.body;
|
error_response["error"]["details"] = res.body;
|
||||||
}
|
}
|
||||||
std::cerr << RED << "❌ Streaming request failed. Status: " << res.status << RESET << std::endl;
|
std::cerr << RED << "❌ Streaming request failed. Status: " << res.status
|
||||||
|
<< RESET << std::endl;
|
||||||
return error_response;
|
return error_response;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<httplib::Client> client_;
|
std::unique_ptr<httplib::Client> client_;
|
||||||
std::unique_ptr<httplib::SSLClient> ssl_client_;
|
std::unique_ptr<httplib::SSLClient> ssl_client_;
|
||||||
std::string host_;
|
std::string host_;
|
||||||
|
|
@ -308,19 +331,55 @@ private:
|
||||||
bool interactive_mode_;
|
bool interactive_mode_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ClientArgs : public ArgsBase<ClientArgs> {
|
||||||
|
ClientArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||||
|
InitAndParse(argc, argv, consumed);
|
||||||
|
}
|
||||||
|
ClientArgs() { Init(); };
|
||||||
|
|
||||||
|
std::string host;
|
||||||
|
int port;
|
||||||
|
std::string api_key;
|
||||||
|
std::string model;
|
||||||
|
std::string prompt;
|
||||||
|
bool interactive;
|
||||||
|
|
||||||
|
template <class Visitor>
|
||||||
|
void ForEach(const Visitor& visitor) {
|
||||||
|
visitor(host, "host", std::string("localhost"),
|
||||||
|
"Server host (default: localhost)");
|
||||||
|
visitor(port, "port", 8080, "Server port (default: 8080)");
|
||||||
|
visitor(api_key, "api_key", std::string(""),
|
||||||
|
"Use public API with key (changes host to "
|
||||||
|
"generativelanguage.googleapis.com:443)");
|
||||||
|
visitor(model, "model", std::string("gemma3-4b"),
|
||||||
|
"Model name to use (default: gemma3-4b)");
|
||||||
|
visitor(prompt, "prompt", std::string("Hello! How are you?"),
|
||||||
|
"Prompt for generation (default: 'Hello! How are you?')");
|
||||||
|
visitor(interactive, "interactive", false,
|
||||||
|
"Start interactive chat mode (0 = no, 1 = yes)");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
int main(int argc, char* argv[]) {
|
int main(int argc, char* argv[]) {
|
||||||
gcpp::ClientArgs client_args(argc, argv);
|
gcpp::ConsumedArgs consumed(argc, argv);
|
||||||
|
gcpp::ClientArgs client_args(argc, argv, consumed);
|
||||||
|
|
||||||
if (gcpp::HasHelp(argc, argv)) {
|
if (gcpp::HasHelp(argc, argv)) {
|
||||||
std::cout << "\nAPI Client for gemma.cpp\n";
|
fprintf(stderr,
|
||||||
std::cout << "========================\n\n";
|
"\nAPI Client for gemma.cpp\n"
|
||||||
|
"========================\n\n");
|
||||||
client_args.Help();
|
client_args.Help();
|
||||||
std::cout << std::endl;
|
fprintf(stderr,
|
||||||
std::cout << "Environment Variables:" << std::endl;
|
"\n*Environment Variables:\n"
|
||||||
std::cout << " GOOGLE_API_KEY : Automatically use public Google API if set" << std::endl;
|
" GOOGLE_API_KEY : Automatically use public Google API if set\n");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
consumed.AbortIfUnconsumed();
|
||||||
|
|
||||||
// Check for GOOGLE_API_KEY environment variable
|
// Check for GOOGLE_API_KEY environment variable
|
||||||
const char* env_api_key = std::getenv("GOOGLE_API_KEY");
|
const char* env_api_key = std::getenv("GOOGLE_API_KEY");
|
||||||
if (env_api_key != nullptr && strlen(env_api_key) > 0) {
|
if (env_api_key != nullptr && strlen(env_api_key) > 0) {
|
||||||
|
|
@ -335,11 +394,12 @@ int main(int argc, char* argv[]) {
|
||||||
client_args.port = 443;
|
client_args.port = 443;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::cout << BOLD << YELLOW << "🚀 Testing API Server at "
|
std::cout << BOLD << YELLOW << "🚀 Testing API Server at " << client_args.host
|
||||||
<< client_args.host << ":" << client_args.port << RESET << std::endl;
|
<< ":" << client_args.port << RESET << std::endl;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
APIClient client(client_args.host, client_args.port, client_args.api_key, client_args.model);
|
APIClient client(client_args.host, client_args.port, client_args.api_key,
|
||||||
|
client_args.model);
|
||||||
|
|
||||||
if (client_args.interactive) {
|
if (client_args.interactive) {
|
||||||
client.InteractiveChat();
|
client.InteractiveChat();
|
||||||
|
|
@ -347,11 +407,12 @@ int main(int argc, char* argv[]) {
|
||||||
client.TestListModels();
|
client.TestListModels();
|
||||||
client.TestGenerateContent(client_args.prompt, true);
|
client.TestGenerateContent(client_args.prompt, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
std::cerr << RED << "❌ Error: " << e.what() << RESET << std::endl;
|
std::cerr << RED << "❌ Error: " << e.what() << RESET << std::endl;
|
||||||
std::cerr << "Make sure the API server is running:" << std::endl;
|
std::cerr << "Make sure the API server is running:" << std::endl;
|
||||||
std::cerr << " ./build/gemma_api_server --tokenizer <path> --weights <path>" << std::endl;
|
std::cerr
|
||||||
|
<< " ./build/gemma_api_server --tokenizer <path> --weights <path>"
|
||||||
|
<< std::endl;
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,22 +15,19 @@
|
||||||
|
|
||||||
// HTTP API server for gemma.cpp with SSE support
|
// HTTP API server for gemma.cpp with SSE support
|
||||||
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <signal.h>
|
#include <signal.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include <memory>
|
|
||||||
#include <random>
|
|
||||||
#include <string>
|
|
||||||
#include <string_view>
|
|
||||||
#include <vector>
|
|
||||||
#include <thread>
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <sstream>
|
#include <iostream>
|
||||||
#include <iomanip>
|
#include <memory>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <thread> // NOLINT
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
// HTTP server library
|
// HTTP server library
|
||||||
#undef CPPHTTPLIB_OPENSSL_SUPPORT
|
#undef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||||
|
|
@ -38,16 +35,12 @@
|
||||||
#include "httplib.h"
|
#include "httplib.h"
|
||||||
|
|
||||||
// JSON library
|
// JSON library
|
||||||
#include "nlohmann/json.hpp"
|
|
||||||
|
|
||||||
#include "compression/types.h"
|
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "gemma/gemma_args.h"
|
#include "gemma/gemma_args.h"
|
||||||
#include "gemma/tokenizer.h"
|
|
||||||
#include "ops/matmul.h"
|
#include "ops/matmul.h"
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/profiler.h"
|
#include "nlohmann/json.hpp"
|
||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
|
||||||
|
|
@ -90,7 +83,8 @@ struct ServerState {
|
||||||
std::lock_guard<std::mutex> lock(sessions_mutex);
|
std::lock_guard<std::mutex> lock(sessions_mutex);
|
||||||
auto& session = sessions[session_id];
|
auto& session = sessions[session_id];
|
||||||
if (!session.kv_cache) {
|
if (!session.kv_cache) {
|
||||||
session.kv_cache = std::make_unique<KVCache>(gemma->Config(), InferenceArgs(), env->ctx.allocator);
|
session.kv_cache = std::make_unique<KVCache>(
|
||||||
|
gemma->Config(), InferenceArgs(), env->ctx.allocator);
|
||||||
}
|
}
|
||||||
session.last_access = std::chrono::steady_clock::now();
|
session.last_access = std::chrono::steady_clock::now();
|
||||||
return session;
|
return session;
|
||||||
|
|
@ -107,7 +101,8 @@ std::string GenerateSessionId() {
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wraps messages with start_of_turn markers - handles both with and without roles
|
// Wraps messages with start_of_turn markers - handles both with and without
|
||||||
|
// roles
|
||||||
std::string WrapMessagesWithTurnMarkers(const json& contents) {
|
std::string WrapMessagesWithTurnMarkers(const json& contents) {
|
||||||
std::string prompt;
|
std::string prompt;
|
||||||
|
|
||||||
|
|
@ -121,12 +116,14 @@ std::string WrapMessagesWithTurnMarkers(const json& contents) {
|
||||||
std::string text = part["text"];
|
std::string text = part["text"];
|
||||||
|
|
||||||
if (role == "user") {
|
if (role == "user") {
|
||||||
prompt += "<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
|
prompt +=
|
||||||
|
"<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
|
||||||
} else if (role == "model") {
|
} else if (role == "model") {
|
||||||
prompt += text + "\n";
|
prompt += text + "\n";
|
||||||
} else if (role.empty()) {
|
} else if (role.empty()) {
|
||||||
// Local format without roles - for now, treat as user input
|
// Local format without roles - for now, treat as user input
|
||||||
prompt += "<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
|
prompt +=
|
||||||
|
"<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -163,18 +160,15 @@ RuntimeConfig ParseGenerationConfig(const json& request) {
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unified response formatter - creates consistent format regardless of request type
|
// Unified response formatter - creates consistent format regardless of request
|
||||||
json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false) {
|
// type
|
||||||
|
json CreateAPIResponse(const std::string& text,
|
||||||
|
bool is_streaming_chunk = false) {
|
||||||
json response = {
|
json response = {
|
||||||
{"candidates", {{
|
{"candidates",
|
||||||
{"content", {
|
{{{"content", {{"parts", {{{"text", text}}}}, {"role", "model"}}},
|
||||||
{"parts", {{{"text", text}}}},
|
{"index", 0}}}},
|
||||||
{"role", "model"}
|
{"promptFeedback", {{"safetyRatings", json::array()}}}};
|
||||||
}},
|
|
||||||
{"index", 0}
|
|
||||||
}}},
|
|
||||||
{"promptFeedback", {{"safetyRatings", json::array()}}}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Only add finishReason for non-streaming chunks
|
// Only add finishReason for non-streaming chunks
|
||||||
if (!is_streaming_chunk) {
|
if (!is_streaming_chunk) {
|
||||||
|
|
@ -185,7 +179,9 @@ json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle generateContent endpoint (non-streaming)
|
// Handle generateContent endpoint (non-streaming)
|
||||||
void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) {
|
void HandleGenerateContentNonStreaming(ServerState& state,
|
||||||
|
const httplib::Request& req,
|
||||||
|
httplib::Response& res) {
|
||||||
try {
|
try {
|
||||||
json request = json::parse(req.body);
|
json request = json::parse(req.body);
|
||||||
|
|
||||||
|
|
@ -199,7 +195,9 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
|
||||||
prompt = WrapMessagesWithTurnMarkers(request["contents"]);
|
prompt = WrapMessagesWithTurnMarkers(request["contents"]);
|
||||||
} else {
|
} else {
|
||||||
res.status = 400;
|
res.status = 400;
|
||||||
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json");
|
res.set_content(
|
||||||
|
json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(),
|
||||||
|
"application/json");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -209,12 +207,7 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
|
||||||
// Set up runtime config
|
// Set up runtime config
|
||||||
RuntimeConfig runtime_config = ParseGenerationConfig(request);
|
RuntimeConfig runtime_config = ParseGenerationConfig(request);
|
||||||
|
|
||||||
// Collect full response
|
runtime_config.stream_token = [](int token, float) { return true; };
|
||||||
std::string full_response;
|
|
||||||
runtime_config.stream_token = [&full_response](int token, float) {
|
|
||||||
// Skip EOS token
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Tokenize prompt
|
// Tokenize prompt
|
||||||
std::vector<int> tokens = WrapAndTokenize(
|
std::vector<int> tokens = WrapAndTokenize(
|
||||||
|
|
@ -227,7 +220,8 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
|
||||||
|
|
||||||
// Temporarily redirect output to capture response
|
// Temporarily redirect output to capture response
|
||||||
std::stringstream output;
|
std::stringstream output;
|
||||||
runtime_config.stream_token = [&output, &state, &session, &tokens](int token, float) {
|
runtime_config.stream_token = [&output, &state, &session, &tokens](
|
||||||
|
int token, float) {
|
||||||
// Skip prompt tokens
|
// Skip prompt tokens
|
||||||
if (session.abs_pos < tokens.size()) {
|
if (session.abs_pos < tokens.size()) {
|
||||||
session.abs_pos++;
|
session.abs_pos++;
|
||||||
|
|
@ -279,7 +273,9 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle streamGenerateContent endpoint with SSE)
|
// Handle streamGenerateContent endpoint with SSE)
|
||||||
void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) {
|
void HandleGenerateContentStreaming(ServerState& state,
|
||||||
|
const httplib::Request& req,
|
||||||
|
httplib::Response& res) {
|
||||||
try {
|
try {
|
||||||
json request = json::parse(req.body);
|
json request = json::parse(req.body);
|
||||||
|
|
||||||
|
|
@ -293,7 +289,9 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
||||||
prompt = WrapMessagesWithTurnMarkers(request["contents"]);
|
prompt = WrapMessagesWithTurnMarkers(request["contents"]);
|
||||||
} else {
|
} else {
|
||||||
res.status = 400;
|
res.status = 400;
|
||||||
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json");
|
res.set_content(
|
||||||
|
json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(),
|
||||||
|
"application/json");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -305,88 +303,85 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
||||||
|
|
||||||
// Set up chunked content provider for SSE
|
// Set up chunked content provider for SSE
|
||||||
res.set_chunked_content_provider(
|
res.set_chunked_content_provider(
|
||||||
"text/event-stream",
|
"text/event-stream", [&state, request, prompt, session_id](
|
||||||
[&state, request, prompt, session_id](size_t offset, httplib::DataSink& sink) {
|
size_t offset, httplib::DataSink& sink) {
|
||||||
try {
|
try {
|
||||||
// Lock for inference
|
// Lock for inference
|
||||||
std::lock_guard<std::mutex> lock(state.inference_mutex);
|
std::lock_guard<std::mutex> lock(state.inference_mutex);
|
||||||
auto& session = state.GetOrCreateSession(session_id);
|
auto& session = state.GetOrCreateSession(session_id);
|
||||||
|
|
||||||
// Set up runtime config
|
// Set up runtime config
|
||||||
RuntimeConfig runtime_config = ParseGenerationConfig(request);
|
RuntimeConfig runtime_config = ParseGenerationConfig(request);
|
||||||
|
|
||||||
// Tokenize prompt
|
// Tokenize prompt
|
||||||
std::vector<int> tokens = WrapAndTokenize(
|
std::vector<int> tokens = WrapAndTokenize(
|
||||||
state.gemma->Tokenizer(), state.gemma->ChatTemplate(),
|
state.gemma->Tokenizer(), state.gemma->ChatTemplate(),
|
||||||
state.gemma->Config().wrapping, session.abs_pos, prompt);
|
state.gemma->Config().wrapping, session.abs_pos, prompt);
|
||||||
|
|
||||||
|
// Stream token callback
|
||||||
|
std::string accumulated_text;
|
||||||
|
auto stream_token = [&](int token, float) {
|
||||||
|
// Skip prompt tokens
|
||||||
|
if (session.abs_pos < tokens.size()) {
|
||||||
|
session.abs_pos++;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// Stream token callback
|
|
||||||
std::string accumulated_text;
|
|
||||||
auto stream_token = [&](int token, float) {
|
|
||||||
// Skip prompt tokens
|
|
||||||
if (session.abs_pos < tokens.size()) {
|
|
||||||
session.abs_pos++;
|
session.abs_pos++;
|
||||||
|
|
||||||
|
// Check for EOS
|
||||||
|
if (state.gemma->Config().IsEOS(token)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode token
|
||||||
|
std::string token_text;
|
||||||
|
state.gemma->Tokenizer().Decode(std::vector<int>{token},
|
||||||
|
&token_text);
|
||||||
|
accumulated_text += token_text;
|
||||||
|
|
||||||
|
// Send SSE event using unified formatter
|
||||||
|
json event = CreateAPIResponse(token_text, true);
|
||||||
|
|
||||||
|
std::string sse_data = "data: " + event.dump() + "\n\n";
|
||||||
|
sink.write(sse_data.data(), sse_data.size());
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
};
|
||||||
|
|
||||||
session.abs_pos++;
|
runtime_config.stream_token = stream_token;
|
||||||
|
|
||||||
// Check for EOS
|
// Run inference with KV cache
|
||||||
if (state.gemma->Config().IsEOS(token)) {
|
TimingInfo timing_info = {.verbosity = 0};
|
||||||
return true;
|
size_t prefix_end = 0;
|
||||||
}
|
|
||||||
|
|
||||||
// Decode token
|
state.gemma->Generate(runtime_config, tokens, session.abs_pos,
|
||||||
std::string token_text;
|
prefix_end, *session.kv_cache, *state.env,
|
||||||
state.gemma->Tokenizer().Decode(std::vector<int>{token}, &token_text);
|
timing_info);
|
||||||
accumulated_text += token_text;
|
|
||||||
|
|
||||||
// Send SSE event using unified formatter
|
// Send final event using unified formatter
|
||||||
json event = CreateAPIResponse(token_text, true);
|
json final_event = CreateAPIResponse("", false);
|
||||||
|
final_event["usageMetadata"] = {
|
||||||
|
{"promptTokenCount", tokens.size()},
|
||||||
|
{"candidatesTokenCount", session.abs_pos - tokens.size()},
|
||||||
|
{"totalTokenCount", session.abs_pos}};
|
||||||
|
|
||||||
std::string sse_data = "data: " + event.dump() + "\n\n";
|
std::string final_sse = "data: " + final_event.dump() + "\n\n";
|
||||||
sink.write(sse_data.data(), sse_data.size());
|
sink.write(final_sse.data(), final_sse.size());
|
||||||
|
|
||||||
return true;
|
// Send done event
|
||||||
};
|
sink.write("data: [DONE]\n\n", 15);
|
||||||
|
|
||||||
runtime_config.stream_token = stream_token;
|
|
||||||
|
|
||||||
// Run inference with KV cache
|
|
||||||
TimingInfo timing_info = {.verbosity = 0};
|
|
||||||
size_t prefix_end = 0;
|
|
||||||
|
|
||||||
state.gemma->Generate(runtime_config, tokens, session.abs_pos,
|
|
||||||
prefix_end, *session.kv_cache, *state.env,
|
|
||||||
timing_info);
|
|
||||||
|
|
||||||
// Send final event using unified formatter
|
|
||||||
json final_event = CreateAPIResponse("", false);
|
|
||||||
final_event["usageMetadata"] = {
|
|
||||||
{"promptTokenCount", tokens.size()},
|
|
||||||
{"candidatesTokenCount", session.abs_pos - tokens.size()},
|
|
||||||
{"totalTokenCount", session.abs_pos}
|
|
||||||
};
|
|
||||||
|
|
||||||
std::string final_sse = "data: " + final_event.dump() + "\n\n";
|
|
||||||
sink.write(final_sse.data(), final_sse.size());
|
|
||||||
|
|
||||||
// Send done event
|
|
||||||
sink.write("data: [DONE]\n\n", 15);
|
|
||||||
|
|
||||||
// Ensure all data is sent
|
|
||||||
sink.done();
|
|
||||||
return false; // End streaming
|
|
||||||
|
|
||||||
} catch (const std::exception& e) {
|
|
||||||
json error_event = {{"error", {{"message", e.what()}}}};
|
|
||||||
std::string error_sse = "data: " + error_event.dump() + "\n\n";
|
|
||||||
sink.write(error_sse.data(), error_sse.size());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
|
// Ensure all data is sent
|
||||||
|
sink.done();
|
||||||
|
return false; // End streaming
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
json error_event = {{"error", {{"message", e.what()}}}};
|
||||||
|
std::string error_sse = "data: " + error_event.dump() + "\n\n";
|
||||||
|
sink.write(error_sse.data(), error_sse.size());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
});
|
||||||
} catch (const json::exception& e) {
|
} catch (const json::exception& e) {
|
||||||
res.status = 400;
|
res.status = 400;
|
||||||
res.set_content(
|
res.set_content(
|
||||||
|
|
@ -398,20 +393,20 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle models list endpoint
|
// Handle models list endpoint
|
||||||
void HandleListModels(ServerState& state, const InferenceArgs& inference, const httplib::Request& req, httplib::Response& res) {
|
void HandleListModels(ServerState& state, const InferenceArgs& inference,
|
||||||
|
const httplib::Request& req, httplib::Response& res) {
|
||||||
json response = {
|
json response = {
|
||||||
{"models", {{
|
{"models",
|
||||||
{"name", "models/" + inference.model},
|
{{{"name", "models/" + inference.model},
|
||||||
{"version", "001"},
|
{"version", "001"},
|
||||||
{"displayName", inference.model},
|
{"displayName", inference.model},
|
||||||
{"description", inference.model + " model running locally"},
|
{"description", inference.model + " model running locally"},
|
||||||
{"inputTokenLimit", 8192},
|
{"inputTokenLimit", 8192},
|
||||||
{"outputTokenLimit", 8192},
|
{"outputTokenLimit", 8192},
|
||||||
{"supportedGenerationMethods", json::array({"generateContent", "streamGenerateContent"})},
|
{"supportedGenerationMethods",
|
||||||
{"temperature", 1.0},
|
json::array({"generateContent", "streamGenerateContent"})},
|
||||||
{"topK", 1}
|
{"temperature", 1.0},
|
||||||
}}}
|
{"topK", 1}}}}};
|
||||||
};
|
|
||||||
|
|
||||||
res.set_content(response.dump(), "application/json");
|
res.set_content(response.dump(), "application/json");
|
||||||
}
|
}
|
||||||
|
|
@ -421,39 +416,45 @@ void HandleListModels(ServerState& state, const InferenceArgs& inference, const
|
||||||
// server_running = false;
|
// server_running = false;
|
||||||
// }
|
// }
|
||||||
|
|
||||||
void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
|
void RunServer(const GemmaArgs& args) {
|
||||||
const InferenceArgs& inference) {
|
|
||||||
std::cerr << "Loading model..." << std::endl;
|
std::cerr << "Loading model..." << std::endl;
|
||||||
|
|
||||||
// Initialize model
|
// Initialize model
|
||||||
ThreadingContext ctx(threading);
|
ThreadingContext ctx(args.threading);
|
||||||
MatMulEnv env(ctx);
|
MatMulEnv env(ctx);
|
||||||
|
|
||||||
ServerState state;
|
ServerState state;
|
||||||
state.gemma = std::make_unique<Gemma>(loader, inference, ctx);
|
state.gemma = std::make_unique<Gemma>(args, ctx);
|
||||||
state.env = &env;
|
state.env = &env;
|
||||||
state.ctx = &ctx;
|
state.ctx = &ctx;
|
||||||
|
|
||||||
|
const InferenceArgs& inference = args.inference;
|
||||||
|
|
||||||
httplib::Server server;
|
httplib::Server server;
|
||||||
|
|
||||||
// Set up routes
|
// Set up routes
|
||||||
server.Get("/", [&inference](const httplib::Request&, httplib::Response& res) {
|
server.Get(
|
||||||
res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" + inference.model + ":generateContent", "text/plain");
|
"/", [&inference](const httplib::Request&, httplib::Response& res) {
|
||||||
});
|
res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" +
|
||||||
|
inference.model + ":generateContent",
|
||||||
|
"text/plain");
|
||||||
|
});
|
||||||
|
|
||||||
// API endpoints
|
// API endpoints
|
||||||
server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req, httplib::Response& res) {
|
server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req,
|
||||||
|
httplib::Response& res) {
|
||||||
HandleListModels(state, inference, req, res);
|
HandleListModels(state, inference, req, res);
|
||||||
});
|
});
|
||||||
|
|
||||||
std::string model_endpoint = "/v1beta/models/" + inference.model;
|
std::string model_endpoint = "/v1beta/models/" + inference.model;
|
||||||
server.Post(model_endpoint + ":generateContent", [&state](const httplib::Request& req, httplib::Response& res) {
|
server.Post(model_endpoint + ":generateContent",
|
||||||
HandleGenerateContentNonStreaming(state, req, res);
|
[&state](const httplib::Request& req, httplib::Response& res) {
|
||||||
});
|
HandleGenerateContentNonStreaming(state, req, res);
|
||||||
|
});
|
||||||
|
|
||||||
server.Post(model_endpoint + ":streamGenerateContent", [&state](const httplib::Request& req, httplib::Response& res) {
|
server.Post(model_endpoint + ":streamGenerateContent",
|
||||||
HandleGenerateContentStreaming(state, req, res);
|
[&state](const httplib::Request& req, httplib::Response& res) {
|
||||||
});
|
HandleGenerateContentStreaming(state, req, res);
|
||||||
|
});
|
||||||
|
|
||||||
// Periodic cleanup of old sessions
|
// Periodic cleanup of old sessions
|
||||||
std::thread cleanup_thread([&state]() {
|
std::thread cleanup_thread([&state]() {
|
||||||
|
|
@ -466,12 +467,15 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
std::cerr << "Starting API server on port " << inference.port << std::endl;
|
std::cerr << "Starting API server on port " << inference.port << std::endl;
|
||||||
std::cerr << "Model loaded successfully" << std::endl;
|
std::cerr << "Model loaded successfully" << std::endl;
|
||||||
std::cerr << "Endpoints:" << std::endl;
|
std::cerr << "Endpoints:" << std::endl;
|
||||||
std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent" << std::endl;
|
std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent"
|
||||||
std::cerr << " POST /v1beta/models/" << inference.model << ":streamGenerateContent (SSE)" << std::endl;
|
<< std::endl;
|
||||||
|
std::cerr << " POST /v1beta/models/" << inference.model
|
||||||
|
<< ":streamGenerateContent (SSE)" << std::endl;
|
||||||
std::cerr << " GET /v1beta/models" << std::endl;
|
std::cerr << " GET /v1beta/models" << std::endl;
|
||||||
|
|
||||||
if (!server.listen("0.0.0.0", inference.port)) {
|
if (!server.listen("0.0.0.0", inference.port)) {
|
||||||
std::cerr << "Failed to start server on port " << inference.port << std::endl;
|
std::cerr << "Failed to start server on port " << inference.port
|
||||||
|
<< std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
cleanup_thread.join();
|
cleanup_thread.join();
|
||||||
|
|
@ -482,35 +486,27 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
gcpp::InternalInit();
|
gcpp::InternalInit();
|
||||||
|
|
||||||
gcpp::LoaderArgs loader(argc, argv);
|
gcpp::ConsumedArgs consumed(argc, argv);
|
||||||
gcpp::ThreadingArgs threading(argc, argv);
|
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||||
gcpp::InferenceArgs inference(argc, argv);
|
|
||||||
|
|
||||||
if (gcpp::HasHelp(argc, argv)) {
|
if (gcpp::HasHelp(argc, argv)) {
|
||||||
std::cerr << "\n\nAPI server for gemma.cpp\n";
|
fprintf(
|
||||||
std::cout << "========================\n\n";
|
stderr,
|
||||||
std::cerr << "Usage: " << argv[0] << " --weights <path> --tokenizer <path> [options]\n";
|
"\n\nAPI server for gemma.cpp\n"
|
||||||
std::cerr << "\nOptions:\n";
|
"========================\n\n"
|
||||||
std::cerr << " --port PORT Server port (default: 8080)\n";
|
" --port PORT Server port (default: 8080)\n"
|
||||||
std::cerr << " --model MODEL Model name for endpoints (default: gemma3-4b)\n";
|
" --model MODEL Model name for endpoints (default: gemma3-4b)\n");
|
||||||
std::cerr << "\n";
|
args.Help();
|
||||||
std::cerr << "\n*Model Loading Arguments*\n\n";
|
|
||||||
loader.Help();
|
|
||||||
std::cerr << "\n*Threading Arguments*\n\n";
|
|
||||||
threading.Help();
|
|
||||||
std::cerr << "\n*Inference Arguments*\n\n";
|
|
||||||
inference.Help();
|
|
||||||
std::cerr << "\n";
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Arguments are now handled by InferenceArgs
|
consumed.AbortIfUnconsumed();
|
||||||
|
|
||||||
// // Set up signal handler
|
// // Set up signal handler
|
||||||
// signal(SIGINT, gcpp::HandleShutdown);
|
// signal(SIGINT, gcpp::HandleShutdown);
|
||||||
// signal(SIGTERM, gcpp::HandleShutdown);
|
// signal(SIGTERM, gcpp::HandleShutdown);
|
||||||
|
|
||||||
gcpp::RunServer(loader, threading, inference);
|
gcpp::RunServer(args);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -73,45 +73,38 @@ GemmaContext* GemmaContext::Create(const char* tokenizer_path,
|
||||||
ThreadingArgs threading_args;
|
ThreadingArgs threading_args;
|
||||||
threading_args.spin = gcpp::Tristate::kFalse;
|
threading_args.spin = gcpp::Tristate::kFalse;
|
||||||
|
|
||||||
LoaderArgs loader(tokenizer_path, weights_path);
|
threading_args.spin = gcpp::Tristate::kFalse;
|
||||||
LogDebug("LoaderArgs created");
|
GemmaArgs args(LoaderArgs(tokenizer_path, weights_path), threading_args);
|
||||||
|
|
||||||
// Initialize cached args
|
// Initialize cached args
|
||||||
LogDebug("Initializing inference args");
|
LogDebug("Initializing inference args");
|
||||||
InferenceArgs inference_args;
|
args.inference.max_generated_tokens = max_generated_tokens;
|
||||||
inference_args.Init();
|
args.inference.temperature = 0.7f;
|
||||||
inference_args.max_generated_tokens = max_generated_tokens;
|
args.inference.top_k = 1;
|
||||||
inference_args.temperature = 0.7f;
|
args.inference.deterministic = false;
|
||||||
inference_args.top_k = 1;
|
|
||||||
inference_args.deterministic = false;
|
|
||||||
|
|
||||||
ss.str("");
|
ss.str("");
|
||||||
ss << "Inference args initialized with max_tokens: " << max_generated_tokens
|
ss << "Inference args initialized with max_tokens: " << max_generated_tokens
|
||||||
<< ", temperature: " << inference_args.temperature
|
<< ", temperature: " << args.inference.temperature
|
||||||
<< ", top_k: " << inference_args.top_k << ", deterministic: "
|
<< ", top_k: " << args.inference.top_k << ", deterministic: "
|
||||||
<< (inference_args.deterministic ? "true" : "false");
|
<< (args.inference.deterministic ? "true" : "false");
|
||||||
LogDebug(ss.str().c_str());
|
LogDebug(ss.str().c_str());
|
||||||
|
|
||||||
return new GemmaContext(loader, inference_args, threading_args,
|
return new GemmaContext(args, max_generated_tokens);
|
||||||
max_generated_tokens);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
GemmaContext::GemmaContext(const LoaderArgs& loader,
|
GemmaContext::GemmaContext(const GemmaArgs& args, int max_generated_tokens)
|
||||||
const InferenceArgs& inference_args,
|
: args(args),
|
||||||
const ThreadingArgs& threading_args,
|
ctx(args.threading),
|
||||||
int max_generated_tokens)
|
|
||||||
: inference_args(inference_args),
|
|
||||||
threading_args(threading_args),
|
|
||||||
ctx(threading_args),
|
|
||||||
matmul_env(ctx),
|
matmul_env(ctx),
|
||||||
active_conversation_name("default"),
|
active_conversation_name("default"),
|
||||||
model(loader, inference_args, matmul_env.ctx) {
|
model(args, matmul_env.ctx) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
|
|
||||||
LogDebug("Creating initial ConversationData");
|
LogDebug("Creating initial ConversationData");
|
||||||
// Create the initial ConversationData object using make_shared
|
// Create the initial ConversationData object using make_shared
|
||||||
active_conversation = std::make_shared<ConversationData>(
|
active_conversation = std::make_shared<ConversationData>(
|
||||||
model.Config(), inference_args, ctx.allocator);
|
model.Config(), args.inference, ctx.allocator);
|
||||||
|
|
||||||
LogDebug(
|
LogDebug(
|
||||||
"Storing initial ConversationData in conversation_cache[\"default\"]");
|
"Storing initial ConversationData in conversation_cache[\"default\"]");
|
||||||
|
|
@ -172,8 +165,8 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
||||||
// set up runtime config
|
// set up runtime config
|
||||||
TimingInfo timing_info = {};
|
TimingInfo timing_info = {};
|
||||||
RuntimeConfig runtime_config = {.stream_token = stream_token,
|
RuntimeConfig runtime_config = {.stream_token = stream_token,
|
||||||
.use_spinning = threading_args.spin};
|
.use_spinning = args.threading.spin};
|
||||||
inference_args.CopyTo(runtime_config);
|
args.inference.CopyTo(runtime_config);
|
||||||
size_t prefix_end = 0;
|
size_t prefix_end = 0;
|
||||||
|
|
||||||
const ModelConfig& model_config = model.Config();
|
const ModelConfig& model_config = model.Config();
|
||||||
|
|
@ -247,7 +240,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
||||||
timing_info);
|
timing_info);
|
||||||
|
|
||||||
// prepare for next turn
|
// prepare for next turn
|
||||||
if (!inference_args.multiturn ||
|
if (!args.inference.multiturn ||
|
||||||
model_config.wrapping == PromptWrapping::PALIGEMMA) {
|
model_config.wrapping == PromptWrapping::PALIGEMMA) {
|
||||||
// If not multiturn, or Paligemma (which handles turns differently),
|
// If not multiturn, or Paligemma (which handles turns differently),
|
||||||
// reset the *active* conversation's position.
|
// reset the *active* conversation's position.
|
||||||
|
|
|
||||||
|
|
@ -53,8 +53,7 @@ typedef void (*GemmaLogCallback)(const char* message, void* user_data);
|
||||||
|
|
||||||
class GemmaContext {
|
class GemmaContext {
|
||||||
private:
|
private:
|
||||||
GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args,
|
GemmaContext(const GemmaArgs& args, int max_generated_tokens);
|
||||||
const ThreadingArgs& threading_args, int max_generated_tokens);
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static GemmaContext* Create(const char* tokenizer_path,
|
static GemmaContext* Create(const char* tokenizer_path,
|
||||||
|
|
@ -81,37 +80,37 @@ class GemmaContext {
|
||||||
|
|
||||||
// Set max generated tokens
|
// Set max generated tokens
|
||||||
void SetMaxGeneratedTokens(size_t value) {
|
void SetMaxGeneratedTokens(size_t value) {
|
||||||
inference_args.max_generated_tokens = value;
|
args.inference.max_generated_tokens = value;
|
||||||
LogDebug("Setting max_generated_tokens to configured value");
|
LogDebug("Setting max_generated_tokens to configured value");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set multiturn flag (0 = disabled, 1 = enabled)
|
// Set multiturn flag (0 = disabled, 1 = enabled)
|
||||||
void SetMultiturn(int value) {
|
void SetMultiturn(int value) {
|
||||||
inference_args.multiturn = value;
|
args.inference.multiturn = value;
|
||||||
LogDebug("Setting multiturn to configured value");
|
LogDebug("Setting multiturn to configured value");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set temperature for token generation
|
// Set temperature for token generation
|
||||||
void SetTemperature(float value) {
|
void SetTemperature(float value) {
|
||||||
inference_args.temperature = value;
|
args.inference.temperature = value;
|
||||||
LogDebug("Setting temperature to configured value");
|
LogDebug("Setting temperature to configured value");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set top_k parameter for sampling
|
// Set top_k parameter for sampling
|
||||||
void SetTopK(int value) {
|
void SetTopK(int value) {
|
||||||
inference_args.top_k = value;
|
args.inference.top_k = value;
|
||||||
LogDebug("Setting top_k to configured value");
|
LogDebug("Setting top_k to configured value");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set deterministic flag
|
// Set deterministic flag
|
||||||
void SetDeterministic(bool value) {
|
void SetDeterministic(bool value) {
|
||||||
inference_args.deterministic = value;
|
args.inference.deterministic = value;
|
||||||
LogDebug("Setting deterministic flag to configured value");
|
LogDebug("Setting deterministic flag to configured value");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set prefill_tbatch_size
|
// Set prefill_tbatch_size
|
||||||
void SetPrefillTbatchSize(size_t value) {
|
void SetPrefillTbatchSize(size_t value) {
|
||||||
inference_args.prefill_tbatch_size = value;
|
args.inference.prefill_tbatch_size = value;
|
||||||
LogDebug("Setting prefill_tbatch_size to configured value");
|
LogDebug("Setting prefill_tbatch_size to configured value");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -175,7 +174,7 @@ class GemmaContext {
|
||||||
active_conversation->abs_pos = 0;
|
active_conversation->abs_pos = 0;
|
||||||
// Replace the cache within the current ConversationData object
|
// Replace the cache within the current ConversationData object
|
||||||
active_conversation->kv_cache = std::make_unique<KVCache>(
|
active_conversation->kv_cache = std::make_unique<KVCache>(
|
||||||
model.Config(), inference_args, ctx.allocator);
|
model.Config(), args.inference, ctx.allocator);
|
||||||
|
|
||||||
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
|
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -193,7 +192,7 @@ class GemmaContext {
|
||||||
LogDebug("Creating new conversation");
|
LogDebug("Creating new conversation");
|
||||||
// Create a new ConversationData object using make_shared
|
// Create a new ConversationData object using make_shared
|
||||||
conversation_cache[name] = std::make_shared<ConversationData>(
|
conversation_cache[name] = std::make_shared<ConversationData>(
|
||||||
model.Config(), inference_args, ctx.allocator);
|
model.Config(), args.inference, ctx.allocator);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -274,8 +273,7 @@ class GemmaContext {
|
||||||
std::vector<int> token_buffer;
|
std::vector<int> token_buffer;
|
||||||
|
|
||||||
// Cached args (remain global for the context)
|
// Cached args (remain global for the context)
|
||||||
InferenceArgs inference_args;
|
GemmaArgs args;
|
||||||
ThreadingArgs threading_args;
|
|
||||||
ThreadingContext ctx;
|
ThreadingContext ctx;
|
||||||
MatMulEnv matmul_env;
|
MatMulEnv matmul_env;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -738,17 +738,16 @@ HWY_EXPORT(GenerateSingleT);
|
||||||
HWY_EXPORT(GenerateBatchT);
|
HWY_EXPORT(GenerateBatchT);
|
||||||
HWY_EXPORT(GenerateImageTokensT);
|
HWY_EXPORT(GenerateImageTokensT);
|
||||||
|
|
||||||
Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
Gemma::Gemma(const GemmaArgs& args, ThreadingContext& ctx)
|
||||||
ThreadingContext& ctx)
|
: reader_(args.loader.weights),
|
||||||
: reader_(loader.weights),
|
model_(reader_, args.loader.tokenizer, args.loader.wrapping),
|
||||||
model_(reader_, loader.tokenizer, loader.wrapping),
|
|
||||||
weights_(model_.Config()),
|
weights_(model_.Config()),
|
||||||
chat_template_(model_.Tokenizer(), model_.Config().model),
|
chat_template_(model_.Tokenizer(), model_.Config().model),
|
||||||
inference_(inference),
|
inference_(args.inference),
|
||||||
aes_ctr_engine_(inference.deterministic) {
|
aes_ctr_engine_(args.inference.deterministic) {
|
||||||
// Negligible CPU time in the ctor body (except ReadFromBlobs).
|
// Negligible CPU time in the ctor body (except ReadFromBlobs).
|
||||||
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference,
|
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, args.loader,
|
||||||
mat_owners_, ctx);
|
args.inference, mat_owners_, ctx);
|
||||||
// Read everything into memory, or `weights_.mapped_` keeps the mapping alive.
|
// Read everything into memory, or `weights_.mapped_` keeps the mapping alive.
|
||||||
reader_.CloseFile();
|
reader_.CloseFile();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -130,11 +130,16 @@ struct TimingInfo {
|
||||||
// separate `ThreadingContext` and `MatMulEnv` for each concurrent `Generate`.
|
// separate `ThreadingContext` and `MatMulEnv` for each concurrent `Generate`.
|
||||||
class Gemma {
|
class Gemma {
|
||||||
public:
|
public:
|
||||||
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
|
// Reads weights/config/tokenizer from `BlobStore` at `args.loader.weights`.
|
||||||
// `ctx` is only used to read tensors and not stored. Calls to `Generate*`
|
// `ctx` is only used to read tensors and not stored. Calls to `Generate*`
|
||||||
// may reference the same, or other `ThreadingContext` via `MatMulEnv`.
|
// may reference the same, or other `ThreadingContext` via `MatMulEnv`.
|
||||||
|
Gemma(const GemmaArgs& args, ThreadingContext& ctx);
|
||||||
|
|
||||||
|
// Deprecated prior interface for backwards compatibility.
|
||||||
Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
ThreadingContext& ctx);
|
ThreadingContext& ctx)
|
||||||
|
: Gemma(GemmaArgs(loader, ThreadingArgs(), inference), ctx) {}
|
||||||
|
|
||||||
~Gemma();
|
~Gemma();
|
||||||
|
|
||||||
const ModelConfig& Config() const { return model_.Config(); }
|
const ModelConfig& Config() const { return model_.Config(); }
|
||||||
|
|
|
||||||
|
|
@ -25,10 +25,11 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "io/io.h" // Path
|
#include "io/io.h" // Path
|
||||||
#include "util/args.h"
|
#include "util/args.h" // IWYU pragma: export
|
||||||
#include "util/basics.h" // Tristate
|
#include "util/basics.h" // Tristate
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
|
#include "util/threading_context.h"
|
||||||
#include "hwy/aligned_allocator.h" // Span
|
#include "hwy/aligned_allocator.h" // Span
|
||||||
#include "hwy/base.h" // HWY_ABORT
|
#include "hwy/base.h" // HWY_ABORT
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
@ -36,7 +37,9 @@
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
LoaderArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||||
|
InitAndParse(argc, argv, consumed);
|
||||||
|
}
|
||||||
LoaderArgs(const std::string& tokenizer_path,
|
LoaderArgs(const std::string& tokenizer_path,
|
||||||
const std::string& weights_path) {
|
const std::string& weights_path) {
|
||||||
Init(); // Init sets to defaults, so assignments must come after Init().
|
Init(); // Init sets to defaults, so assignments must come after Init().
|
||||||
|
|
@ -169,7 +172,9 @@ struct RuntimeConfig {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
InferenceArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||||
|
InitAndParse(argc, argv, consumed);
|
||||||
|
}
|
||||||
InferenceArgs() { Init(); };
|
InferenceArgs() { Init(); };
|
||||||
|
|
||||||
bool IsInteractive() const { return prompt.empty() && prompt_file.Empty(); }
|
bool IsInteractive() const { return prompt.empty() && prompt_file.Empty(); }
|
||||||
|
|
@ -275,33 +280,35 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ClientArgs : public ArgsBase<ClientArgs> {
|
// Bundles all args required to construct a `GemmaEnv` or the equivalent.
|
||||||
ClientArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
struct GemmaArgs {
|
||||||
ClientArgs() { Init(); };
|
// For callers that do not parse command line args.
|
||||||
|
GemmaArgs(const LoaderArgs& loader,
|
||||||
|
const ThreadingArgs& threading = ThreadingArgs(),
|
||||||
|
const InferenceArgs& inference = InferenceArgs())
|
||||||
|
: loader(loader), threading(threading), inference(inference) {}
|
||||||
|
|
||||||
std::string host;
|
GemmaArgs(int argc, char** argv, ConsumedArgs& consumed)
|
||||||
int port;
|
: loader(argc, argv, consumed),
|
||||||
std::string api_key;
|
threading(argc, argv, consumed),
|
||||||
std::string model;
|
inference(argc, argv, consumed) {}
|
||||||
std::string prompt;
|
|
||||||
bool interactive;
|
|
||||||
|
|
||||||
template <class Visitor>
|
void Help() {
|
||||||
void ForEach(const Visitor& visitor) {
|
fprintf(stderr,
|
||||||
visitor(host, "host", std::string("localhost"),
|
"To run with pre-2025 weights, specify --tokenizer and --weights.\n"
|
||||||
"Server host (default: localhost)");
|
"With the single-file weights format, specify just --weights.\n"
|
||||||
visitor(port, "port", 8080,
|
"\n*Model Loading Arguments*\n");
|
||||||
"Server port (default: 8080)");
|
loader.Help();
|
||||||
visitor(api_key, "api_key", std::string(""),
|
fprintf(stderr, "\n*Threading Arguments*\n");
|
||||||
"Use public API with key (changes host to "
|
threading.Help();
|
||||||
"generativelanguage.googleapis.com:443)");
|
fprintf(stderr, "\n*Inference Arguments*\n");
|
||||||
visitor(model, "model", std::string("gemma3-4b"),
|
inference.Help();
|
||||||
"Model name to use (default: gemma3-4b)");
|
fprintf(stderr, "\n");
|
||||||
visitor(prompt, "prompt", std::string("Hello! How are you?"),
|
|
||||||
"Prompt for generation (default: 'Hello! How are you?')");
|
|
||||||
visitor(interactive, "interactive", false,
|
|
||||||
"Start interactive chat mode (0 = no, 1 = yes)");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LoaderArgs loader;
|
||||||
|
ThreadingArgs threading;
|
||||||
|
InferenceArgs inference;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,74 @@
|
||||||
|
#include "gemma/gemma_args.h"
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
void FillPtrs(const std::vector<std::string>& args, std::vector<char*>& ptrs) {
|
||||||
|
ptrs.reserve(args.size());
|
||||||
|
for (const std::string& arg : args) {
|
||||||
|
ptrs.push_back(const_cast<char*>(arg.data()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void CheckAllConsumed(const std::vector<std::string>& args) {
|
||||||
|
std::vector<char*> ptrs;
|
||||||
|
FillPtrs(args, ptrs);
|
||||||
|
const int argc = static_cast<int>(args.size());
|
||||||
|
char** argv = const_cast<char**>(ptrs.data());
|
||||||
|
|
||||||
|
ConsumedArgs consumed(argc, argv);
|
||||||
|
GemmaArgs gemma_args(argc, argv, consumed);
|
||||||
|
consumed.AbortIfUnconsumed();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void CheckUnconsumed(const std::vector<std::string>& args,
|
||||||
|
size_t expected) {
|
||||||
|
std::vector<char*> ptrs;
|
||||||
|
FillPtrs(args, ptrs);
|
||||||
|
const int argc = static_cast<int>(args.size());
|
||||||
|
char** argv = const_cast<char**>(ptrs.data());
|
||||||
|
|
||||||
|
ConsumedArgs consumed(argc, argv);
|
||||||
|
GemmaArgs gemma_args(argc, argv, consumed);
|
||||||
|
ASSERT_EQ(expected, consumed.FirstUnconsumed());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: do not use --help because that is not actually consumed; it is actually
|
||||||
|
// special-cased in `HasHelp`.
|
||||||
|
TEST(GemmaArgsTest, AllConsumedArgs) {
|
||||||
|
// Single arg
|
||||||
|
CheckAllConsumed({"gemma", "--weights=x"});
|
||||||
|
// Two args, one with =
|
||||||
|
CheckAllConsumed({"gemma", "--weights=x", "--verbosity=1"});
|
||||||
|
// Two args, one with extra value
|
||||||
|
CheckAllConsumed({"gemma", "--weights=x", "--verbosity", "2"});
|
||||||
|
// Two args with values
|
||||||
|
CheckAllConsumed({"gemma", "--verbosity", "2", "--deterministic=true"});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GemmaArgsTest, UnconsumedArgs) {
|
||||||
|
// Single unconsumed arg
|
||||||
|
CheckUnconsumed({"gemma", "--UNDEFINED"}, 1);
|
||||||
|
// Single unconsumed arg, no --
|
||||||
|
CheckUnconsumed({"gemma", "UNDEFINED"}, 1);
|
||||||
|
// Single unconsumed arg after valid arg
|
||||||
|
CheckUnconsumed({"gemma", "--weights=x", "--UNDEFINED"}, 2);
|
||||||
|
// Single unconsumed arg before valid arg
|
||||||
|
CheckUnconsumed({"gemma", "--UNDEFINED", "--weights=x"}, 1);
|
||||||
|
// Single unconsumed arg with = after valid arg
|
||||||
|
CheckUnconsumed({"gemma", "--weights=x", "--UNDEFINED=1"}, 2);
|
||||||
|
// Single unconsumed arg with = before valid arg
|
||||||
|
CheckUnconsumed({"gemma", "--UNDEFINED=false", "--weights=x"}, 1);
|
||||||
|
// Multiple unconsumed args
|
||||||
|
CheckUnconsumed({"gemma", "--UNDEFINED", "--XXX"}, 1);
|
||||||
|
// Multiple unconsumed args with valid arg between
|
||||||
|
CheckUnconsumed({"gemma", "--UNDEFINED", "--weights=x", "--XXX"}, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
44
gemma/run.cc
44
gemma/run.cc
|
|
@ -89,9 +89,11 @@ std::string GetPrompt(const InferenceArgs& inference) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// The main Read-Eval-Print Loop.
|
// The main Read-Eval-Print Loop.
|
||||||
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
void ReplGemma(const GemmaArgs& args, const Gemma& gemma, KVCache& kv_cache,
|
||||||
const Gemma& gemma, KVCache& kv_cache, MatMulEnv& env) {
|
MatMulEnv& env) {
|
||||||
PROFILER_ZONE("Gen.misc");
|
PROFILER_ZONE("Gen.misc");
|
||||||
|
const InferenceArgs& inference = args.inference;
|
||||||
|
const int verbosity = inference.verbosity;
|
||||||
size_t abs_pos = 0; // across turns
|
size_t abs_pos = 0; // across turns
|
||||||
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
||||||
size_t prompt_size = 0;
|
size_t prompt_size = 0;
|
||||||
|
|
@ -113,12 +115,12 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
HWY_ASSERT(image.ReadPPM(inference.image_file.path));
|
HWY_ASSERT(image.ReadPPM(inference.image_file.path));
|
||||||
const size_t image_size = config.vit_config.image_size;
|
const size_t image_size = config.vit_config.image_size;
|
||||||
image.Resize(image_size, image_size);
|
image.Resize(image_size, image_size);
|
||||||
RuntimeConfig runtime_config = {.verbosity = inference.verbosity,
|
RuntimeConfig runtime_config = {.verbosity = verbosity,
|
||||||
.use_spinning = threading.spin};
|
.use_spinning = args.threading.spin};
|
||||||
double image_tokens_start = hwy::platform::Now();
|
double image_tokens_start = hwy::platform::Now();
|
||||||
gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
|
gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
|
||||||
image_tokens, env);
|
image_tokens, env);
|
||||||
if (inference.verbosity >= 1) {
|
if (verbosity >= 1) {
|
||||||
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||||
fprintf(stderr,
|
fprintf(stderr,
|
||||||
"\n\n[ Timing info ] Image token generation took: %d ms\n",
|
"\n\n[ Timing info ] Image token generation took: %d ms\n",
|
||||||
|
|
@ -189,7 +191,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||||
RuntimeConfig runtime_config = {.verbosity = inference.verbosity,
|
RuntimeConfig runtime_config = {.verbosity = inference.verbosity,
|
||||||
.batch_stream_token = batch_stream_token,
|
.batch_stream_token = batch_stream_token,
|
||||||
.use_spinning = threading.spin};
|
.use_spinning = args.threading.spin};
|
||||||
inference.CopyTo(runtime_config);
|
inference.CopyTo(runtime_config);
|
||||||
std::vector<int> prompt;
|
std::vector<int> prompt;
|
||||||
size_t prefix_end = 0;
|
size_t prefix_end = 0;
|
||||||
|
|
@ -252,14 +254,14 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
void Run(const GemmaArgs& args) {
|
||||||
const InferenceArgs& inference) {
|
|
||||||
PROFILER_ZONE("Run.misc");
|
PROFILER_ZONE("Run.misc");
|
||||||
|
|
||||||
ThreadingContext ctx(threading);
|
ThreadingContext ctx(args.threading);
|
||||||
MatMulEnv env(ctx);
|
MatMulEnv env(ctx);
|
||||||
|
const InferenceArgs& inference = args.inference;
|
||||||
if (inference.verbosity >= 3) env.print_best = true;
|
if (inference.verbosity >= 3) env.print_best = true;
|
||||||
const Gemma gemma(loader, inference, ctx);
|
const Gemma gemma(args, ctx);
|
||||||
KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
||||||
|
|
||||||
if (inference.verbosity >= 1) {
|
if (inference.verbosity >= 1) {
|
||||||
|
|
@ -287,13 +289,12 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
if (inference.IsInteractive()) {
|
if (inference.IsInteractive()) {
|
||||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||||
<< kAsciiArtBanner << "\n\n";
|
<< kAsciiArtBanner << "\n\n";
|
||||||
ShowConfig(loader, threading, inference, gemma.Config(),
|
ShowConfig(args, gemma.Config(), gemma.WeightReadMode(), ctx);
|
||||||
gemma.WeightReadMode(), ctx);
|
|
||||||
std::cout << "\n" << instructions << "\n";
|
std::cout << "\n" << instructions << "\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ReplGemma(threading, inference, gemma, kv_cache, env);
|
ReplGemma(args, gemma, kv_cache, env);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
@ -302,17 +303,24 @@ int main(int argc, char** argv) {
|
||||||
gcpp::InternalInit();
|
gcpp::InternalInit();
|
||||||
{
|
{
|
||||||
// Negligible CPU time.
|
// Negligible CPU time.
|
||||||
gcpp::LoaderArgs loader(argc, argv);
|
gcpp::ConsumedArgs consumed(argc, argv);
|
||||||
gcpp::ThreadingArgs threading(argc, argv);
|
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||||
gcpp::InferenceArgs inference(argc, argv);
|
|
||||||
|
|
||||||
if (gcpp::HasHelp(argc, argv)) {
|
if (gcpp::HasHelp(argc, argv)) {
|
||||||
std::cerr << gcpp::kAsciiArtBanner;
|
std::cerr << gcpp::kAsciiArtBanner;
|
||||||
gcpp::ShowHelp(loader, threading, inference);
|
fprintf(stderr,
|
||||||
|
"\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
|
||||||
|
"==========================================================\n\n"
|
||||||
|
"*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
|
||||||
|
"--weights gemma2-2b-it-sfp.sbs\n\n");
|
||||||
|
args.Help();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
gcpp::Run(loader, threading, inference);
|
// After `HasHelp` so that we print --help even if unconsumed args remain.
|
||||||
|
consumed.AbortIfUnconsumed();
|
||||||
|
|
||||||
|
gcpp::Run(args);
|
||||||
}
|
}
|
||||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||||
return 0;
|
return 0;
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,9 @@ namespace gcpp {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct WriterArgs : public ArgsBase<WriterArgs> {
|
struct WriterArgs : public ArgsBase<WriterArgs> {
|
||||||
WriterArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
WriterArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||||
|
InitAndParse(argc, argv, consumed);
|
||||||
|
}
|
||||||
|
|
||||||
Path output_weights;
|
Path output_weights;
|
||||||
|
|
||||||
|
|
@ -38,12 +40,15 @@ struct WriterArgs : public ArgsBase<WriterArgs> {
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
gcpp::WriterArgs args(argc, argv);
|
gcpp::ConsumedArgs consumed(argc, argv);
|
||||||
if (args.output_weights.Empty()) {
|
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||||
|
gcpp::WriterArgs writer_args(argc, argv, consumed);
|
||||||
|
if (writer_args.output_weights.Empty()) {
|
||||||
HWY_ABORT("Missing --output_weights flag, a file for the model weights.");
|
HWY_ABORT("Missing --output_weights flag, a file for the model weights.");
|
||||||
}
|
}
|
||||||
|
consumed.AbortIfUnconsumed();
|
||||||
|
|
||||||
gcpp::GemmaEnv env(argc, argv);
|
gcpp::GemmaEnv env(args);
|
||||||
env.GetGemma()->Save(args.output_weights, env.Env().ctx);
|
env.GetGemma()->Save(writer_args.output_weights, env.Env().ctx);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -21,10 +21,9 @@
|
||||||
#include "evals/benchmark_helper.h"
|
#include "evals/benchmark_helper.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "io/io.h"
|
#include "paligemma/paligemma_helper.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
#include "paligemma/paligemma_helper.h"
|
|
||||||
|
|
||||||
// This test can be run manually with the downloaded PaliGemma weights.
|
// This test can be run manually with the downloaded PaliGemma weights.
|
||||||
// It should pass for `paligemma-3b-mix-224` and `paligemma2-3b-pt-448`.
|
// It should pass for `paligemma-3b-mix-224` and `paligemma2-3b-pt-448`.
|
||||||
|
|
@ -73,7 +72,11 @@ TEST_F(PaliGemmaTest, QueryObjects) {
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
testing::InitGoogleTest(&argc, argv);
|
testing::InitGoogleTest(&argc, argv);
|
||||||
|
|
||||||
gcpp::GemmaEnv env(argc, argv);
|
gcpp::ConsumedArgs consumed(argc, argv);
|
||||||
|
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||||
|
consumed.AbortIfUnconsumed();
|
||||||
|
|
||||||
|
gcpp::GemmaEnv env(args);
|
||||||
gcpp::s_env = &env;
|
gcpp::s_env = &env;
|
||||||
|
|
||||||
return RUN_ALL_TESTS();
|
return RUN_ALL_TESTS();
|
||||||
|
|
|
||||||
|
|
@ -45,10 +45,7 @@ static void RemoveTrailingZeros(std::vector<int> &vec) {
|
||||||
// Wrapper around GemmaEnv to expose to Python.
|
// Wrapper around GemmaEnv to expose to Python.
|
||||||
class GemmaModel {
|
class GemmaModel {
|
||||||
public:
|
public:
|
||||||
GemmaModel(const gcpp::LoaderArgs& loader,
|
GemmaModel(const gcpp::GemmaArgs& args) : env_(args), last_prob_(0.0f) {}
|
||||||
const gcpp::ThreadingArgs& threading,
|
|
||||||
const gcpp::InferenceArgs& inference)
|
|
||||||
: env_(loader, threading, inference), last_prob_(0.0f) {}
|
|
||||||
|
|
||||||
// Generates a single example, given a prompt and a callback to stream the
|
// Generates a single example, given a prompt and a callback to stream the
|
||||||
// generated tokens.
|
// generated tokens.
|
||||||
|
|
@ -254,13 +251,15 @@ PYBIND11_MODULE(gemma, mod) {
|
||||||
py::class_<GemmaModel>(mod, "GemmaModel")
|
py::class_<GemmaModel>(mod, "GemmaModel")
|
||||||
.def(py::init([](const std::string& tokenizer, const std::string& weights,
|
.def(py::init([](const std::string& tokenizer, const std::string& weights,
|
||||||
size_t max_threads) {
|
size_t max_threads) {
|
||||||
const gcpp::LoaderArgs loader(tokenizer, weights);
|
|
||||||
gcpp::ThreadingArgs threading;
|
gcpp::ThreadingArgs threading;
|
||||||
threading.max_lps = max_threads;
|
threading.max_lps = max_threads;
|
||||||
|
|
||||||
gcpp::InferenceArgs inference;
|
gcpp::InferenceArgs inference;
|
||||||
inference.max_generated_tokens = 512;
|
inference.max_generated_tokens = 512;
|
||||||
auto gemma =
|
|
||||||
std::make_unique<GemmaModel>(loader, threading, inference);
|
const gcpp::GemmaArgs args(gcpp::LoaderArgs(tokenizer, weights),
|
||||||
|
threading, inference);
|
||||||
|
auto gemma = std::make_unique<GemmaModel>(args);
|
||||||
if (!gemma->ModelIsLoaded()) {
|
if (!gemma->ModelIsLoaded()) {
|
||||||
throw std::invalid_argument("Could not load model.");
|
throw std::invalid_argument("Could not load model.");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
70
util/args.h
70
util/args.h
|
|
@ -22,6 +22,7 @@
|
||||||
|
|
||||||
#include <algorithm> // std::transform
|
#include <algorithm> // std::transform
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "io/io.h" // Path
|
#include "io/io.h" // Path
|
||||||
#include "util/basics.h" // Tristate
|
#include "util/basics.h" // Tristate
|
||||||
|
|
@ -29,6 +30,56 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
// For checking which args were not matched/consumed. Passed to each `*Args`
|
||||||
|
// ctor that parses argc/argv to ensure that their args are tracked, without
|
||||||
|
// requiring global state.
|
||||||
|
class ConsumedArgs {
|
||||||
|
public:
|
||||||
|
ConsumedArgs(int argc, char** argv) : argv_(argv), consumed_(argc) {
|
||||||
|
// We assume argc >= 1, because argv[0] is the binary name. That allows us
|
||||||
|
// to signal "called AbortIfUnconsumed" with an empty vector.
|
||||||
|
HWY_ASSERT(!consumed_.empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
~ConsumedArgs() {
|
||||||
|
if (HWY_UNLIKELY(!consumed_.empty())) {
|
||||||
|
HWY_ABORT("AbortIfUnconsumed was not called.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void NotifyConsumed(size_t idx) {
|
||||||
|
HWY_ASSERT(idx < consumed_.size());
|
||||||
|
HWY_ASSERT(consumed_[idx] == 0);
|
||||||
|
consumed_[idx] = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns index of first unconsumed arg, or 0 if none. Also disarms the
|
||||||
|
// warning in the dtor checking whether this/`AbortIfUnconsumed` were called.
|
||||||
|
size_t FirstUnconsumed() {
|
||||||
|
// Ignore argv[0], which is the binary name.
|
||||||
|
for (size_t i = 1; i < consumed_.size(); ++i) {
|
||||||
|
if (HWY_UNLIKELY(consumed_[i] == 0)) {
|
||||||
|
consumed_.clear();
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
consumed_.clear();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AbortIfUnconsumed() {
|
||||||
|
const size_t i = FirstUnconsumed();
|
||||||
|
if (HWY_UNLIKELY(i != 0)) {
|
||||||
|
HWY_ABORT("Unrecognized arg %zu: %s\n", i, argv_[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
char** argv_;
|
||||||
|
std::vector<uint8_t> consumed_;
|
||||||
|
};
|
||||||
|
|
||||||
// Args is a class that provides a ForEach member function which visits each of
|
// Args is a class that provides a ForEach member function which visits each of
|
||||||
// its member variables. ArgsBase provides functions called by Args to
|
// its member variables. ArgsBase provides functions called by Args to
|
||||||
// initialize values to their defaults (passed as an argument to the visitor),
|
// initialize values to their defaults (passed as an argument to the visitor),
|
||||||
|
|
@ -93,7 +144,8 @@ class ArgsBase {
|
||||||
// consider adding a hash-map to speed this up.
|
// consider adding a hash-map to speed this up.
|
||||||
class ParseVisitor {
|
class ParseVisitor {
|
||||||
public:
|
public:
|
||||||
ParseVisitor(int argc, char* argv[]) : argc_(argc), argv_(argv) {}
|
ParseVisitor(int argc, char* argv[], ConsumedArgs& consumed)
|
||||||
|
: argc_(argc), argv_(argv), consumed_(consumed) {}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void operator()(T& t, const char* name, const T& /*init*/,
|
void operator()(T& t, const char* name, const T& /*init*/,
|
||||||
|
|
@ -108,6 +160,8 @@ class ArgsBase {
|
||||||
if (!SetValue(argv_[i + 1], t)) {
|
if (!SetValue(argv_[i + 1], t)) {
|
||||||
HWY_ABORT("Invalid value for %s, got %s\n", name, argv_[i + 1]);
|
HWY_ABORT("Invalid value for %s, got %s\n", name, argv_[i + 1]);
|
||||||
}
|
}
|
||||||
|
consumed_.NotifyConsumed(i);
|
||||||
|
consumed_.NotifyConsumed(i + 1);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (std::string(argv_[i]).find(prefixed_eq) == 0) {
|
if (std::string(argv_[i]).find(prefixed_eq) == 0) {
|
||||||
|
|
@ -115,6 +169,7 @@ class ArgsBase {
|
||||||
if (!SetValue(value, t)) {
|
if (!SetValue(value, t)) {
|
||||||
HWY_ABORT("Invalid value for %s, got %s\n", name, value);
|
HWY_ABORT("Invalid value for %s, got %s\n", name, value);
|
||||||
}
|
}
|
||||||
|
consumed_.NotifyConsumed(i);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -181,8 +236,9 @@ class ArgsBase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int argc_;
|
const int argc_;
|
||||||
char** argv_;
|
char** const argv_;
|
||||||
|
ConsumedArgs& consumed_;
|
||||||
}; // ParseVisitor
|
}; // ParseVisitor
|
||||||
|
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
|
|
@ -211,15 +267,15 @@ class ArgsBase {
|
||||||
ForEach(visitor);
|
ForEach(visitor);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Parse(int argc, char* argv[]) {
|
void Parse(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||||
ParseVisitor visitor(argc, argv);
|
ParseVisitor visitor(argc, argv, consumed);
|
||||||
ForEach(visitor);
|
ForEach(visitor);
|
||||||
}
|
}
|
||||||
|
|
||||||
// For convenience, enables single-line constructor.
|
// For convenience, enables single-line constructor.
|
||||||
void InitAndParse(int argc, char* argv[]) {
|
void InitAndParse(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||||
Init();
|
Init();
|
||||||
Parse(argc, argv);
|
Parse(argc, argv, consumed);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,9 @@ namespace gcpp {
|
||||||
// Optional arguments for `ThreadingContext` from the command line.
|
// Optional arguments for `ThreadingContext` from the command line.
|
||||||
class ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
class ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
||||||
public:
|
public:
|
||||||
ThreadingArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
ThreadingArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||||
|
InitAndParse(argc, argv, consumed);
|
||||||
|
}
|
||||||
ThreadingArgs() { Init(); };
|
ThreadingArgs() { Init(); };
|
||||||
|
|
||||||
// For BoundedTopology:
|
// For BoundedTopology:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue