Major duplicated code reduction in test/benchmarks

Helper functions to tokenize/wrap
Move LayersOutputFunc into RuntimeConfig
AcceptFunc passes the probability
Implement StringFromType using the parser, and verify results match

PiperOrigin-RevId: 643255119
This commit is contained in:
Jan Wassenberg 2024-06-14 00:15:36 -07:00 committed by Copybara-Service
parent c15ff9529c
commit d3c6a45b59
17 changed files with 626 additions and 814 deletions

View File

@ -113,12 +113,8 @@ cc_library(
cc_library(
name = "cross_entropy",
srcs = [
"gemma/cross_entropy.cc",
],
hdrs = [
"gemma/cross_entropy.h",
],
srcs = ["gemma/cross_entropy.cc"],
hdrs = ["gemma/cross_entropy.h"],
deps = [
":common",
":gemma_lib",
@ -149,6 +145,25 @@ cc_library(
],
)
cc_library(
name = "benchmark_helper",
srcs = ["gemma/benchmark_helper.cc"],
hdrs = ["gemma/benchmark_helper.h"],
deps = [
":app",
":args",
":common",
":cross_entropy",
":gemma_lib",
# Placeholder for internal dep, do not remove.,
"@benchmark//:benchmark",
"//compression:compress",
"@hwy//:hwy",
"@hwy//:nanobenchmark",
"@hwy//:thread_pool",
],
)
cc_test(
name = "gemma_test",
srcs = ["gemma/gemma_test.cc"],
@ -161,11 +176,11 @@ cc_test(
deps = [
":app",
":args",
":benchmark_helper",
":common",
":cross_entropy",
":gemma_lib",
":ops",
# Placeholder for internal dep, do not remove.,
"@googletest//:gtest_main",
"//compression:io",
"@hwy//:hwy_test_util",
@ -179,6 +194,7 @@ cc_binary(
deps = [
":app",
":args",
":benchmark_helper",
":common",
":gemma_lib",
# Placeholder for internal dep, do not remove.,
@ -213,10 +229,10 @@ cc_binary(
deps = [
":app",
":args",
":benchmark_helper",
":common",
":cross_entropy",
":gemma_lib",
# Placeholder for internal dep, do not remove.,
"//compression:io",
"@hwy//:hwy",
"@hwy//:nanobenchmark",
@ -230,7 +246,6 @@ cc_binary(
srcs = ["gemma/benchmarks.cc"],
deps = [
":benchmark_helper",
# Placeholder for internal dep, do not remove.,
"@benchmark//:benchmark",
],
)
@ -243,8 +258,8 @@ cc_binary(
deps = [
":app",
":args",
":benchmark_helper",
":gemma_lib",
# Placeholder for internal dep, do not remove.,
"//compression:io",
"@hwy//:hwy",
"@hwy//:thread_pool",
@ -257,8 +272,10 @@ cc_binary(
srcs = ["gemma/run_mmlu.cc"],
deps = [
":app",
":args",
":benchmark_helper",
":gemma_lib",
# Placeholder for internal dep, do not remove.,
"//compression:io",
"@hwy//:hwy",
"@hwy//:profiler",
"@hwy//:thread_pool",
@ -318,25 +335,6 @@ cc_library(
],
)
cc_library(
name = "benchmark_helper",
srcs = [
"gemma/benchmark_helper.cc",
],
hdrs = [
"gemma/benchmark_helper.h",
],
deps = [
":app",
":common",
":gemma_lib",
"@benchmark//:benchmark",
"@hwy//:hwy",
"@hwy//:nanobenchmark",
"@hwy//:thread_pool",
],
)
cc_test(
name = "backward_scalar_test",
size = "large",

View File

@ -23,6 +23,8 @@
#include <string>
#include <utility> // std::move
#include "hwy/base.h"
namespace gcpp {
// Forward-declare to break the circular dependency: OpenFileOrNull returns
@ -77,12 +79,30 @@ struct Path {
return path;
}
bool Empty() const { return path.empty(); }
// Returns whether the file existed when this was called.
bool Exists() const { return !!OpenFileOrNull(*this, "r"); }
std::string path;
};
static inline HWY_MAYBE_UNUSED std::string ReadFileToString(const Path& path) {
std::unique_ptr<File> file = OpenFileOrNull(path, "r");
if (!file) {
HWY_ABORT("Failed to open %s", path.path.c_str());
}
const size_t size = file->FileSize();
if (size == 0) {
HWY_ABORT("Empty file %s", path.path.c_str());
}
std::string content(size, ' ');
if (!file->Read(0, size, content.data())) {
HWY_ABORT("Failed to read %s", path.path.c_str());
}
return content;
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_

View File

@ -1,146 +1,83 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <random>
#include <string>
#include <utility>
#include <vector>
// Placeholder for internal header, do not modify.
#include "compression/io.h"
#include "gemma/gemma.h"
#include "util/app.h"
#include "gemma/benchmark_helper.h"
#include "gemma/gemma.h" // LayersOutputFunc
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "nlohmann/json.hpp"
using json = nlohmann::json;
class PromptArgs : public gcpp::ArgsBase<PromptArgs> {
namespace gcpp {
class PromptArgs : public ArgsBase<PromptArgs> {
public:
PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
gcpp::Path layers_output;
Path layers_output; // optional
std::string prompt;
// Returns error string or nullptr if OK.
const char* Validate() const {
if (prompt.empty()) return "Must specify --prompt";
return nullptr;
}
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(layers_output.path, "layers_output", std::string(""),
visitor(layers_output, "layers_output", Path(""),
"Path to store layers output", 2);
visitor(prompt, "prompt", std::string(""), "Prompt to the model", 2);
}
};
std::pair<std::string, int> QueryModel(
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input,
gcpp::LayersOutputT* layers_output) {
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer().Encode(input, &prompt));
int Run(int argc, char** argv) {
PromptArgs prompt_args(argc, argv);
AbortIfInvalidArgs(prompt_args);
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
prompt.insert(prompt.begin(), gcpp::BOS_ID);
std::string res;
size_t total_tokens = 0;
std::mt19937 gen;
gen.seed(42);
auto stream_token = [&res, &total_tokens, &model](int token, float) {
++total_tokens;
std::string token_text;
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
res += token_text;
return true;
};
if (app.verbosity >= 2) {
std::cout << args.max_tokens << " " << args.max_generated_tokens << " "
<< args.temperature;
}
gcpp::TimingInfo timing_info;
gcpp::RuntimeConfig runtime_config = {
.max_tokens = args.max_tokens,
.max_generated_tokens = args.max_generated_tokens,
.temperature = args.temperature,
.verbosity = app.verbosity,
.gen = &gen,
.stream_token = stream_token,
};
model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, timing_info,
layers_output);
return {res, total_tokens};
}
class OutputJsonLogger {
public:
json json_output;
gcpp::LayersOutputT layers_output_log_f =
[this](int pos, const std::string& key, const float* values,
GemmaEnv env(argc, argv);
env.MutableConfig().layers_output =
prompt_args.layers_output.Empty()
? LayersOutputFunc()
: [&json_output](int pos, const std::string& key, const float* values,
size_t values_len) {
std::vector<float> v{values, values + values_len};
json_output[std::to_string(pos)][key] = v;
};
};
/* Run this in the same way as gemma, p.ex.:
./debug_prompt --tokenizer tokenizer.spm --model 2b-it --weights \
2b-it-sfp.sbs --prompt "..." --layers_output [path]
*/
int main(int argc, char** argv) {
{
// Placeholder for internal init, do not modify.
}
const auto [answer, token_count] = env.QueryModel(prompt_args.prompt);
std::cout << answer.substr(prompt_args.prompt.size()) << "\n" << std::flush;
gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs args(argc, argv); // inference
gcpp::AppArgs app(argc, argv);
PromptArgs prompt_args(argc, argv);
if (const char* error = loader.Validate()) {
HWY_ABORT("\nInvalid loader args: %s", error);
}
if (const char* error = args.Validate()) {
HWY_ABORT("\nInvalid inference args: %s", error);
}
const bool log_layers_output = !prompt_args.layers_output.path.empty();
OutputJsonLogger json_logger;
gcpp::LayersOutputT* layers_output =
log_layers_output ? &json_logger.layers_output_log_f : nullptr;
hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning workers to cores helps.
if (app.num_threads > 10) {
gcpp::PinWorkersToCores(pool);
}
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
const std::string& prompt = prompt_args.prompt;
if (prompt.empty()) {
std::cout << "Please specify --prompt" << std::endl;
return EXIT_FAILURE;
}
const auto [answer, token_count] = QueryModel(
model, args, app, kv_cache, pool, prompt, layers_output);
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
if (log_layers_output) {
if (env.MutableConfig().layers_output) {
std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out);
if (!output_f) {
std::cout << "Opening file failed" << std::endl;
return EXIT_FAILURE;
}
output_f << json_logger.json_output.dump();
if (!output_f) {
std::cout << "Writing to file failed" << std::endl;
return EXIT_FAILURE;
}
if (!output_f) HWY_ABORT("Opening layer output file failed");
output_f << json_output.dump();
if (!output_f) HWY_ABORT("Writing to layer output file failed");
output_f.close();
}
return EXIT_SUCCESS;
return 0;
}
} // namespace gcpp
int main(int argc, char** argv) { return gcpp::Run(argc, argv); }

View File

@ -1,36 +1,36 @@
#include <stdio.h>
#include <algorithm>
#include <cstdlib> // EXIT_FAILURE
#include <fstream>
#include <iostream>
#include <ostream>
#include <random>
#include <sstream>
#include <string>
#include <utility> // std::pair
#include <vector>
// Placeholder for internal header, do not modify.
#include "compression/io.h" // Path
#include "gemma/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/cross_entropy.h"
#include "gemma/gemma.h"
#include "util/app.h"
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/timer.h"
#include "nlohmann/json.hpp"
namespace gcpp {
using json = nlohmann::json;
class BenchmarkArgs : public gcpp::ArgsBase<BenchmarkArgs> {
class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
public:
BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
gcpp::Path goldens;
gcpp::Path summarize_text;
gcpp::Path cross_entropy;
gcpp::Path trivia_qa;
Path goldens;
Path summarize_text;
Path cross_entropy;
Path trivia_qa;
size_t max_questions;
size_t batch_tokens;
@ -53,61 +53,6 @@ class BenchmarkArgs : public gcpp::ArgsBase<BenchmarkArgs> {
}
};
void LogSpeedStats(const double time_start, size_t total_tokens) {
const double time_end = hwy::platform::Now();
const double time_elapsed = time_end - time_start;
const double tok_sec = total_tokens / time_elapsed;
std::cout << total_tokens << " tokens in " << time_elapsed << " seconds"
<< " [" << tok_sec << " tokens / sec" << "]\n";
}
std::pair<std::string, int> QueryModel(
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input) {
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer().Encode(input, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
prompt.insert(prompt.begin(), 2);
std::string res;
size_t total_tokens = 0;
std::mt19937 gen;
gen.seed(42);
const double time_start = hwy::platform::Now();
auto stream_token = [&res, &total_tokens, &time_start, &app, &model](
int token, float) {
++total_tokens;
std::string token_text;
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
res += token_text;
if (app.verbosity >= 1 && total_tokens % 100 == 0) {
LogSpeedStats(time_start, total_tokens);
}
return true;
};
if (app.verbosity >= 2) {
std::cout << args.max_tokens << " " << args.max_generated_tokens << " "
<< args.temperature;
}
gcpp::TimingInfo timing_info;
gcpp::RuntimeConfig runtime_config = {
.max_tokens = args.max_tokens,
.max_generated_tokens = args.max_generated_tokens,
.temperature = args.temperature,
.verbosity = app.verbosity,
.gen = &gen,
.stream_token = stream_token,
};
model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, timing_info,
/*layers_output=*/nullptr);
if (app.verbosity >= 1) {
LogSpeedStats(time_start, total_tokens);
}
return {res, total_tokens};
}
std::vector<std::pair<std::string, std::string>> load_goldens(
const std::string& path) {
std::ifstream goldens_file(path);
@ -129,28 +74,14 @@ std::vector<std::pair<std::string, std::string>> load_goldens(
return res;
}
std::string ReadFile(const gcpp::Path& path) {
std::ifstream text_file(path.path);
if (!text_file) {
std::cout << "Could not open file: " << path.path << "\n" << std::flush;
return {};
}
std::stringstream buffer;
buffer << text_file.rdbuf();
return buffer.str();
}
int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
hwy::ThreadPool& pool, const std::string& golden_path) {
const std::vector<std::pair<std::string, std::string>> queries_answers =
int BenchmarkGoldens(GemmaEnv& env, const std::string& golden_path) {
std::vector<std::pair<std::string, std::string>> queries_answers =
load_goldens(golden_path);
int correct_answers = 0;
int total_tokens = 0;
size_t correct_answers = 0;
size_t total_tokens = 0;
const double time_start = hwy::platform::Now();
for (const auto& [question, expected_answer] : queries_answers) {
const auto [answer, token_count] =
QueryModel(model, args, app, kv_cache, pool, question);
for (auto& [question, expected_answer] : queries_answers) {
const auto [answer, token_count] = env.QueryModel(question);
total_tokens += token_count;
if (answer.find(expected_answer) != std::string::npos) {
correct_answers++;
@ -172,28 +103,22 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
return EXIT_SUCCESS;
}
int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args,
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
hwy::ThreadPool& pool, const gcpp::Path& text) {
int BenchmarkSummary(GemmaEnv& env, const Path& text) {
std::string prompt("Here is some text to summarize:\n");
prompt.append(ReadFile(text));
prompt.append(ReadFileToString(text));
prompt.append("\nSummarize this text.\n");
const double time_start = hwy::platform::Now();
const auto [answer, token_count] =
QueryModel(model, args, app, kv_cache, pool, prompt);
const auto [answer, token_count] = env.QueryModel(prompt);
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
LogSpeedStats(time_start, token_count);
return EXIT_SUCCESS;
}
int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
gcpp::InferenceArgs& args, gcpp::AppArgs& app,
hwy::ThreadPool& pool, const gcpp::Path& text,
int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
size_t batch_tokens) {
std::string input = ReadFile(text);
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer().Encode(input, &prompt));
prompt.resize(std::min<size_t>(args.max_tokens, prompt.size()));
std::string input = ReadFileToString(text);
std::vector<int> prompt = env.Tokenize(input);
prompt.resize(std::min<size_t>(env.MaxTokens(), prompt.size()));
std::cout << "Number of input tokens: " << prompt.size() << "\n";
const double time_start = hwy::platform::Now();
float total_entropy = 0.0f;
@ -203,13 +128,12 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
std::vector<int> prompt_slice(prompt.begin() + pos,
prompt.begin() + pos + num_tokens);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(model_type);
float entropy = ComputeCrossEntropy(model, num_tokens, prompt_slice,
kv_cache, app.verbosity);
KVCache kv_cache = KVCache::Create(env.ModelType());
float entropy = ComputeCrossEntropy(
*env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
total_entropy += entropy;
LogSpeedStats(time_start, pos + num_tokens);
std::string text_slice;
HWY_ASSERT(model.Tokenizer().Decode(prompt_slice, &text_slice));
std::string text_slice = env.StringFromTokens(prompt_slice);
total_input_len += text_slice.size();
printf("Total cross entropy: %f [cumulative: %f]\n",
entropy, total_entropy);
@ -219,23 +143,19 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
return EXIT_SUCCESS;
}
int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
hwy::ThreadPool& pool, const gcpp::Path& json_file,
int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file,
size_t max_questions) {
std::ifstream trivia_file(json_file.path);
if (!trivia_file) {
std::cout << "Could not load file: " << json_file.path << "\n"
<< std::flush;
return EXIT_FAILURE;
HWY_ABORT("Could not load file %s\n", json_file.path.c_str());
}
std::string line;
size_t correct_answers = 0;
size_t i = 0;
while (std::getline(trivia_file, line)) {
json data = json::parse(line);
const auto [answer, token_count] = QueryModel(
model, args, app, kv_cache, pool, data["question"]);
std::string q(data["question"]);
const auto [answer, token_count] = env.QueryModel(q);
std::cout << answer << "\n";
bool correct = false;
for (const std::string expected : data["answer"]["aliases"]) {
@ -256,52 +176,25 @@ int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
return EXIT_SUCCESS;
}
/* Run this in the same way as gemma, p.ex.:
./benchmark --tokenizer tokenizer.spm --model 2b-it --weights \
2b-it-sfp.sbs --goldens_dir "../goldens"
*/
} // namespace gcpp
int main(int argc, char** argv) {
{
// Placeholder for internal init, do not modify.
}
gcpp::GemmaEnv env(argc, argv);
gcpp::BenchmarkArgs benchmark_args(argc, argv);
gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs args(argc, argv); // inference
gcpp::AppArgs app(argc, argv);
BenchmarkArgs benchmark_args(argc, argv);
if (const char* error = loader.Validate()) {
HWY_ABORT("\nInvalid loader args: %s", error);
}
if (const char* error = args.Validate()) {
HWY_ABORT("\nInvalid inference args: %s", error);
}
hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning workers to cores helps.
if (app.num_threads > 10) {
gcpp::PinWorkersToCores(pool);
}
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
if (!benchmark_args.goldens.path.empty()) {
if (!benchmark_args.goldens.Empty()) {
const std::string golden_path =
benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt";
return BenchmarkGoldens(model, args, app, kv_cache, pool, golden_path);
} else if (!benchmark_args.summarize_text.path.empty()) {
return BenchmarkSummary(model, args, app, kv_cache, pool,
benchmark_args.summarize_text);
} else if (!benchmark_args.cross_entropy.path.empty()) {
return BenchmarkCrossEntropy(model, loader.ModelType(), args, app,
pool, benchmark_args.cross_entropy,
benchmark_args.goldens.path + "/" +
gcpp::ModelString(env.ModelType(), env.ModelTrainingType()) + ".txt";
return BenchmarkGoldens(env, golden_path);
} else if (!benchmark_args.summarize_text.Empty()) {
return BenchmarkSummary(env, benchmark_args.summarize_text);
} else if (!benchmark_args.cross_entropy.Empty()) {
return BenchmarkCrossEntropy(env, benchmark_args.cross_entropy,
benchmark_args.batch_tokens);
} else if (!benchmark_args.trivia_qa.path.empty()) {
return BenchmarkTriviaQA(model, args, app, kv_cache, pool,
benchmark_args.trivia_qa,
} else if (!benchmark_args.trivia_qa.Empty()) {
return BenchmarkTriviaQA(env, benchmark_args.trivia_qa,
benchmark_args.max_questions);
}
std::cout << "No benchmark command given." << "\n" << std::flush;
return EXIT_FAILURE;
HWY_ABORT("No benchmark command given.");
}

View File

@ -14,70 +14,95 @@
// limitations under the License.
#include "gemma/benchmark_helper.h"
#include <cstdlib> // EXIT_FAILURE
#include <stdio.h>
#include <time.h>
#include <iostream>
#include <memory>
#include <ostream>
#include <random>
#include <string>
#include <thread> // NOLINT
#include <utility> // std::pair
#include <vector>
#include "gemma/common.h"
// Placeholder for internal header, do not modify.
#include "compression/compress.h" // TypeName
#include "gemma/common.h" // StringFromType
#include "gemma/cross_entropy.h"
#include "gemma/gemma.h"
#include "util/app.h"
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/per_target.h"
#include "hwy/timer.h"
namespace gcpp {
void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
if (inference.deterministic) {
// Nothing up my sleeve number, at least some upper bits set.
gen.seed(0x12345678);
} else {
// Depending on the library implementation, this may still be deterministic.
std::random_device rd;
gen.seed(rd());
}
}
GemmaEnv::GemmaEnv(int argc, char** argv)
: loader_(argc, argv), inference_args_(argc, argv), app_(argc, argv),
: loader_(argc, argv),
inference_args_(argc, argv),
app_(argc, argv),
pool_(app_.num_threads) {
if (const char* error = loader_.Validate()) {
HWY_ABORT("\nInvalid loader args: %s", error);
}
if (const char* error = inference_args_.Validate()) {
HWY_ABORT("\nInvalid inference args: %s", error);
{
// Placeholder for internal init, do not modify.
}
// For many-core, pinning workers to cores helps.
if (app_.num_threads > 10) {
gcpp::PinWorkersToCores(pool_);
}
AbortIfInvalidArgs(inference_args_);
if (const char* err = loader_.Validate()) {
loader_.Help();
fprintf(stderr, "Skipping model load because: %s\n", err);
} else {
fprintf(stderr, "Loading model...\n");
model_ = AllocateGemma(loader_, pool_);
kv_cache_ = KVCache::Create(loader_.ModelType());
gen_.seed(42);
}
std::pair<std::string, int> GemmaEnv::QueryModel(const std::string& input) {
std::string prompt_string = input;
if (loader_.ModelTrainingType() == ModelTraining::GEMMA_IT) {
// For instruction-tuned models: add control tokens.
prompt_string = "<start_of_turn>user\n" + input +
"<end_of_turn>\n<start_of_turn>model\n";
}
std::vector<int> prompt;
HWY_ASSERT(model_->Tokenizer().Encode(input, &prompt));
InitGenerator(inference_args_, gen_);
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
prompt.insert(prompt.begin(), gcpp::BOS_ID);
runtime_config_ = {
.max_tokens = inference_args_.max_tokens,
.max_generated_tokens = inference_args_.max_generated_tokens,
.temperature = inference_args_.temperature,
.verbosity = app_.verbosity,
.gen = &gen_,
};
}
std::pair<std::string, size_t> GemmaEnv::QueryModel(
const std::vector<int>& tokens) {
std::string res;
size_t total_tokens = 0;
auto accept_token = [](int) { return true; };
std::mt19937 gen;
gen.seed(42);
const double time_start = hwy::platform::Now();
auto stream_token = [&res, &total_tokens, &time_start, this](
const StreamFunc stream_token = [&res, &total_tokens, &time_start, this](
int token, float) {
++total_tokens;
std::string token_text;
HWY_ASSERT(model_->Tokenizer().Decode(std::vector<int>{token},
&token_text));
HWY_ASSERT(
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
res += token_text;
if (app_.verbosity >= 1 && total_tokens % 100 == 0) {
if (app_.verbosity >= 1 && total_tokens % 128 == 0) {
LogSpeedStats(time_start, total_tokens);
}
return true;
@ -88,24 +113,32 @@ std::pair<std::string, int> GemmaEnv::QueryModel(const std::string& input) {
<< inference_args_.temperature;
}
gcpp::TimingInfo timing_info;
gcpp::RuntimeConfig runtime_config = {
.max_tokens = inference_args_.max_tokens,
.max_generated_tokens = inference_args_.max_generated_tokens,
.temperature = inference_args_.temperature,
.verbosity = app_.verbosity,
.gen = &gen,
.stream_token = stream_token,
.accept_token = accept_token,
};
model_->Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache_,
timing_info, /*layers_output=*/nullptr);
runtime_config_.stream_token = stream_token;
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_cache_,
timing_info);
if (app_.verbosity >= 1) {
LogSpeedStats(time_start, total_tokens);
}
return {res, total_tokens};
}
void GemmaEnv::LogSpeedStats(double time_start, size_t total_tokens) const {
std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) {
const std::vector<int> prompt =
WrapAndTokenize(model_->Tokenizer(), loader_.ModelTrainingType(),
/*pos=*/0, input);
return QueryModel(prompt);
}
float GemmaEnv::CrossEntropy(const std::string& input) {
std::vector<int> prompt = Tokenize(input);
prompt.insert(prompt.begin(), BOS_ID);
return ComputeCrossEntropy(*GetModel(), /*max_tokens=*/3072, prompt,
MutableKVCache(),
/*verbosity=*/0) /
static_cast<int>(input.size());
}
void LogSpeedStats(double time_start, size_t total_tokens) {
const double time_end = hwy::platform::Now();
const double time_elapsed = time_end - time_start;
const double tok_sec = total_tokens / time_elapsed;
@ -113,6 +146,53 @@ void GemmaEnv::LogSpeedStats(double time_start, size_t total_tokens) const {
<< " [" << tok_sec << " tokens / sec" << "]\n";
}
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
loader.Print(app.verbosity);
inference.Print(app.verbosity);
app.Print(app.verbosity);
if (app.verbosity >= 2) {
time_t now = time(nullptr);
char* dt = ctime(&now); // NOLINT
std::cout << "Date & Time : " << dt
<< "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize
<< "\n"
<< "Hardware concurrency : "
<< std::thread::hardware_concurrency() << "\n"
<< "Instruction set : "
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
<< hwy::VectorBytes() * 8 << " bits)" << "\n";
char cpu100[100];
if (hwy::platform::GetCpuString(cpu100)) {
std::cout << "CPU : " << cpu100 << "\n";
}
std::cout << "Compiled config : " << CompiledConfig() << "\n"
<< "Weight Type : "
<< gcpp::StringFromType(loader.WeightType()) << "\n"
<< "EmbedderInput Type : "
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n";
}
}
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
gcpp::AppArgs& app) {
std::cerr
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
"==========================================================\n\n"
"To run gemma.cpp, you need to "
"specify 3 required model loading arguments:\n"
" --tokenizer\n"
" --weights\n"
" --model.\n";
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
"--weights 2b-it-sfp.sbs --model 2b-it\n";
std::cerr << "\n*Model Loading Arguments*\n\n";
loader.Help();
std::cerr << "\n*Inference Arguments*\n\n";
inference.Help();
std::cerr << "\n*Application Arguments*\n\n";
app.Help();
std::cerr << "\n";
}
} // namespace gcpp

View File

@ -16,11 +16,15 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_
#include <stddef.h>
#include <memory>
#include <random>
#include <string>
#include <utility>
#include <vector>
#include "gemma/common.h"
#include "gemma/gemma.h"
#include "util/app.h"
#include "hwy/base.h"
@ -28,24 +32,60 @@
namespace gcpp {
void InitGenerator(const InferenceArgs& inference, std::mt19937& gen);
// Convenience class to load a model and run inference.
class GemmaEnv {
public:
GemmaEnv(int argc, char** argv);
size_t MaxTokens() const { return inference_args_.max_tokens; }
// Sets the maximum number of output tokens to generate.
void set_max_generated_tokens(int max_tokens) {
void SetMaxGeneratedTokens(size_t max_tokens) {
inference_args_.max_generated_tokens = max_tokens;
}
std::vector<int> Tokenize(const std::string& input) const {
std::vector<int> tokens;
HWY_ASSERT(model_->Tokenizer().Encode(input, &tokens));
return tokens;
}
std::vector<int> TokenizeAndPrependBOS(const std::string& input) const {
std::vector<int> tokens = Tokenize(input);
tokens.insert(tokens.begin(), BOS_ID);
return tokens;
}
std::string StringFromTokens(const std::vector<int>& tokens) const {
std::string string;
HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string));
return string;
}
// Runs inference on the given input and returns the top-1 result string and
// the number of tokens that were generated.
std::pair<std::string, int> QueryModel(const std::string& input);
std::pair<std::string, size_t> QueryModel(const std::vector<int>& tokens);
// Adds turn structure to input, tokenizes and calls the above overload.
std::pair<std::string, size_t> QueryModel(std::string& input);
// Runs inference on the given input and returns the cross entropy, a measure
// of how well the model predicts the correct output. It is the average
// number of bits per token.
float CrossEntropy(const std::string& input);
// Returns nullptr if the model failed to load.
Gemma* GetModel() const { return model_.get(); }
Model ModelType() const { return loader_.ModelType(); }
ModelTraining ModelTrainingType() const {
return loader_.ModelTrainingType();
}
int Verbosity() const { return app_.verbosity; }
gcpp::RuntimeConfig& MutableConfig() { return runtime_config_; }
std::mt19937& MutableGen() { return gen_; }
KVCache& MutableKVCache() { return kv_cache_; }
private:
// Logs the inference speed in tokens/sec.
void LogSpeedStats(double time_start, size_t total_tokens) const;
// Arguments to the model loader: file locations, etc.
LoaderArgs loader_;
// Arguments to the inference function: max tokens, etc.
@ -60,10 +100,16 @@ class GemmaEnv {
std::unique_ptr<Gemma> model_;
// The KV cache to use for inference.
KVCache kv_cache_;
gcpp::RuntimeConfig runtime_config_;
};
// Logs the inference speed in tokens/sec.
void LogSpeedStats(double time_start, size_t total_tokens);
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
gcpp::AppArgs& app);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_

View File

@ -13,86 +13,60 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <iostream>
#include <ostream>
#include <random>
#include <sstream>
#include <stddef.h>
#include <stdio.h>
#include <string>
// Placeholder for internal header, do not modify.
#include "benchmark/benchmark.h"
#include "gemma/benchmark_helper.h"
void run_gemma_prompt(const std::string& prompt_string,
gcpp::GemmaEnv& env,
benchmark::State& state) {
std::mt19937 gen;
namespace gcpp {
if (prompt_string.empty()) return;
// Shared state for benchmarks - unfortunately the library does not allow
// passing context nor closures. Raw pointer because style guide forbids
// non-local static objects with dtors.t
GemmaEnv* s_env = nullptr;
int token_counter = 0;
void RunPrompt(const std::string& original_prompt, benchmark::State& state) {
size_t total_tokens = 0;
for (auto s : state) {
auto [response, n] = env.QueryModel(prompt_string);
std::cout << "response: " << response << "\n";
std::cout << "n: " << n << "\n";
token_counter += n;
std::string prompt = original_prompt; // reset from original
auto [response, n] = s_env->QueryModel(prompt);
if (s_env->Verbosity() != 0) {
fprintf(stdout, "|%s|\n", response.c_str());
}
total_tokens += n;
}
state.SetItemsProcessed(token_counter);
state.SetItemsProcessed(total_tokens);
}
// Awkward global because benchmarks don't support additional state, so it is
// either this or cast to int64_t.
gcpp::GemmaEnv* global_env = nullptr;
} // namespace gcpp
static void BM_short_prompt(benchmark::State& state) {
run_gemma_prompt("What is the capital of Spain?", *global_env,
state);
gcpp::RunPrompt("What is the capital of Spain?", state);
}
static void BM_factuality_prompt(benchmark::State& state) {
run_gemma_prompt("How does an inkjet printer work?",
*global_env, state);
gcpp::RunPrompt("How does an inkjet printer work?", state);
}
static void BM_creative_prompt(benchmark::State& state) {
run_gemma_prompt(
"Tell me a story about a magical bunny and their TRS-80.",
*global_env, state);
gcpp::RunPrompt("Tell me a story about a magical bunny and their TRS-80.",
state);
}
static void BM_coding_prompt(benchmark::State& state) {
run_gemma_prompt(
"Write a python program to generate a fibonacci sequence.",
*global_env, state);
gcpp::RunPrompt("Write a python program to generate a fibonacci sequence.",
state);
}
static void BM_long_coding_prompt(benchmark::State& state) {
std::ifstream t("benchmarks.cc", std::ios_base::in);
std::stringstream buffer;
buffer << t.rdbuf();
std::string prompt_string = buffer.str();
t.close();
run_gemma_prompt("Make improvements to the following code:\n " +
prompt_string, *global_env, state);
}
int main(int argc, char** argv) {
{
// Placeholder for internal init, do not modify.
}
gcpp::GemmaEnv env(argc, argv);
env.set_max_generated_tokens(128);
global_env = &env;
BENCHMARK(BM_short_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
env.set_max_generated_tokens(256);
BENCHMARK(BM_factuality_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
@ -108,11 +82,10 @@ int main(int argc, char** argv) {
->Unit(benchmark::kMillisecond)
->UseRealTime();
env.set_max_generated_tokens(1024);
BENCHMARK(BM_long_coding_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
int main(int argc, char** argv) {
gcpp::GemmaEnv env(argc, argv);
env.SetMaxGeneratedTokens(256);
gcpp::s_env = &env;
::benchmark::RunSpecifiedBenchmarks();
::benchmark::Shutdown();

View File

@ -28,7 +28,8 @@
namespace gcpp {
namespace {
const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training) {
constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt", "2b-it",
"7b-it", "gr2b-it", "tiny"};
constexpr Model kModelTypes[] = {
@ -38,10 +39,7 @@ constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT,
ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT,
ModelTraining::GEMMA_IT};
} // namespace
const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training) {
constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags);
static char kErrorMessageBuffer[kNum * 8 + 1024] =
"Invalid or missing model flag, need to specify one of ";
@ -51,37 +49,58 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag,
}
strcat(kErrorMessageBuffer, kModelFlags[kNum - 1]); // NOLINT
strcat(kErrorMessageBuffer, "."); // NOLINT
std::string model_type_lc = model_flag;
std::transform(begin(model_type_lc), end(model_type_lc), begin(model_type_lc),
[](unsigned char c) { return std::tolower(c); });
for (size_t i = 0; i < kNum; i++) {
if (kModelFlags[i] == model_type_lc) {
model = kModelTypes[i];
training = kModelTraining[i];
HWY_ASSERT(std::string(ModelString(model, training)) == model_type_lc);
return nullptr;
}
}
return kErrorMessageBuffer;
}
const char* ModelString(Model model, ModelTraining training) {
if (model == Model::GEMMA_TINY) return "tiny";
static_assert(static_cast<size_t>(ModelTraining::GEMMA_IT) == 0);
constexpr const char* k2B[] = {"2b-it", "2b-pt"};
constexpr const char* k7B[] = {"7b-it", "7b-pt"};
constexpr const char* kGr2B[] = {"gr2b-it", "gr2b-pt"};
if (model == Model::GEMMA_2B) return k2B[static_cast<size_t>(training)];
if (model == Model::GEMMA_7B) return k7B[static_cast<size_t>(training)];
if (model == Model::GRIFFIN_2B) return kGr2B[static_cast<size_t>(training)];
HWY_ABORT("Unknown model %d training %d\n", static_cast<int>(model),
static_cast<int>(training));
}
constexpr const char* kTypeStrings[] = {"f32", "bf16", "sfp"};
const char* StringFromType(Type type) {
return kTypeStrings[static_cast<size_t>(type)];
}
const char* ParseType(const std::string& type_string, Type& type) {
constexpr Type kTypes[] = {Type::kF32, Type::kBF16, Type::kSFP};
constexpr const char* kStrings[] = {"f32", "bf16", "sfp"};
constexpr size_t kNum = std::end(kStrings) - std::begin(kStrings);
constexpr size_t kNum = std::end(kTypeStrings) - std::begin(kTypeStrings);
static char kErrorMessageBuffer[kNum * 8 + 100] =
"Invalid or missing type, need to specify one of ";
for (size_t i = 0; i + 1 < kNum; i++) {
strcat(kErrorMessageBuffer, kStrings[i]); // NOLINT
strcat(kErrorMessageBuffer, kTypeStrings[i]); // NOLINT
strcat(kErrorMessageBuffer, ", "); // NOLINT
}
strcat(kErrorMessageBuffer, kStrings[kNum - 1]); // NOLINT
strcat(kErrorMessageBuffer, kTypeStrings[kNum - 1]); // NOLINT
strcat(kErrorMessageBuffer, "."); // NOLINT
std::string type_lc = type_string;
std::transform(begin(type_lc), end(type_lc), begin(type_lc),
[](unsigned char c) { return std::tolower(c); });
for (size_t i = 0; i < kNum; i++) {
if (kStrings[i] == type_lc) {
type = kTypes[i];
if (kTypeStrings[i] == type_lc) {
type = static_cast<Type>(i);
HWY_ASSERT(std::string(StringFromType(type)) == type_lc);
return nullptr;
}
}

View File

@ -154,18 +154,12 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training);
const char* ParseType(const std::string& type_string, Type& type);
static inline const char* StringFromType(Type type) {
switch (type) {
case Type::kF32:
return "f32";
case Type::kBF16:
return "bf16";
case Type::kSFP:
return "sfp";
default:
return "?";
}
}
// Inverse of ParseModelTypeAndTraining.
const char* ModelString(Model model, ModelTraining training);
const char* StringFromType(Type type);
// ----------------------------------------------------------------------------
//
// __builtin_sqrt is not constexpr as of Clang 17.
#if HWY_COMPILER_GCC_ACTUAL

View File

@ -111,7 +111,7 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
};
TimingInfo timing_info;
gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info, nullptr);
gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info);
const float scale = 1.0f / std::log(2.0f);
return cross_entropy * scale;

View File

@ -17,7 +17,6 @@
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
#include <cstdio>
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
@ -208,8 +207,8 @@ bool GemmaTokenizer::Encode(const std::string& input,
}
bool GemmaTokenizer::Encode(const std::string& input,
std::vector<int>* pieces) const {
return impl_->Encode(input, pieces);
std::vector<int>* ids) const {
return impl_->Encode(input, ids);
}
// Given a sequence of ids, decodes it into a detokenized output.
@ -653,12 +652,12 @@ HWY_NOINLINE void Transformer(const int *tokens, size_t num_tokens, size_t pos,
const WeightArrayT& weights,
Activations<TConfig, kBatchSize>& activations,
KVCache& kv_cache, hwy::ThreadPool& pool,
LayersOutputT* layers_output) {
const LayersOutputFunc& layers_output) {
HWY_ASSERT(num_tokens <= kBatchSize);
if (layers_output != nullptr) {
if (layers_output) {
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
float token_f = tokens[token_idx];
(*layers_output)(pos + token_idx, "Tokens", &token_f, 1);
layers_output(pos + token_idx, "Tokens", &token_f, 1);
}
}
static constexpr size_t kModelDim = TConfig::kModelDim;
@ -713,12 +712,11 @@ HWY_NOINLINE void Transformer(const int *tokens, size_t num_tokens, size_t pos,
}
AddFromBatched<kBatchSize>(num_tokens, activations.ffw_out.data(),
activations.x.data(), kModelDim);
if (layers_output != nullptr) {
if (layers_output) {
std::string block_name = "blocks." + std::to_string(layer);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
(*layers_output)(pos + token_idx, block_name,
activations.x.data() + token_idx * kModelDim,
kModelDim);
layers_output(pos + token_idx, block_name,
activations.x.data() + token_idx * kModelDim, kModelDim);
}
}
}
@ -727,9 +725,9 @@ HWY_NOINLINE void Transformer(const int *tokens, size_t num_tokens, size_t pos,
RMSNormInplaceBatched<kBatchSize>(num_tokens, weights.final_norm_scale.data(),
activations.x.data(), kModelDim);
if (layers_output != nullptr) {
if (layers_output) {
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
(*layers_output)(pos + token_idx, "final_norm",
layers_output(pos + token_idx, "final_norm",
activations.x.data() + token_idx * kModelDim, kModelDim);
}
}
@ -782,8 +780,7 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
const ByteStorageT& decode_u8,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info,
LayersOutputT* layers_output) {
hwy::ThreadPool& pool, TimingInfo& timing_info) {
const CompressedWeights<TConfig>& weights = GetWeights<TConfig>(weights_u8);
auto& prefill_activations =
GetActivations<TConfig, kPrefillBatchSize>(prefill_u8);
@ -860,7 +857,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
pos < max_tokens && generate_pos < max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) {
Transformer<kDecodeBatchSize>(&token, kDecodeBatchSize, pos, weights,
activations, kv_cache, pool, layers_output);
activations, kv_cache, pool,
runtime_config.layers_output);
float token_logit = 0.0f;
// The condition below is always true if we are doing Prefill above.
// We keep it here for clarity so that the code is correct even if Prefill
@ -953,17 +951,37 @@ Gemma::~Gemma() {
void Gemma::Generate(const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, TimingInfo& timing_info,
LayersOutputT* layers_output) {
KVCache& kv_cache, TimingInfo& timing_info) {
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
GEMMA_EXPORT_AND_DISPATCH(
model_type_, weight_type_, GenerateT,
(weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos,
kv_cache, pool_, timing_info, layers_output));
kv_cache, pool_, timing_info));
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
}
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const ModelTraining training, size_t pos,
std::string& prompt) {
// Instruction-tuned models are trained to expect control tokens.
if (training == ModelTraining::GEMMA_IT) {
// Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation.
const std::string start = (pos == 0)
? "<start_of_turn>user\n"
: "<end_of_turn>\n<start_of_turn>user\n";
prompt = start + prompt + "<end_of_turn>\n<start_of_turn>model\n";
}
std::vector<int> tokens;
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
// Both pre-trained and instruction-tuned require BOS as first token.
if (pos == 0) {
tokens.insert(tokens.begin(), gcpp::BOS_ID);
}
return tokens;
}
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -74,10 +74,17 @@ class GemmaTokenizer {
using StreamFunc = std::function<bool(int, float)>;
// If not empty, AcceptFunc is called with token. It should return false for
// tokens you don't want to generate and true for tokens you want to generate.
using AcceptFunc = std::function<bool(int)>;
using AcceptFunc = std::function<bool(int, float)>;
// If not empty, SampleFunc is called with the probability distribution for the
// next token, and its return value is used as the next generated token.
using SampleFunc = std::function<int(const float*, size_t)>;
// Will be called for layers output with:
// - position in the tokens sequence
// - name of the data, p.ex. "tokens", "block.1", "final_norm"
// - pointer to the data array
// - size of the data array
using LayersOutputFunc =
std::function<void(int, const std::string&, const float*, size_t)>;
struct RuntimeConfig {
size_t max_tokens;
@ -88,6 +95,7 @@ struct RuntimeConfig {
StreamFunc stream_token;
AcceptFunc accept_token; // if empty, accepts all tokens.
SampleFunc sample_func; // if empty, uses SampleTopK.
LayersOutputFunc layers_output; // if not empty, called after each layer.
int eos_id = EOS_ID;
};
@ -97,14 +105,6 @@ struct TimingInfo {
double time_to_first_token = 0.0;
};
// Will be called for layers output with:
// - position in the tokens sequence
// - name of the data, p.ex. "tokens", "block.1", "final_norm"
// - pointer to the data array
// - size of the data array
using LayersOutputT =
std::function<void(int, const std::string&, const float*, size_t)>;
class Gemma {
public:
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
@ -121,12 +121,9 @@ class Gemma {
const ByteStorageT& Prefill() const { return prefill_u8_; }
const ByteStorageT& Decode() const { return decode_u8_; }
// layers_output is optional; if set - it will be called with the activations
// output after applying each layer.
void Generate(const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, TimingInfo& timing_info,
LayersOutputT* layers_output = nullptr);
KVCache& kv_cache, TimingInfo& timing_info);
private:
hwy::ThreadPool& pool_;
@ -141,14 +138,19 @@ class Gemma {
Type weight_type_;
};
// Adds BOS token and possibly 'turn' annotations, which depend on `training`
// and `pos`, the number of tokens decoded so far; returns the corresponding
// tokens. Asserts that tokenization is successful.
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
ModelTraining training, size_t pos,
std::string& prompt);
// DEPRECATED, call Gemma::Generate directly.
HWY_INLINE void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& /*pool*/,
TimingInfo& timing_info,
LayersOutputT* layers_output) {
gemma.Generate(runtime_config, prompt, start_pos, kv_cache, timing_info,
layers_output);
TimingInfo& timing_info) {
gemma.Generate(runtime_config, prompt, start_pos, kv_cache, timing_info);
}
} // namespace gcpp

View File

@ -17,89 +17,35 @@
#include <stdio.h>
#include <memory>
#include <random>
#include <string>
#include <vector>
// Placeholder for internal header, do not modify.
#include "gemma/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/cross_entropy.h"
#include "gemma/ops.h"
#include "util/app.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/tests/test_util-inl.h"
#include "hwy/tests/hwy_gtest.h"
namespace gcpp {
namespace {
int s_argc = 0;
char** s_argv = nullptr;
// Shared state. Requires argc/argv, so construct in main and use the same raw
// pointer approach as in benchmarks.cc. Note that the style guide forbids
// non-local static variables with dtors.
GemmaEnv* s_env = nullptr;
class GemmaTest : public ::testing::Test {
protected:
static void SetUpTestSuite() {
gcpp::AppArgs app(s_argc, s_argv);
gcpp::LoaderArgs loader(s_argc, s_argv);
if (const char* err = loader.Validate()) {
fprintf(stderr, "Insufficient LoaderArgs, skipping e2e tests.\n");
} else {
fprintf(stderr, "Loading model..\n");
s_pool = std::make_unique<hwy::ThreadPool>(app.num_threads);
s_gemma = AllocateGemma(loader, *s_pool);
s_kv_cache = KVCache::Create(loader.ModelType());
s_model = loader.ModelType();
}
}
static void TearDownTestSuite() {
s_pool.reset();
s_gemma.reset();
}
std::string GemmaReply(const std::string& prompt_string) {
std::mt19937 gen;
gen.seed(42);
std::vector<int> prompt;
HWY_ASSERT(s_gemma->Tokenizer().Encode(prompt_string, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
prompt.insert(prompt.begin(), BOS_ID);
std::vector<int> response;
auto stream_token = [&response](int token, float) {
response.push_back(token);
return true;
};
gcpp::RuntimeConfig runtime_config = {
.max_tokens = 3072,
.max_generated_tokens = 2048,
.temperature = 1.0,
.verbosity = 0,
.gen = &gen,
.stream_token = stream_token,
};
gcpp::TimingInfo timing_info;
s_gemma->Generate(runtime_config, prompt, /*start_pos=*/0, s_kv_cache,
timing_info, /*layers_output=*/nullptr);
std::string response_text;
HWY_ASSERT(s_gemma->Tokenizer().Decode(response, &response_text));
return response_text;
}
float GemmaCrossEntropy(const std::string& prompt_string) {
std::vector<int> prompt;
HWY_ASSERT(s_gemma->Tokenizer().Encode(prompt_string, &prompt));
prompt.insert(prompt.begin(), BOS_ID);
return ComputeCrossEntropy(*s_gemma, /*max_tokens=*/3072, prompt,
s_kv_cache,
/*verbosity=*/0) /
prompt_string.size();
std::string GemmaReply(const std::string& prompt) {
s_env->SetMaxGeneratedTokens(2048);
s_env->MutableConfig().temperature = 0.0f; // deterministic
s_env->MutableConfig().verbosity = 0;
// Using the turn structure worsens results.
const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt);
auto [response, n] = s_env->QueryModel(tokens);
return response;
}
void TestQuestions(const char* kQA[][2], size_t num_questions) {
if (!s_gemma) return;
if (!s_env->GetModel()) return;
for (size_t i = 0; i < num_questions; ++i) {
fprintf(stderr, "Question %zu\n\n", i + 1);
std::string response = GemmaReply(kQA[i][0]);
@ -107,18 +53,8 @@ class GemmaTest : public ::testing::Test {
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
}
}
static std::unique_ptr<hwy::ThreadPool> s_pool;
static std::unique_ptr<gcpp::Gemma> s_gemma;
static gcpp::KVCache s_kv_cache;
static gcpp::Model s_model;
};
/*static*/ std::unique_ptr<hwy::ThreadPool> GemmaTest::s_pool;
/*static*/ std::unique_ptr<gcpp::Gemma> GemmaTest::s_gemma;
/*static*/ gcpp::KVCache GemmaTest::s_kv_cache;
/*static*/ gcpp::Model GemmaTest::s_model;
TEST_F(GemmaTest, Geography) {
static const char* kQA[][2] = {
{"What is the capital of Hungary?", "Budapest"},
@ -130,7 +66,7 @@ TEST_F(GemmaTest, Geography) {
TEST_F(GemmaTest, History) {
static const char* kQA[][2] = {
{"When was the Battle of Hastings?", "1066"},
{"When was the battle of Hastings?", "1066"},
};
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
TestQuestions(kQA, kNum);
@ -181,42 +117,39 @@ static const char kGettysburg[] = {
"people, for the people, shall not perish from the earth.\n"};
TEST_F(GemmaTest, CrossEntropySmall) {
if (!s_gemma) return;
if (!s_env->GetModel()) return;
static const char kSmall[] =
"The capital of Hungary is Budapest which is located in Europe.";
float entropy = GemmaCrossEntropy(kSmall);
float entropy = s_env->CrossEntropy(kSmall);
fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 2.1f : 2.0f);
EXPECT_LT(entropy,
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 2.1f : 2.0f);
}
TEST_F(GemmaTest, CrossEntropyJingleBells) {
if (!s_gemma) return;
float entropy = GemmaCrossEntropy(kJingleBells);
if (!s_env->GetModel()) return;
float entropy = s_env->CrossEntropy(kJingleBells);
fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 0.9f : 1.8f);
EXPECT_LT(entropy,
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 0.9f : 1.8f);
}
TEST_F(GemmaTest, CrossEntropyGettysburg) {
if (!s_gemma) return;
float entropy = GemmaCrossEntropy(kGettysburg);
if (!s_env->GetModel()) return;
float entropy = s_env->CrossEntropy(kGettysburg);
fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 0.8f : 1.2f);
EXPECT_LT(entropy,
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 0.8f : 1.2f);
}
} // namespace
} // namespace gcpp
int main(int argc, char** argv) {
{
// Placeholder for internal init, do not modify.
}
gcpp::GemmaEnv env(argc, argv);
gcpp::s_env = &env;
// For later use by SetUp.
gcpp::s_argc = argc;
gcpp::s_argv = argv;
// Probably should be called before SetUpTestSuite.
testing::InitGoogleTest(&gcpp::s_argc, argv);
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -1668,12 +1668,13 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
std::array<int, k> indices{};
for (size_t i = 0; i < vocab_size; ++i) {
if (probabilities[i] < top_k[k - 1] &&
(!accept_token || accept_token(StaticCast<int>(i)))) {
(!accept_token || accept_token(StaticCast<int>(i), probabilities[i]))) {
continue;
}
for (size_t j = 0; j < k; ++j) {
if (probabilities[i] > top_k[j] &&
(!accept_token || accept_token(StaticCast<int>(i)))) {
(!accept_token ||
accept_token(StaticCast<int>(i), probabilities[i]))) {
// shift elements by 1, insert the new value, move on to next value
for (size_t idx = k - 1; idx > j; --idx) {
top_k[idx] = top_k[idx - 1];

View File

@ -15,27 +15,22 @@
// Command line text interface to gemma.
#include <ctime>
#include <iostream>
#include <random>
#include <string>
#include <string_view>
#include <thread> // NOLINT
#include <vector>
// Placeholder for internal header, do not modify.
#include "compression/compress.h"
#include "gemma/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/gemma.h" // Gemma
#include "util/app.h"
#include "util/args.h" // HasHelp
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/per_target.h"
#include "hwy/profiler.h"
#include "hwy/timer.h"
#if (!defined(HWY_VERSION_LT) || HWY_VERSION_LT(1, 2)) && !HWY_IDE
#error "Please update to version 1.2 of github.com/google/highway."
@ -57,56 +52,6 @@ static constexpr std::string_view kAsciiArtBanner = R""(
|___/ |_| |_|
)"";
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
loader.Print(app.verbosity);
inference.Print(app.verbosity);
app.Print(app.verbosity);
if (app.verbosity >= 2) {
time_t now = time(nullptr);
char* dt = ctime(&now); // NOLINT
std::cout << "Date & Time : " << dt
<< "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize
<< "\n"
<< "Hardware concurrency : "
<< std::thread::hardware_concurrency() << "\n"
<< "Instruction set : "
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
<< hwy::VectorBytes() * 8 << " bits)" << "\n";
char cpu100[100];
if (hwy::platform::GetCpuString(cpu100)) {
std::cout << "CPU : " << cpu100 << "\n";
}
std::cout << "Compiled config : " << CompiledConfig() << "\n"
<< "Weight Type : "
<< gcpp::StringFromType(loader.WeightType()) << "\n"
<< "EmbedderInput Type : "
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n";
}
}
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
gcpp::AppArgs& app) {
std::cerr
<< kAsciiArtBanner
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
"==========================================================\n\n"
"To run gemma.cpp, you need to "
"specify 3 required model loading arguments:\n"
" --tokenizer\n"
" --weights\n"
" --model.\n";
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
"--weights 2b-it-sfp.sbs --model 2b-it\n";
std::cerr << "\n*Model Loading Arguments*\n\n";
loader.Help();
std::cerr << "\n*Inference Arguments*\n\n";
inference.Help();
std::cerr << "\n*Application Arguments*\n\n";
app.Help();
std::cerr << "\n";
}
// The main Read-Eval-Print Loop.
void ReplGemma(gcpp::Gemma& model, ModelTraining training,
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
@ -118,12 +63,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
int prompt_size{};
std::mt19937 gen;
if (args.deterministic) {
gen.seed(42);
} else {
std::random_device rd;
gen.seed(rd());
}
InitGenerator(args, gen);
// callback function invoked for each generated token.
auto stream_token = [&abs_pos, &current_pos, &args, &gen, &prompt_size,
@ -162,7 +102,6 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
while (abs_pos < args.max_tokens) {
std::string prompt_string;
std::vector<int> prompt;
current_pos = 0;
{
PROFILER_ZONE("Gen.input");
@ -192,30 +131,11 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
continue;
}
if (training == ModelTraining::GEMMA_IT) {
// For instruction-tuned models: add control tokens.
prompt_string = "<start_of_turn>user\n" + prompt_string +
"<end_of_turn>\n<start_of_turn>model\n";
if (abs_pos != 0) {
// Prepend "<end_of_turn>" token if this is a multi-turn dialogue
// continuation.
prompt_string = "<end_of_turn>\n" + prompt_string;
}
}
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
if (abs_pos == 0) {
prompt.insert(prompt.begin(), gcpp::BOS_ID);
}
const std::vector<int> prompt =
WrapAndTokenize(model.Tokenizer(), training, abs_pos, prompt_string);
prompt_size = prompt.size();
std::cerr << "\n"
<< "[ Reading prompt ] " << std::flush;
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
@ -301,17 +221,20 @@ int main(int argc, char** argv) {
gcpp::AppArgs app(argc, argv);
if (gcpp::HasHelp(argc, argv)) {
ShowHelp(loader, inference, app);
std::cerr << gcpp::kAsciiArtBanner;
gcpp::ShowHelp(loader, inference, app);
return 0;
}
if (const char* error = loader.Validate()) {
ShowHelp(loader, inference, app);
std::cerr << gcpp::kAsciiArtBanner;
gcpp::ShowHelp(loader, inference, app);
HWY_ABORT("\nInvalid args: %s", error);
}
if (const char* error = inference.Validate()) {
ShowHelp(loader, inference, app);
std::cerr << gcpp::kAsciiArtBanner;
gcpp::ShowHelp(loader, inference, app);
HWY_ABORT("\nInvalid args: %s", error);
}

View File

@ -13,19 +13,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Command line text interface to gemma.
#include <stdio.h>
#include <fstream>
#include <iostream>
#include <random>
#include <set>
#include <sstream>
#include <algorithm>
#include <string>
#include <vector>
// Placeholder for internal header, do not modify.
#include "compression/io.h" // Path
#include "gemma/benchmark_helper.h"
#include "gemma/gemma.h" // Gemma
#include "util/app.h"
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
@ -34,164 +31,134 @@
namespace gcpp {
void JsonGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
hwy::ThreadPool& pool,
const InferenceArgs& args, int verbosity,
std::string& eot_line) {
PROFILER_ZONE("Gen.misc");
// token index within the current turn
int max_tokens = 4096;
struct JsonArgs : public ArgsBase<JsonArgs> {
JsonArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
std::mt19937 gen;
if (args.deterministic) {
gen.seed(42);
} else {
std::random_device rd;
gen.seed(rd());
Path input;
// Returns error string or nullptr if OK.
const char* Validate() const {
if (input.Empty()) return "Must specify --input";
if (!input.Exists()) return "--input file does not exist";
return nullptr;
}
float answers = 0.0;
float correct_answers = 0.0;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(input, "input", Path(), "Full pathname of mmlu.json.");
};
};
std::ifstream fJson("/tmp/mmlu.json");
std::stringstream buffer;
buffer << fJson.rdbuf();
auto json = nlohmann::json::parse(buffer.str());
std::vector<std::string> accept_tokens = {"A", "B", "C", "D"};
std::set<int> accept_token_set{};
for (const std::string& accept_token : accept_tokens) {
std::vector<int> accept_token_ids;
HWY_ASSERT(model.Tokenizer().Encode(accept_token, &accept_token_ids));
accept_token_set.insert(accept_token_ids.begin(), accept_token_ids.end());
// Linear search for a few tokens is faster than std::set.
// TODO: instead of accepting for each vocab entry, filter the logits once.
class TokenSet {
public:
TokenSet(const GemmaTokenizer& tokenizer,
const std::vector<std::string>& strings) {
all_tokens_.reserve(strings.size());
for (const std::string& str : strings) {
std::vector<int> tokens;
fprintf(stderr, "%s -> ", str.c_str());
HWY_ASSERT(tokenizer.Encode(str, &tokens));
for (int token : tokens) {
fprintf(stderr, "%d, ", token);
all_tokens_.push_back(token);
}
fprintf(stderr, "\n");
}
}
for (auto sample : json["samples"]) {
int abs_pos = 0; // absolute token index over all turns
int current_pos = 0;
int prompt_size{};
bool Contains(int token) const {
return std::find(all_tokens_.begin(), all_tokens_.end(), token) !=
all_tokens_.end();
}
// cout << "prompt:" << sample["prompt"] << endl;
const std::string& prompt_string = sample["prompt"];
std::vector<int> prompt;
private:
std::vector<int> all_tokens_;
};
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
prompt_size = prompt.size();
void Run(GemmaEnv& env, JsonArgs& json) {
PROFILER_ZONE("Run.all");
const std::string& correct_answer = accept_tokens[sample["input_label"]];
float answers = 0.0f;
float correct_answers = 0.0f;
// max_tokens = prompt_size + max_tokens;
auto json_data = nlohmann::json::parse(ReadFileToString(json.input));
const std::vector<std::string> accept_strings = {
"A", "B", "C", "D", //
" A", " B", " C", " D", //
"**", "**:", ":**", "The", "Answer", "is", ":", "."};
const TokenSet accept_set(env.GetModel()->Tokenizer(), accept_strings);
for (auto sample : json_data["samples"]) {
const int id = sample["i"];
fprintf(stderr, "Processing question %d\n", id);
const std::string& correct_answer = accept_strings[sample["input_label"]];
std::string prompt_string = sample["prompt"];
// AcceptFunc restricts the output to one of these four tokens, so make an
// effort to steer the model towards that. See
// https://huggingface.co/blog/open-llm-leaderboard-mmlu
prompt_string +=
"What is start of the line with the correct answer? "
"Do not include any justifications or explanations. Reply only with a "
"letter.";
const std::vector<int> prompt =
WrapAndTokenize(env.GetModel()->Tokenizer(), env.ModelTrainingType(),
/*pos=*/0, prompt_string);
const size_t prompt_size = prompt.size();
std::vector<int> predicted_token_ids;
predicted_token_ids.reserve(max_tokens);
auto stream_token = [&current_pos, &prompt_size, &predicted_token_ids,
&accept_token_set](int token, float proba) {
predicted_token_ids.reserve(4096);
size_t current_pos = 0;
const StreamFunc stream_token = [&current_pos, prompt_size,
&predicted_token_ids](int token,
float proba) {
PROFILER_ZONE("Stream");
++current_pos;
if (current_pos > prompt_size) {
predicted_token_ids.push_back(token);
// If the generated token is in the accepted token set, return False.
// This will stop further generation.
return accept_token_set.find(token) == accept_token_set.end();
}
return true;
};
const AcceptFunc accept_token = [&current_pos, &prompt_size,
&accept_token_set](int token) {
// i.e. we have no constraints on accepted tokens
if (accept_token_set.empty()) {
return true;
}
if (current_pos >= prompt_size) {
return accept_token_set.find(token) != accept_token_set.end();
} else {
// auto-accept early tokens
return true;
}
};
// Although " A" is a token, it is difficult to associate that with the
// correct answer. Only accepting certain tokens is risky: (A) is easily
// confused with the word "A".
gcpp::TimingInfo timing_info;
gcpp::RuntimeConfig runtime_config = {
.max_tokens = args.max_tokens,
.max_generated_tokens = args.max_generated_tokens,
.temperature = args.temperature,
.verbosity = verbosity,
.gen = &gen,
.max_tokens = env.MaxTokens(),
.max_generated_tokens = 30,
.temperature = 0.0f,
.verbosity = env.Verbosity(),
.gen = &env.MutableGen(),
.stream_token = stream_token,
.accept_token = accept_token,
};
model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info);
env.GetModel()->Generate(runtime_config, prompt, /*pos=*/0,
env.MutableKVCache(), timing_info);
std::string output_string;
HWY_ASSERT(model.Tokenizer().Decode(predicted_token_ids, &output_string));
std::cout << "QuestionId: " << sample["i"] << "; "
<< "Predicted Answer: " << output_string << "; "
<< "Correct Answer: " << correct_answer << std::endl;
std::string output_string = env.StringFromTokens(predicted_token_ids);
fprintf(stderr, "Correct %s, model '%s'\n", correct_answer.c_str(),
output_string.c_str());
answers += 1.0;
answers += 1.0f;
if (output_string == correct_answer) {
correct_answers += 1.0;
correct_answers += 1.0f;
}
std::cout << "Running accuracy = " << "["
<< static_cast<int>(correct_answers) << "/"
<< static_cast<int>(answers) << "]" << " = "
<< correct_answers / answers << std::endl;
fprintf(stderr, "%.0f/%.0f = %.2f%%\n", correct_answers, answers,
correct_answers / answers);
}
}
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
loader.Print(app.verbosity);
inference.Print(app.verbosity);
app.Print(app.verbosity);
}
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
PROFILER_ZONE("Run.misc");
hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning workers to cores helps.
if (app.num_threads > 10) {
PinWorkersToCores(pool);
}
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
JsonGemma(model, kv_cache, pool, inference, app.verbosity, app.eot_line);
}
} // namespace gcpp
int main(int argc, char** argv) {
{
PROFILER_ZONE("Startup.misc");
// Placeholder for internal init, do not modify.
gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs inference(argc, argv);
gcpp::AppArgs app(argc, argv);
if (const char* error = loader.Validate()) {
fprintf(stderr,
"\ngemma.cpp\n---------\n\nTo run gemma.cpp, you need to "
"specify 3 required model loading arguments: --tokenizer, "
"--compressed_weights, "
"and --model.\n\nModel Loading Arguments\n\n");
loader.Help();
fprintf(stderr, "\nInference Arguments\n\n");
inference.Help();
fprintf(stderr, "\nApplication Arguments\n\n");
app.Help();
fprintf(stderr, "\n\n");
HWY_ABORT("\nInvalid args: %s", error);
}
gcpp::Run(loader, inference, app);
PROFILER_ZONE("Startup.all");
gcpp::GemmaEnv env(argc, argv);
gcpp::JsonArgs json(argc, argv);
gcpp::AbortIfInvalidArgs(json);
gcpp::Run(env, json);
}
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
return 0;

View File

@ -197,6 +197,14 @@ static inline HWY_MAYBE_UNUSED bool HasHelp(int argc, char* argv[]) {
return false;
}
template <class TArgs>
static inline HWY_MAYBE_UNUSED void AbortIfInvalidArgs(TArgs& args) {
if (const char* err = args.Validate()) {
args.Help();
HWY_ABORT("Problem with args: %s\n", err);
}
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_ARGS_H_