Major duplicated code reduction in test/benchmarks

Helper functions to tokenize/wrap
Move LayersOutputFunc into RuntimeConfig
AcceptFunc passes the probability
Implement StringFromType using the parser, and verify results match

PiperOrigin-RevId: 643255119
This commit is contained in:
Jan Wassenberg 2024-06-14 00:15:36 -07:00 committed by Copybara-Service
parent c15ff9529c
commit d3c6a45b59
17 changed files with 626 additions and 814 deletions

View File

@ -113,12 +113,8 @@ cc_library(
cc_library( cc_library(
name = "cross_entropy", name = "cross_entropy",
srcs = [ srcs = ["gemma/cross_entropy.cc"],
"gemma/cross_entropy.cc", hdrs = ["gemma/cross_entropy.h"],
],
hdrs = [
"gemma/cross_entropy.h",
],
deps = [ deps = [
":common", ":common",
":gemma_lib", ":gemma_lib",
@ -149,6 +145,25 @@ cc_library(
], ],
) )
cc_library(
name = "benchmark_helper",
srcs = ["gemma/benchmark_helper.cc"],
hdrs = ["gemma/benchmark_helper.h"],
deps = [
":app",
":args",
":common",
":cross_entropy",
":gemma_lib",
# Placeholder for internal dep, do not remove.,
"@benchmark//:benchmark",
"//compression:compress",
"@hwy//:hwy",
"@hwy//:nanobenchmark",
"@hwy//:thread_pool",
],
)
cc_test( cc_test(
name = "gemma_test", name = "gemma_test",
srcs = ["gemma/gemma_test.cc"], srcs = ["gemma/gemma_test.cc"],
@ -161,11 +176,11 @@ cc_test(
deps = [ deps = [
":app", ":app",
":args", ":args",
":benchmark_helper",
":common", ":common",
":cross_entropy", ":cross_entropy",
":gemma_lib", ":gemma_lib",
":ops", ":ops",
# Placeholder for internal dep, do not remove.,
"@googletest//:gtest_main", "@googletest//:gtest_main",
"//compression:io", "//compression:io",
"@hwy//:hwy_test_util", "@hwy//:hwy_test_util",
@ -179,6 +194,7 @@ cc_binary(
deps = [ deps = [
":app", ":app",
":args", ":args",
":benchmark_helper",
":common", ":common",
":gemma_lib", ":gemma_lib",
# Placeholder for internal dep, do not remove., # Placeholder for internal dep, do not remove.,
@ -213,10 +229,10 @@ cc_binary(
deps = [ deps = [
":app", ":app",
":args", ":args",
":benchmark_helper",
":common", ":common",
":cross_entropy", ":cross_entropy",
":gemma_lib", ":gemma_lib",
# Placeholder for internal dep, do not remove.,
"//compression:io", "//compression:io",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:nanobenchmark", "@hwy//:nanobenchmark",
@ -230,7 +246,6 @@ cc_binary(
srcs = ["gemma/benchmarks.cc"], srcs = ["gemma/benchmarks.cc"],
deps = [ deps = [
":benchmark_helper", ":benchmark_helper",
# Placeholder for internal dep, do not remove.,
"@benchmark//:benchmark", "@benchmark//:benchmark",
], ],
) )
@ -243,8 +258,8 @@ cc_binary(
deps = [ deps = [
":app", ":app",
":args", ":args",
":benchmark_helper",
":gemma_lib", ":gemma_lib",
# Placeholder for internal dep, do not remove.,
"//compression:io", "//compression:io",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:thread_pool", "@hwy//:thread_pool",
@ -257,8 +272,10 @@ cc_binary(
srcs = ["gemma/run_mmlu.cc"], srcs = ["gemma/run_mmlu.cc"],
deps = [ deps = [
":app", ":app",
":args",
":benchmark_helper",
":gemma_lib", ":gemma_lib",
# Placeholder for internal dep, do not remove., "//compression:io",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:profiler", "@hwy//:profiler",
"@hwy//:thread_pool", "@hwy//:thread_pool",
@ -318,25 +335,6 @@ cc_library(
], ],
) )
cc_library(
name = "benchmark_helper",
srcs = [
"gemma/benchmark_helper.cc",
],
hdrs = [
"gemma/benchmark_helper.h",
],
deps = [
":app",
":common",
":gemma_lib",
"@benchmark//:benchmark",
"@hwy//:hwy",
"@hwy//:nanobenchmark",
"@hwy//:thread_pool",
],
)
cc_test( cc_test(
name = "backward_scalar_test", name = "backward_scalar_test",
size = "large", size = "large",

View File

@ -23,6 +23,8 @@
#include <string> #include <string>
#include <utility> // std::move #include <utility> // std::move
#include "hwy/base.h"
namespace gcpp { namespace gcpp {
// Forward-declare to break the circular dependency: OpenFileOrNull returns // Forward-declare to break the circular dependency: OpenFileOrNull returns
@ -77,12 +79,30 @@ struct Path {
return path; return path;
} }
bool Empty() const { return path.empty(); }
// Returns whether the file existed when this was called. // Returns whether the file existed when this was called.
bool Exists() const { return !!OpenFileOrNull(*this, "r"); } bool Exists() const { return !!OpenFileOrNull(*this, "r"); }
std::string path; std::string path;
}; };
static inline HWY_MAYBE_UNUSED std::string ReadFileToString(const Path& path) {
std::unique_ptr<File> file = OpenFileOrNull(path, "r");
if (!file) {
HWY_ABORT("Failed to open %s", path.path.c_str());
}
const size_t size = file->FileSize();
if (size == 0) {
HWY_ABORT("Empty file %s", path.path.c_str());
}
std::string content(size, ' ');
if (!file->Read(0, size, content.data())) {
HWY_ABORT("Failed to read %s", path.path.c_str());
}
return content;
}
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_ #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_

View File

@ -1,146 +1,83 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstdlib>
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <random>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
// Placeholder for internal header, do not modify.
#include "compression/io.h" #include "compression/io.h"
#include "gemma/gemma.h" #include "gemma/benchmark_helper.h"
#include "util/app.h" #include "gemma/gemma.h" // LayersOutputFunc
#include "util/args.h" #include "util/args.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
using json = nlohmann::json; using json = nlohmann::json;
class PromptArgs : public gcpp::ArgsBase<PromptArgs> { namespace gcpp {
class PromptArgs : public ArgsBase<PromptArgs> {
public: public:
PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
gcpp::Path layers_output; Path layers_output; // optional
std::string prompt; std::string prompt;
// Returns error string or nullptr if OK.
const char* Validate() const {
if (prompt.empty()) return "Must specify --prompt";
return nullptr;
}
template <class Visitor> template <class Visitor>
void ForEach(const Visitor& visitor) { void ForEach(const Visitor& visitor) {
visitor(layers_output.path, "layers_output", std::string(""), visitor(layers_output, "layers_output", Path(""),
"Path to store layers output", 2); "Path to store layers output", 2);
visitor(prompt, "prompt", std::string(""), "Prompt to the model", 2); visitor(prompt, "prompt", std::string(""), "Prompt to the model", 2);
} }
}; };
std::pair<std::string, int> QueryModel( int Run(int argc, char** argv) {
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input,
gcpp::LayersOutputT* layers_output) {
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer().Encode(input, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
prompt.insert(prompt.begin(), gcpp::BOS_ID);
std::string res;
size_t total_tokens = 0;
std::mt19937 gen;
gen.seed(42);
auto stream_token = [&res, &total_tokens, &model](int token, float) {
++total_tokens;
std::string token_text;
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
res += token_text;
return true;
};
if (app.verbosity >= 2) {
std::cout << args.max_tokens << " " << args.max_generated_tokens << " "
<< args.temperature;
}
gcpp::TimingInfo timing_info;
gcpp::RuntimeConfig runtime_config = {
.max_tokens = args.max_tokens,
.max_generated_tokens = args.max_generated_tokens,
.temperature = args.temperature,
.verbosity = app.verbosity,
.gen = &gen,
.stream_token = stream_token,
};
model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, timing_info,
layers_output);
return {res, total_tokens};
}
class OutputJsonLogger {
public:
json json_output;
gcpp::LayersOutputT layers_output_log_f =
[this](int pos, const std::string& key, const float* values,
size_t values_len) {
std::vector<float> v{values, values + values_len};
json_output[std::to_string(pos)][key] = v;
};
};
/* Run this in the same way as gemma, p.ex.:
./debug_prompt --tokenizer tokenizer.spm --model 2b-it --weights \
2b-it-sfp.sbs --prompt "..." --layers_output [path]
*/
int main(int argc, char** argv) {
{
// Placeholder for internal init, do not modify.
}
gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs args(argc, argv); // inference
gcpp::AppArgs app(argc, argv);
PromptArgs prompt_args(argc, argv); PromptArgs prompt_args(argc, argv);
AbortIfInvalidArgs(prompt_args);
if (const char* error = loader.Validate()) { json json_output;
HWY_ABORT("\nInvalid loader args: %s", error); GemmaEnv env(argc, argv);
} env.MutableConfig().layers_output =
if (const char* error = args.Validate()) { prompt_args.layers_output.Empty()
HWY_ABORT("\nInvalid inference args: %s", error); ? LayersOutputFunc()
} : [&json_output](int pos, const std::string& key, const float* values,
const bool log_layers_output = !prompt_args.layers_output.path.empty(); size_t values_len) {
OutputJsonLogger json_logger; std::vector<float> v{values, values + values_len};
gcpp::LayersOutputT* layers_output = json_output[std::to_string(pos)][key] = v;
log_layers_output ? &json_logger.layers_output_log_f : nullptr; };
hwy::ThreadPool pool(app.num_threads); const auto [answer, token_count] = env.QueryModel(prompt_args.prompt);
// For many-core, pinning workers to cores helps. std::cout << answer.substr(prompt_args.prompt.size()) << "\n" << std::flush;
if (app.num_threads > 10) {
gcpp::PinWorkersToCores(pool);
}
gcpp::Gemma model = gcpp::CreateGemma(loader, pool); if (env.MutableConfig().layers_output) {
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
const std::string& prompt = prompt_args.prompt;
if (prompt.empty()) {
std::cout << "Please specify --prompt" << std::endl;
return EXIT_FAILURE;
}
const auto [answer, token_count] = QueryModel(
model, args, app, kv_cache, pool, prompt, layers_output);
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
if (log_layers_output) {
std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out); std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out);
if (!output_f) { if (!output_f) HWY_ABORT("Opening layer output file failed");
std::cout << "Opening file failed" << std::endl; output_f << json_output.dump();
return EXIT_FAILURE; if (!output_f) HWY_ABORT("Writing to layer output file failed");
}
output_f << json_logger.json_output.dump();
if (!output_f) {
std::cout << "Writing to file failed" << std::endl;
return EXIT_FAILURE;
}
output_f.close(); output_f.close();
} }
return 0;
return EXIT_SUCCESS;
} }
} // namespace gcpp
int main(int argc, char** argv) { return gcpp::Run(argc, argv); }

View File

@ -1,36 +1,36 @@
#include <stdio.h>
#include <algorithm> #include <algorithm>
#include <cstdlib> // EXIT_FAILURE #include <cstdlib> // EXIT_FAILURE
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <ostream> #include <ostream>
#include <random>
#include <sstream>
#include <string> #include <string>
#include <utility> // std::pair #include <utility> // std::pair
#include <vector> #include <vector>
// Placeholder for internal header, do not modify.
#include "compression/io.h" // Path #include "compression/io.h" // Path
#include "gemma/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/cross_entropy.h" #include "gemma/cross_entropy.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "util/app.h"
#include "util/args.h" #include "util/args.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/timer.h" #include "hwy/timer.h"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
namespace gcpp {
using json = nlohmann::json; using json = nlohmann::json;
class BenchmarkArgs : public gcpp::ArgsBase<BenchmarkArgs> { class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
public: public:
BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
gcpp::Path goldens; Path goldens;
gcpp::Path summarize_text; Path summarize_text;
gcpp::Path cross_entropy; Path cross_entropy;
gcpp::Path trivia_qa; Path trivia_qa;
size_t max_questions; size_t max_questions;
size_t batch_tokens; size_t batch_tokens;
@ -53,61 +53,6 @@ class BenchmarkArgs : public gcpp::ArgsBase<BenchmarkArgs> {
} }
}; };
void LogSpeedStats(const double time_start, size_t total_tokens) {
const double time_end = hwy::platform::Now();
const double time_elapsed = time_end - time_start;
const double tok_sec = total_tokens / time_elapsed;
std::cout << total_tokens << " tokens in " << time_elapsed << " seconds"
<< " [" << tok_sec << " tokens / sec" << "]\n";
}
std::pair<std::string, int> QueryModel(
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input) {
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer().Encode(input, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
prompt.insert(prompt.begin(), 2);
std::string res;
size_t total_tokens = 0;
std::mt19937 gen;
gen.seed(42);
const double time_start = hwy::platform::Now();
auto stream_token = [&res, &total_tokens, &time_start, &app, &model](
int token, float) {
++total_tokens;
std::string token_text;
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
res += token_text;
if (app.verbosity >= 1 && total_tokens % 100 == 0) {
LogSpeedStats(time_start, total_tokens);
}
return true;
};
if (app.verbosity >= 2) {
std::cout << args.max_tokens << " " << args.max_generated_tokens << " "
<< args.temperature;
}
gcpp::TimingInfo timing_info;
gcpp::RuntimeConfig runtime_config = {
.max_tokens = args.max_tokens,
.max_generated_tokens = args.max_generated_tokens,
.temperature = args.temperature,
.verbosity = app.verbosity,
.gen = &gen,
.stream_token = stream_token,
};
model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, timing_info,
/*layers_output=*/nullptr);
if (app.verbosity >= 1) {
LogSpeedStats(time_start, total_tokens);
}
return {res, total_tokens};
}
std::vector<std::pair<std::string, std::string>> load_goldens( std::vector<std::pair<std::string, std::string>> load_goldens(
const std::string& path) { const std::string& path) {
std::ifstream goldens_file(path); std::ifstream goldens_file(path);
@ -129,28 +74,14 @@ std::vector<std::pair<std::string, std::string>> load_goldens(
return res; return res;
} }
std::string ReadFile(const gcpp::Path& path) { int BenchmarkGoldens(GemmaEnv& env, const std::string& golden_path) {
std::ifstream text_file(path.path); std::vector<std::pair<std::string, std::string>> queries_answers =
if (!text_file) {
std::cout << "Could not open file: " << path.path << "\n" << std::flush;
return {};
}
std::stringstream buffer;
buffer << text_file.rdbuf();
return buffer.str();
}
int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
hwy::ThreadPool& pool, const std::string& golden_path) {
const std::vector<std::pair<std::string, std::string>> queries_answers =
load_goldens(golden_path); load_goldens(golden_path);
int correct_answers = 0; size_t correct_answers = 0;
int total_tokens = 0; size_t total_tokens = 0;
const double time_start = hwy::platform::Now(); const double time_start = hwy::platform::Now();
for (const auto& [question, expected_answer] : queries_answers) { for (auto& [question, expected_answer] : queries_answers) {
const auto [answer, token_count] = const auto [answer, token_count] = env.QueryModel(question);
QueryModel(model, args, app, kv_cache, pool, question);
total_tokens += token_count; total_tokens += token_count;
if (answer.find(expected_answer) != std::string::npos) { if (answer.find(expected_answer) != std::string::npos) {
correct_answers++; correct_answers++;
@ -172,28 +103,22 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
return EXIT_SUCCESS; return EXIT_SUCCESS;
} }
int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args, int BenchmarkSummary(GemmaEnv& env, const Path& text) {
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
hwy::ThreadPool& pool, const gcpp::Path& text) {
std::string prompt("Here is some text to summarize:\n"); std::string prompt("Here is some text to summarize:\n");
prompt.append(ReadFile(text)); prompt.append(ReadFileToString(text));
prompt.append("\nSummarize this text.\n"); prompt.append("\nSummarize this text.\n");
const double time_start = hwy::platform::Now(); const double time_start = hwy::platform::Now();
const auto [answer, token_count] = const auto [answer, token_count] = env.QueryModel(prompt);
QueryModel(model, args, app, kv_cache, pool, prompt);
std::cout << answer.substr(prompt.size()) << "\n" << std::flush; std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
LogSpeedStats(time_start, token_count); LogSpeedStats(time_start, token_count);
return EXIT_SUCCESS; return EXIT_SUCCESS;
} }
int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type, int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
gcpp::InferenceArgs& args, gcpp::AppArgs& app,
hwy::ThreadPool& pool, const gcpp::Path& text,
size_t batch_tokens) { size_t batch_tokens) {
std::string input = ReadFile(text); std::string input = ReadFileToString(text);
std::vector<int> prompt; std::vector<int> prompt = env.Tokenize(input);
HWY_ASSERT(model.Tokenizer().Encode(input, &prompt)); prompt.resize(std::min<size_t>(env.MaxTokens(), prompt.size()));
prompt.resize(std::min<size_t>(args.max_tokens, prompt.size()));
std::cout << "Number of input tokens: " << prompt.size() << "\n"; std::cout << "Number of input tokens: " << prompt.size() << "\n";
const double time_start = hwy::platform::Now(); const double time_start = hwy::platform::Now();
float total_entropy = 0.0f; float total_entropy = 0.0f;
@ -203,13 +128,12 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens); size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
std::vector<int> prompt_slice(prompt.begin() + pos, std::vector<int> prompt_slice(prompt.begin() + pos,
prompt.begin() + pos + num_tokens); prompt.begin() + pos + num_tokens);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(model_type); KVCache kv_cache = KVCache::Create(env.ModelType());
float entropy = ComputeCrossEntropy(model, num_tokens, prompt_slice, float entropy = ComputeCrossEntropy(
kv_cache, app.verbosity); *env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
total_entropy += entropy; total_entropy += entropy;
LogSpeedStats(time_start, pos + num_tokens); LogSpeedStats(time_start, pos + num_tokens);
std::string text_slice; std::string text_slice = env.StringFromTokens(prompt_slice);
HWY_ASSERT(model.Tokenizer().Decode(prompt_slice, &text_slice));
total_input_len += text_slice.size(); total_input_len += text_slice.size();
printf("Total cross entropy: %f [cumulative: %f]\n", printf("Total cross entropy: %f [cumulative: %f]\n",
entropy, total_entropy); entropy, total_entropy);
@ -219,23 +143,19 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
return EXIT_SUCCESS; return EXIT_SUCCESS;
} }
int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args, int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file,
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
hwy::ThreadPool& pool, const gcpp::Path& json_file,
size_t max_questions) { size_t max_questions) {
std::ifstream trivia_file(json_file.path); std::ifstream trivia_file(json_file.path);
if (!trivia_file) { if (!trivia_file) {
std::cout << "Could not load file: " << json_file.path << "\n" HWY_ABORT("Could not load file %s\n", json_file.path.c_str());
<< std::flush;
return EXIT_FAILURE;
} }
std::string line; std::string line;
size_t correct_answers = 0; size_t correct_answers = 0;
size_t i = 0; size_t i = 0;
while (std::getline(trivia_file, line)) { while (std::getline(trivia_file, line)) {
json data = json::parse(line); json data = json::parse(line);
const auto [answer, token_count] = QueryModel( std::string q(data["question"]);
model, args, app, kv_cache, pool, data["question"]); const auto [answer, token_count] = env.QueryModel(q);
std::cout << answer << "\n"; std::cout << answer << "\n";
bool correct = false; bool correct = false;
for (const std::string expected : data["answer"]["aliases"]) { for (const std::string expected : data["answer"]["aliases"]) {
@ -256,52 +176,25 @@ int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
return EXIT_SUCCESS; return EXIT_SUCCESS;
} }
/* Run this in the same way as gemma, p.ex.: } // namespace gcpp
./benchmark --tokenizer tokenizer.spm --model 2b-it --weights \
2b-it-sfp.sbs --goldens_dir "../goldens"
*/
int main(int argc, char** argv) { int main(int argc, char** argv) {
{ gcpp::GemmaEnv env(argc, argv);
// Placeholder for internal init, do not modify. gcpp::BenchmarkArgs benchmark_args(argc, argv);
}
gcpp::LoaderArgs loader(argc, argv); if (!benchmark_args.goldens.Empty()) {
gcpp::InferenceArgs args(argc, argv); // inference
gcpp::AppArgs app(argc, argv);
BenchmarkArgs benchmark_args(argc, argv);
if (const char* error = loader.Validate()) {
HWY_ABORT("\nInvalid loader args: %s", error);
}
if (const char* error = args.Validate()) {
HWY_ABORT("\nInvalid inference args: %s", error);
}
hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning workers to cores helps.
if (app.num_threads > 10) {
gcpp::PinWorkersToCores(pool);
}
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
if (!benchmark_args.goldens.path.empty()) {
const std::string golden_path = const std::string golden_path =
benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt"; benchmark_args.goldens.path + "/" +
return BenchmarkGoldens(model, args, app, kv_cache, pool, golden_path); gcpp::ModelString(env.ModelType(), env.ModelTrainingType()) + ".txt";
} else if (!benchmark_args.summarize_text.path.empty()) { return BenchmarkGoldens(env, golden_path);
return BenchmarkSummary(model, args, app, kv_cache, pool, } else if (!benchmark_args.summarize_text.Empty()) {
benchmark_args.summarize_text); return BenchmarkSummary(env, benchmark_args.summarize_text);
} else if (!benchmark_args.cross_entropy.path.empty()) { } else if (!benchmark_args.cross_entropy.Empty()) {
return BenchmarkCrossEntropy(model, loader.ModelType(), args, app, return BenchmarkCrossEntropy(env, benchmark_args.cross_entropy,
pool, benchmark_args.cross_entropy,
benchmark_args.batch_tokens); benchmark_args.batch_tokens);
} else if (!benchmark_args.trivia_qa.path.empty()) { } else if (!benchmark_args.trivia_qa.Empty()) {
return BenchmarkTriviaQA(model, args, app, kv_cache, pool, return BenchmarkTriviaQA(env, benchmark_args.trivia_qa,
benchmark_args.trivia_qa,
benchmark_args.max_questions); benchmark_args.max_questions);
} }
std::cout << "No benchmark command given." << "\n" << std::flush; HWY_ABORT("No benchmark command given.");
return EXIT_FAILURE;
} }

View File

@ -14,70 +14,95 @@
// limitations under the License. // limitations under the License.
#include "gemma/benchmark_helper.h" #include "gemma/benchmark_helper.h"
#include <cstdlib> // EXIT_FAILURE
#include <stdio.h>
#include <time.h>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <ostream> #include <ostream>
#include <random> #include <random>
#include <string> #include <string>
#include <thread> // NOLINT
#include <utility> // std::pair #include <utility> // std::pair
#include <vector> #include <vector>
#include "gemma/common.h" // Placeholder for internal header, do not modify.
#include "compression/compress.h" // TypeName
#include "gemma/common.h" // StringFromType
#include "gemma/cross_entropy.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "util/app.h" #include "util/app.h"
#include "util/args.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/per_target.h"
#include "hwy/timer.h" #include "hwy/timer.h"
namespace gcpp { namespace gcpp {
GemmaEnv::GemmaEnv(int argc, char** argv)
: loader_(argc, argv), inference_args_(argc, argv), app_(argc, argv), void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
pool_(app_.num_threads) { if (inference.deterministic) {
if (const char* error = loader_.Validate()) { // Nothing up my sleeve number, at least some upper bits set.
HWY_ABORT("\nInvalid loader args: %s", error); gen.seed(0x12345678);
} } else {
if (const char* error = inference_args_.Validate()) { // Depending on the library implementation, this may still be deterministic.
HWY_ABORT("\nInvalid inference args: %s", error); std::random_device rd;
} gen.seed(rd());
// For many-core, pinning workers to cores helps. }
if (app_.num_threads > 10) { }
gcpp::PinWorkersToCores(pool_);
} GemmaEnv::GemmaEnv(int argc, char** argv)
: loader_(argc, argv),
inference_args_(argc, argv),
app_(argc, argv),
pool_(app_.num_threads) {
{
// Placeholder for internal init, do not modify.
}
// For many-core, pinning workers to cores helps.
if (app_.num_threads > 10) {
gcpp::PinWorkersToCores(pool_);
}
AbortIfInvalidArgs(inference_args_);
if (const char* err = loader_.Validate()) {
loader_.Help();
fprintf(stderr, "Skipping model load because: %s\n", err);
} else {
fprintf(stderr, "Loading model...\n");
model_ = AllocateGemma(loader_, pool_); model_ = AllocateGemma(loader_, pool_);
kv_cache_ = KVCache::Create(loader_.ModelType()); kv_cache_ = KVCache::Create(loader_.ModelType());
gen_.seed(42);
} }
std::pair<std::string, int> GemmaEnv::QueryModel(const std::string& input) { InitGenerator(inference_args_, gen_);
std::string prompt_string = input;
if (loader_.ModelTrainingType() == ModelTraining::GEMMA_IT) {
// For instruction-tuned models: add control tokens.
prompt_string = "<start_of_turn>user\n" + input +
"<end_of_turn>\n<start_of_turn>model\n";
}
std::vector<int> prompt;
HWY_ASSERT(model_->Tokenizer().Encode(input, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token runtime_config_ = {
// if needed. .max_tokens = inference_args_.max_tokens,
prompt.insert(prompt.begin(), gcpp::BOS_ID); .max_generated_tokens = inference_args_.max_generated_tokens,
.temperature = inference_args_.temperature,
.verbosity = app_.verbosity,
.gen = &gen_,
};
}
std::pair<std::string, size_t> GemmaEnv::QueryModel(
const std::vector<int>& tokens) {
std::string res; std::string res;
size_t total_tokens = 0; size_t total_tokens = 0;
auto accept_token = [](int) { return true; };
std::mt19937 gen;
gen.seed(42);
const double time_start = hwy::platform::Now(); const double time_start = hwy::platform::Now();
auto stream_token = [&res, &total_tokens, &time_start, this]( const StreamFunc stream_token = [&res, &total_tokens, &time_start, this](
int token, float) { int token, float) {
++total_tokens; ++total_tokens;
std::string token_text; std::string token_text;
HWY_ASSERT(model_->Tokenizer().Decode(std::vector<int>{token}, HWY_ASSERT(
&token_text)); model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
res += token_text; res += token_text;
if (app_.verbosity >= 1 && total_tokens % 100 == 0) { if (app_.verbosity >= 1 && total_tokens % 128 == 0) {
LogSpeedStats(time_start, total_tokens); LogSpeedStats(time_start, total_tokens);
} }
return true; return true;
@ -88,24 +113,32 @@ std::pair<std::string, int> GemmaEnv::QueryModel(const std::string& input) {
<< inference_args_.temperature; << inference_args_.temperature;
} }
gcpp::TimingInfo timing_info; gcpp::TimingInfo timing_info;
gcpp::RuntimeConfig runtime_config = { runtime_config_.stream_token = stream_token;
.max_tokens = inference_args_.max_tokens, model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_cache_,
.max_generated_tokens = inference_args_.max_generated_tokens, timing_info);
.temperature = inference_args_.temperature,
.verbosity = app_.verbosity,
.gen = &gen,
.stream_token = stream_token,
.accept_token = accept_token,
};
model_->Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache_,
timing_info, /*layers_output=*/nullptr);
if (app_.verbosity >= 1) { if (app_.verbosity >= 1) {
LogSpeedStats(time_start, total_tokens); LogSpeedStats(time_start, total_tokens);
} }
return {res, total_tokens}; return {res, total_tokens};
} }
void GemmaEnv::LogSpeedStats(double time_start, size_t total_tokens) const { std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) {
const std::vector<int> prompt =
WrapAndTokenize(model_->Tokenizer(), loader_.ModelTrainingType(),
/*pos=*/0, input);
return QueryModel(prompt);
}
float GemmaEnv::CrossEntropy(const std::string& input) {
std::vector<int> prompt = Tokenize(input);
prompt.insert(prompt.begin(), BOS_ID);
return ComputeCrossEntropy(*GetModel(), /*max_tokens=*/3072, prompt,
MutableKVCache(),
/*verbosity=*/0) /
static_cast<int>(input.size());
}
void LogSpeedStats(double time_start, size_t total_tokens) {
const double time_end = hwy::platform::Now(); const double time_end = hwy::platform::Now();
const double time_elapsed = time_end - time_start; const double time_elapsed = time_end - time_start;
const double tok_sec = total_tokens / time_elapsed; const double tok_sec = total_tokens / time_elapsed;
@ -113,6 +146,53 @@ void GemmaEnv::LogSpeedStats(double time_start, size_t total_tokens) const {
<< " [" << tok_sec << " tokens / sec" << "]\n"; << " [" << tok_sec << " tokens / sec" << "]\n";
} }
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
loader.Print(app.verbosity);
inference.Print(app.verbosity);
app.Print(app.verbosity);
if (app.verbosity >= 2) {
time_t now = time(nullptr);
char* dt = ctime(&now); // NOLINT
std::cout << "Date & Time : " << dt
<< "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize
<< "\n"
<< "Hardware concurrency : "
<< std::thread::hardware_concurrency() << "\n"
<< "Instruction set : "
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
<< hwy::VectorBytes() * 8 << " bits)" << "\n";
char cpu100[100];
if (hwy::platform::GetCpuString(cpu100)) {
std::cout << "CPU : " << cpu100 << "\n";
}
std::cout << "Compiled config : " << CompiledConfig() << "\n"
<< "Weight Type : "
<< gcpp::StringFromType(loader.WeightType()) << "\n"
<< "EmbedderInput Type : "
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n";
}
}
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
gcpp::AppArgs& app) {
std::cerr
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
"==========================================================\n\n"
"To run gemma.cpp, you need to "
"specify 3 required model loading arguments:\n"
" --tokenizer\n"
" --weights\n"
" --model.\n";
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
"--weights 2b-it-sfp.sbs --model 2b-it\n";
std::cerr << "\n*Model Loading Arguments*\n\n";
loader.Help();
std::cerr << "\n*Inference Arguments*\n\n";
inference.Help();
std::cerr << "\n*Application Arguments*\n\n";
app.Help();
std::cerr << "\n";
}
} // namespace gcpp } // namespace gcpp

View File

@ -16,11 +16,15 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_
#include <stddef.h>
#include <memory> #include <memory>
#include <random> #include <random>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector>
#include "gemma/common.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "util/app.h" #include "util/app.h"
#include "hwy/base.h" #include "hwy/base.h"
@ -28,24 +32,60 @@
namespace gcpp { namespace gcpp {
void InitGenerator(const InferenceArgs& inference, std::mt19937& gen);
// Convenience class to load a model and run inference. // Convenience class to load a model and run inference.
class GemmaEnv { class GemmaEnv {
public: public:
GemmaEnv(int argc, char** argv); GemmaEnv(int argc, char** argv);
size_t MaxTokens() const { return inference_args_.max_tokens; }
// Sets the maximum number of output tokens to generate. // Sets the maximum number of output tokens to generate.
void set_max_generated_tokens(int max_tokens) { void SetMaxGeneratedTokens(size_t max_tokens) {
inference_args_.max_generated_tokens = max_tokens; inference_args_.max_generated_tokens = max_tokens;
} }
std::vector<int> Tokenize(const std::string& input) const {
std::vector<int> tokens;
HWY_ASSERT(model_->Tokenizer().Encode(input, &tokens));
return tokens;
}
std::vector<int> TokenizeAndPrependBOS(const std::string& input) const {
std::vector<int> tokens = Tokenize(input);
tokens.insert(tokens.begin(), BOS_ID);
return tokens;
}
std::string StringFromTokens(const std::vector<int>& tokens) const {
std::string string;
HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string));
return string;
}
// Runs inference on the given input and returns the top-1 result string and // Runs inference on the given input and returns the top-1 result string and
// the number of tokens that were generated. // the number of tokens that were generated.
std::pair<std::string, int> QueryModel(const std::string& input); std::pair<std::string, size_t> QueryModel(const std::vector<int>& tokens);
// Adds turn structure to input, tokenizes and calls the above overload.
std::pair<std::string, size_t> QueryModel(std::string& input);
// Runs inference on the given input and returns the cross entropy, a measure
// of how well the model predicts the correct output. It is the average
// number of bits per token.
float CrossEntropy(const std::string& input);
// Returns nullptr if the model failed to load.
Gemma* GetModel() const { return model_.get(); }
Model ModelType() const { return loader_.ModelType(); }
ModelTraining ModelTrainingType() const {
return loader_.ModelTrainingType();
}
int Verbosity() const { return app_.verbosity; }
gcpp::RuntimeConfig& MutableConfig() { return runtime_config_; }
std::mt19937& MutableGen() { return gen_; }
KVCache& MutableKVCache() { return kv_cache_; }
private: private:
// Logs the inference speed in tokens/sec.
void LogSpeedStats(double time_start, size_t total_tokens) const;
// Arguments to the model loader: file locations, etc. // Arguments to the model loader: file locations, etc.
LoaderArgs loader_; LoaderArgs loader_;
// Arguments to the inference function: max tokens, etc. // Arguments to the inference function: max tokens, etc.
@ -60,10 +100,16 @@ class GemmaEnv {
std::unique_ptr<Gemma> model_; std::unique_ptr<Gemma> model_;
// The KV cache to use for inference. // The KV cache to use for inference.
KVCache kv_cache_; KVCache kv_cache_;
gcpp::RuntimeConfig runtime_config_;
}; };
// Logs the inference speed in tokens/sec.
void LogSpeedStats(double time_start, size_t total_tokens);
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
gcpp::AppArgs& app);
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_ #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_

View File

@ -13,108 +13,81 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <fstream> #include <stddef.h>
#include <iostream> #include <stdio.h>
#include <ostream>
#include <random>
#include <sstream>
#include <string> #include <string>
// Placeholder for internal header, do not modify.
#include "benchmark/benchmark.h" #include "benchmark/benchmark.h"
#include "gemma/benchmark_helper.h" #include "gemma/benchmark_helper.h"
void run_gemma_prompt(const std::string& prompt_string, namespace gcpp {
gcpp::GemmaEnv& env,
benchmark::State& state) {
std::mt19937 gen;
if (prompt_string.empty()) return; // Shared state for benchmarks - unfortunately the library does not allow
// passing context nor closures. Raw pointer because style guide forbids
// non-local static objects with dtors.t
GemmaEnv* s_env = nullptr;
int token_counter = 0; void RunPrompt(const std::string& original_prompt, benchmark::State& state) {
size_t total_tokens = 0;
for (auto s : state) { for (auto s : state) {
auto [response, n] = env.QueryModel(prompt_string); std::string prompt = original_prompt; // reset from original
std::cout << "response: " << response << "\n"; auto [response, n] = s_env->QueryModel(prompt);
std::cout << "n: " << n << "\n"; if (s_env->Verbosity() != 0) {
token_counter += n; fprintf(stdout, "|%s|\n", response.c_str());
}
total_tokens += n;
} }
state.SetItemsProcessed(token_counter); state.SetItemsProcessed(total_tokens);
} }
// Awkward global because benchmarks don't support additional state, so it is } // namespace gcpp
// either this or cast to int64_t.
gcpp::GemmaEnv* global_env = nullptr;
static void BM_short_prompt(benchmark::State& state) { static void BM_short_prompt(benchmark::State& state) {
run_gemma_prompt("What is the capital of Spain?", *global_env, gcpp::RunPrompt("What is the capital of Spain?", state);
state);
} }
static void BM_factuality_prompt(benchmark::State& state) { static void BM_factuality_prompt(benchmark::State& state) {
run_gemma_prompt("How does an inkjet printer work?", gcpp::RunPrompt("How does an inkjet printer work?", state);
*global_env, state);
} }
static void BM_creative_prompt(benchmark::State& state) { static void BM_creative_prompt(benchmark::State& state) {
run_gemma_prompt( gcpp::RunPrompt("Tell me a story about a magical bunny and their TRS-80.",
"Tell me a story about a magical bunny and their TRS-80.", state);
*global_env, state);
} }
static void BM_coding_prompt(benchmark::State& state) { static void BM_coding_prompt(benchmark::State& state) {
run_gemma_prompt( gcpp::RunPrompt("Write a python program to generate a fibonacci sequence.",
"Write a python program to generate a fibonacci sequence.", state);
*global_env, state);
} }
static void BM_long_coding_prompt(benchmark::State& state) { BENCHMARK(BM_short_prompt)
std::ifstream t("benchmarks.cc", std::ios_base::in); ->Iterations(3)
std::stringstream buffer; ->Unit(benchmark::kMillisecond)
buffer << t.rdbuf(); ->UseRealTime();
std::string prompt_string = buffer.str();
t.close();
run_gemma_prompt("Make improvements to the following code:\n " + BENCHMARK(BM_factuality_prompt)
prompt_string, *global_env, state); ->Iterations(3)
} ->Unit(benchmark::kMillisecond)
->UseRealTime();
BENCHMARK(BM_creative_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
BENCHMARK(BM_coding_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
int main(int argc, char** argv) { int main(int argc, char** argv) {
{
// Placeholder for internal init, do not modify.
}
gcpp::GemmaEnv env(argc, argv); gcpp::GemmaEnv env(argc, argv);
env.SetMaxGeneratedTokens(256);
gcpp::s_env = &env;
env.set_max_generated_tokens(128); ::benchmark::RunSpecifiedBenchmarks();
global_env = &env; ::benchmark::Shutdown();
BENCHMARK(BM_short_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
env.set_max_generated_tokens(256);
BENCHMARK(BM_factuality_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
BENCHMARK(BM_creative_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
BENCHMARK(BM_coding_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
env.set_max_generated_tokens(1024);
BENCHMARK(BM_long_coding_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
::benchmark ::RunSpecifiedBenchmarks();
::benchmark ::Shutdown();
return 0; return 0;
} }

View File

@ -28,20 +28,18 @@
namespace gcpp { namespace gcpp {
namespace {
constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt", "2b-it",
"7b-it", "gr2b-it", "tiny"};
constexpr Model kModelTypes[] = {
Model::GEMMA_2B, Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_2B,
Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_TINY};
constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT,
ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT,
ModelTraining::GEMMA_IT};
} // namespace
const char* ParseModelTypeAndTraining(const std::string& model_flag, const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training) { Model& model, ModelTraining& training) {
constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt", "2b-it",
"7b-it", "gr2b-it", "tiny"};
constexpr Model kModelTypes[] = {
Model::GEMMA_2B, Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_2B,
Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_TINY};
constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT,
ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT,
ModelTraining::GEMMA_IT};
constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags); constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags);
static char kErrorMessageBuffer[kNum * 8 + 1024] = static char kErrorMessageBuffer[kNum * 8 + 1024] =
"Invalid or missing model flag, need to specify one of "; "Invalid or missing model flag, need to specify one of ";
@ -51,37 +49,58 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag,
} }
strcat(kErrorMessageBuffer, kModelFlags[kNum - 1]); // NOLINT strcat(kErrorMessageBuffer, kModelFlags[kNum - 1]); // NOLINT
strcat(kErrorMessageBuffer, "."); // NOLINT strcat(kErrorMessageBuffer, "."); // NOLINT
std::string model_type_lc = model_flag; std::string model_type_lc = model_flag;
std::transform(begin(model_type_lc), end(model_type_lc), begin(model_type_lc), std::transform(begin(model_type_lc), end(model_type_lc), begin(model_type_lc),
[](unsigned char c) { return std::tolower(c); }); [](unsigned char c) { return std::tolower(c); });
for (size_t i = 0; i < kNum; i++) { for (size_t i = 0; i < kNum; i++) {
if (kModelFlags[i] == model_type_lc) { if (kModelFlags[i] == model_type_lc) {
model = kModelTypes[i]; model = kModelTypes[i];
training = kModelTraining[i]; training = kModelTraining[i];
HWY_ASSERT(std::string(ModelString(model, training)) == model_type_lc);
return nullptr; return nullptr;
} }
} }
return kErrorMessageBuffer; return kErrorMessageBuffer;
} }
const char* ModelString(Model model, ModelTraining training) {
if (model == Model::GEMMA_TINY) return "tiny";
static_assert(static_cast<size_t>(ModelTraining::GEMMA_IT) == 0);
constexpr const char* k2B[] = {"2b-it", "2b-pt"};
constexpr const char* k7B[] = {"7b-it", "7b-pt"};
constexpr const char* kGr2B[] = {"gr2b-it", "gr2b-pt"};
if (model == Model::GEMMA_2B) return k2B[static_cast<size_t>(training)];
if (model == Model::GEMMA_7B) return k7B[static_cast<size_t>(training)];
if (model == Model::GRIFFIN_2B) return kGr2B[static_cast<size_t>(training)];
HWY_ABORT("Unknown model %d training %d\n", static_cast<int>(model),
static_cast<int>(training));
}
constexpr const char* kTypeStrings[] = {"f32", "bf16", "sfp"};
const char* StringFromType(Type type) {
return kTypeStrings[static_cast<size_t>(type)];
}
const char* ParseType(const std::string& type_string, Type& type) { const char* ParseType(const std::string& type_string, Type& type) {
constexpr Type kTypes[] = {Type::kF32, Type::kBF16, Type::kSFP}; constexpr size_t kNum = std::end(kTypeStrings) - std::begin(kTypeStrings);
constexpr const char* kStrings[] = {"f32", "bf16", "sfp"};
constexpr size_t kNum = std::end(kStrings) - std::begin(kStrings);
static char kErrorMessageBuffer[kNum * 8 + 100] = static char kErrorMessageBuffer[kNum * 8 + 100] =
"Invalid or missing type, need to specify one of "; "Invalid or missing type, need to specify one of ";
for (size_t i = 0; i + 1 < kNum; i++) { for (size_t i = 0; i + 1 < kNum; i++) {
strcat(kErrorMessageBuffer, kStrings[i]); // NOLINT strcat(kErrorMessageBuffer, kTypeStrings[i]); // NOLINT
strcat(kErrorMessageBuffer, ", "); // NOLINT strcat(kErrorMessageBuffer, ", "); // NOLINT
} }
strcat(kErrorMessageBuffer, kStrings[kNum - 1]); // NOLINT strcat(kErrorMessageBuffer, kTypeStrings[kNum - 1]); // NOLINT
strcat(kErrorMessageBuffer, "."); // NOLINT strcat(kErrorMessageBuffer, "."); // NOLINT
std::string type_lc = type_string; std::string type_lc = type_string;
std::transform(begin(type_lc), end(type_lc), begin(type_lc), std::transform(begin(type_lc), end(type_lc), begin(type_lc),
[](unsigned char c) { return std::tolower(c); }); [](unsigned char c) { return std::tolower(c); });
for (size_t i = 0; i < kNum; i++) { for (size_t i = 0; i < kNum; i++) {
if (kStrings[i] == type_lc) { if (kTypeStrings[i] == type_lc) {
type = kTypes[i]; type = static_cast<Type>(i);
HWY_ASSERT(std::string(StringFromType(type)) == type_lc);
return nullptr; return nullptr;
} }
} }

View File

@ -154,18 +154,12 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training); Model& model, ModelTraining& training);
const char* ParseType(const std::string& type_string, Type& type); const char* ParseType(const std::string& type_string, Type& type);
static inline const char* StringFromType(Type type) { // Inverse of ParseModelTypeAndTraining.
switch (type) { const char* ModelString(Model model, ModelTraining training);
case Type::kF32: const char* StringFromType(Type type);
return "f32";
case Type::kBF16: // ----------------------------------------------------------------------------
return "bf16"; //
case Type::kSFP:
return "sfp";
default:
return "?";
}
}
// __builtin_sqrt is not constexpr as of Clang 17. // __builtin_sqrt is not constexpr as of Clang 17.
#if HWY_COMPILER_GCC_ACTUAL #if HWY_COMPILER_GCC_ACTUAL

View File

@ -111,7 +111,7 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
}; };
TimingInfo timing_info; TimingInfo timing_info;
gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info, nullptr); gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info);
const float scale = 1.0f / std::log(2.0f); const float scale = 1.0f / std::log(2.0f);
return cross_entropy * scale; return cross_entropy * scale;

View File

@ -17,7 +17,6 @@
// Compiles this file for multiple architectures via "foreach_target.h", to // Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'. // which we pass the filename via macro 'argument'.
#include <cstdio>
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT #define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
@ -208,8 +207,8 @@ bool GemmaTokenizer::Encode(const std::string& input,
} }
bool GemmaTokenizer::Encode(const std::string& input, bool GemmaTokenizer::Encode(const std::string& input,
std::vector<int>* pieces) const { std::vector<int>* ids) const {
return impl_->Encode(input, pieces); return impl_->Encode(input, ids);
} }
// Given a sequence of ids, decodes it into a detokenized output. // Given a sequence of ids, decodes it into a detokenized output.
@ -649,16 +648,16 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
// Compute the transformer for a batch of input tokens. During generation, // Compute the transformer for a batch of input tokens. During generation,
// we usually have num_tokens == 1 (and also kBatchSize == 1). // we usually have num_tokens == 1 (and also kBatchSize == 1).
template <size_t kBatchSize, typename WeightArrayT, class TConfig> template <size_t kBatchSize, typename WeightArrayT, class TConfig>
HWY_NOINLINE void Transformer(const int *tokens, size_t num_tokens, size_t pos, HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, size_t pos,
const WeightArrayT& weights, const WeightArrayT& weights,
Activations<TConfig, kBatchSize>& activations, Activations<TConfig, kBatchSize>& activations,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
LayersOutputT* layers_output) { const LayersOutputFunc& layers_output) {
HWY_ASSERT(num_tokens <= kBatchSize); HWY_ASSERT(num_tokens <= kBatchSize);
if (layers_output != nullptr) { if (layers_output) {
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
float token_f = tokens[token_idx]; float token_f = tokens[token_idx];
(*layers_output)(pos + token_idx, "Tokens", &token_f, 1); layers_output(pos + token_idx, "Tokens", &token_f, 1);
} }
} }
static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kModelDim = TConfig::kModelDim;
@ -713,12 +712,11 @@ HWY_NOINLINE void Transformer(const int *tokens, size_t num_tokens, size_t pos,
} }
AddFromBatched<kBatchSize>(num_tokens, activations.ffw_out.data(), AddFromBatched<kBatchSize>(num_tokens, activations.ffw_out.data(),
activations.x.data(), kModelDim); activations.x.data(), kModelDim);
if (layers_output != nullptr) { if (layers_output) {
std::string block_name = "blocks." + std::to_string(layer); std::string block_name = "blocks." + std::to_string(layer);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
(*layers_output)(pos + token_idx, block_name, layers_output(pos + token_idx, block_name,
activations.x.data() + token_idx * kModelDim, activations.x.data() + token_idx * kModelDim, kModelDim);
kModelDim);
} }
} }
} }
@ -727,10 +725,10 @@ HWY_NOINLINE void Transformer(const int *tokens, size_t num_tokens, size_t pos,
RMSNormInplaceBatched<kBatchSize>(num_tokens, weights.final_norm_scale.data(), RMSNormInplaceBatched<kBatchSize>(num_tokens, weights.final_norm_scale.data(),
activations.x.data(), kModelDim); activations.x.data(), kModelDim);
if (layers_output != nullptr) { if (layers_output) {
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
(*layers_output)(pos + token_idx, "final_norm", layers_output(pos + token_idx, "final_norm",
activations.x.data() + token_idx * kModelDim, kModelDim); activations.x.data() + token_idx * kModelDim, kModelDim);
} }
} }
} }
@ -782,8 +780,7 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
const ByteStorageT& decode_u8, const ByteStorageT& decode_u8,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache, const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info, hwy::ThreadPool& pool, TimingInfo& timing_info) {
LayersOutputT* layers_output) {
const CompressedWeights<TConfig>& weights = GetWeights<TConfig>(weights_u8); const CompressedWeights<TConfig>& weights = GetWeights<TConfig>(weights_u8);
auto& prefill_activations = auto& prefill_activations =
GetActivations<TConfig, kPrefillBatchSize>(prefill_u8); GetActivations<TConfig, kPrefillBatchSize>(prefill_u8);
@ -860,7 +857,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
pos < max_tokens && generate_pos < max_generated_tokens; pos < max_tokens && generate_pos < max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) { ++pos, ++pos_offset, ++generate_pos) {
Transformer<kDecodeBatchSize>(&token, kDecodeBatchSize, pos, weights, Transformer<kDecodeBatchSize>(&token, kDecodeBatchSize, pos, weights,
activations, kv_cache, pool, layers_output); activations, kv_cache, pool,
runtime_config.layers_output);
float token_logit = 0.0f; float token_logit = 0.0f;
// The condition below is always true if we are doing Prefill above. // The condition below is always true if we are doing Prefill above.
// We keep it here for clarity so that the code is correct even if Prefill // We keep it here for clarity so that the code is correct even if Prefill
@ -953,17 +951,37 @@ Gemma::~Gemma() {
void Gemma::Generate(const RuntimeConfig& runtime_config, void Gemma::Generate(const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, TimingInfo& timing_info, KVCache& kv_cache, TimingInfo& timing_info) {
LayersOutputT* layers_output) {
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
GEMMA_EXPORT_AND_DISPATCH( GEMMA_EXPORT_AND_DISPATCH(
model_type_, weight_type_, GenerateT, model_type_, weight_type_, GenerateT,
(weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos, (weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos,
kv_cache, pool_, timing_info, layers_output)); kv_cache, pool_, timing_info));
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
} }
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const ModelTraining training, size_t pos,
std::string& prompt) {
// Instruction-tuned models are trained to expect control tokens.
if (training == ModelTraining::GEMMA_IT) {
// Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation.
const std::string start = (pos == 0)
? "<start_of_turn>user\n"
: "<end_of_turn>\n<start_of_turn>user\n";
prompt = start + prompt + "<end_of_turn>\n<start_of_turn>model\n";
}
std::vector<int> tokens;
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
// Both pre-trained and instruction-tuned require BOS as first token.
if (pos == 0) {
tokens.insert(tokens.begin(), gcpp::BOS_ID);
}
return tokens;
}
} // namespace gcpp } // namespace gcpp
#endif // HWY_ONCE #endif // HWY_ONCE

View File

@ -74,10 +74,17 @@ class GemmaTokenizer {
using StreamFunc = std::function<bool(int, float)>; using StreamFunc = std::function<bool(int, float)>;
// If not empty, AcceptFunc is called with token. It should return false for // If not empty, AcceptFunc is called with token. It should return false for
// tokens you don't want to generate and true for tokens you want to generate. // tokens you don't want to generate and true for tokens you want to generate.
using AcceptFunc = std::function<bool(int)>; using AcceptFunc = std::function<bool(int, float)>;
// If not empty, SampleFunc is called with the probability distribution for the // If not empty, SampleFunc is called with the probability distribution for the
// next token, and its return value is used as the next generated token. // next token, and its return value is used as the next generated token.
using SampleFunc = std::function<int(const float*, size_t)>; using SampleFunc = std::function<int(const float*, size_t)>;
// Will be called for layers output with:
// - position in the tokens sequence
// - name of the data, p.ex. "tokens", "block.1", "final_norm"
// - pointer to the data array
// - size of the data array
using LayersOutputFunc =
std::function<void(int, const std::string&, const float*, size_t)>;
struct RuntimeConfig { struct RuntimeConfig {
size_t max_tokens; size_t max_tokens;
@ -88,6 +95,7 @@ struct RuntimeConfig {
StreamFunc stream_token; StreamFunc stream_token;
AcceptFunc accept_token; // if empty, accepts all tokens. AcceptFunc accept_token; // if empty, accepts all tokens.
SampleFunc sample_func; // if empty, uses SampleTopK. SampleFunc sample_func; // if empty, uses SampleTopK.
LayersOutputFunc layers_output; // if not empty, called after each layer.
int eos_id = EOS_ID; int eos_id = EOS_ID;
}; };
@ -97,14 +105,6 @@ struct TimingInfo {
double time_to_first_token = 0.0; double time_to_first_token = 0.0;
}; };
// Will be called for layers output with:
// - position in the tokens sequence
// - name of the data, p.ex. "tokens", "block.1", "final_norm"
// - pointer to the data array
// - size of the data array
using LayersOutputT =
std::function<void(int, const std::string&, const float*, size_t)>;
class Gemma { class Gemma {
public: public:
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
@ -121,12 +121,9 @@ class Gemma {
const ByteStorageT& Prefill() const { return prefill_u8_; } const ByteStorageT& Prefill() const { return prefill_u8_; }
const ByteStorageT& Decode() const { return decode_u8_; } const ByteStorageT& Decode() const { return decode_u8_; }
// layers_output is optional; if set - it will be called with the activations
// output after applying each layer.
void Generate(const RuntimeConfig& runtime_config, void Generate(const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, TimingInfo& timing_info, KVCache& kv_cache, TimingInfo& timing_info);
LayersOutputT* layers_output = nullptr);
private: private:
hwy::ThreadPool& pool_; hwy::ThreadPool& pool_;
@ -141,14 +138,19 @@ class Gemma {
Type weight_type_; Type weight_type_;
}; };
// Adds BOS token and possibly 'turn' annotations, which depend on `training`
// and `pos`, the number of tokens decoded so far; returns the corresponding
// tokens. Asserts that tokenization is successful.
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
ModelTraining training, size_t pos,
std::string& prompt);
// DEPRECATED, call Gemma::Generate directly. // DEPRECATED, call Gemma::Generate directly.
HWY_INLINE void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config, HWY_INLINE void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& /*pool*/, KVCache& kv_cache, hwy::ThreadPool& /*pool*/,
TimingInfo& timing_info, TimingInfo& timing_info) {
LayersOutputT* layers_output) { gemma.Generate(runtime_config, prompt, start_pos, kv_cache, timing_info);
gemma.Generate(runtime_config, prompt, start_pos, kv_cache, timing_info,
layers_output);
} }
} // namespace gcpp } // namespace gcpp

View File

@ -17,89 +17,35 @@
#include <stdio.h> #include <stdio.h>
#include <memory>
#include <random>
#include <string> #include <string>
#include <vector> #include <vector>
// Placeholder for internal header, do not modify. #include "gemma/benchmark_helper.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/cross_entropy.h" #include "hwy/tests/hwy_gtest.h"
#include "gemma/ops.h"
#include "util/app.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/tests/test_util-inl.h"
namespace gcpp { namespace gcpp {
namespace { namespace {
int s_argc = 0; // Shared state. Requires argc/argv, so construct in main and use the same raw
char** s_argv = nullptr; // pointer approach as in benchmarks.cc. Note that the style guide forbids
// non-local static variables with dtors.
GemmaEnv* s_env = nullptr;
class GemmaTest : public ::testing::Test { class GemmaTest : public ::testing::Test {
protected: protected:
static void SetUpTestSuite() { std::string GemmaReply(const std::string& prompt) {
gcpp::AppArgs app(s_argc, s_argv); s_env->SetMaxGeneratedTokens(2048);
gcpp::LoaderArgs loader(s_argc, s_argv); s_env->MutableConfig().temperature = 0.0f; // deterministic
if (const char* err = loader.Validate()) { s_env->MutableConfig().verbosity = 0;
fprintf(stderr, "Insufficient LoaderArgs, skipping e2e tests.\n"); // Using the turn structure worsens results.
} else { const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt);
fprintf(stderr, "Loading model..\n"); auto [response, n] = s_env->QueryModel(tokens);
s_pool = std::make_unique<hwy::ThreadPool>(app.num_threads); return response;
s_gemma = AllocateGemma(loader, *s_pool);
s_kv_cache = KVCache::Create(loader.ModelType());
s_model = loader.ModelType();
}
}
static void TearDownTestSuite() {
s_pool.reset();
s_gemma.reset();
}
std::string GemmaReply(const std::string& prompt_string) {
std::mt19937 gen;
gen.seed(42);
std::vector<int> prompt;
HWY_ASSERT(s_gemma->Tokenizer().Encode(prompt_string, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
prompt.insert(prompt.begin(), BOS_ID);
std::vector<int> response;
auto stream_token = [&response](int token, float) {
response.push_back(token);
return true;
};
gcpp::RuntimeConfig runtime_config = {
.max_tokens = 3072,
.max_generated_tokens = 2048,
.temperature = 1.0,
.verbosity = 0,
.gen = &gen,
.stream_token = stream_token,
};
gcpp::TimingInfo timing_info;
s_gemma->Generate(runtime_config, prompt, /*start_pos=*/0, s_kv_cache,
timing_info, /*layers_output=*/nullptr);
std::string response_text;
HWY_ASSERT(s_gemma->Tokenizer().Decode(response, &response_text));
return response_text;
}
float GemmaCrossEntropy(const std::string& prompt_string) {
std::vector<int> prompt;
HWY_ASSERT(s_gemma->Tokenizer().Encode(prompt_string, &prompt));
prompt.insert(prompt.begin(), BOS_ID);
return ComputeCrossEntropy(*s_gemma, /*max_tokens=*/3072, prompt,
s_kv_cache,
/*verbosity=*/0) /
prompt_string.size();
} }
void TestQuestions(const char* kQA[][2], size_t num_questions) { void TestQuestions(const char* kQA[][2], size_t num_questions) {
if (!s_gemma) return; if (!s_env->GetModel()) return;
for (size_t i = 0; i < num_questions; ++i) { for (size_t i = 0; i < num_questions; ++i) {
fprintf(stderr, "Question %zu\n\n", i + 1); fprintf(stderr, "Question %zu\n\n", i + 1);
std::string response = GemmaReply(kQA[i][0]); std::string response = GemmaReply(kQA[i][0]);
@ -107,18 +53,8 @@ class GemmaTest : public ::testing::Test {
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
} }
} }
static std::unique_ptr<hwy::ThreadPool> s_pool;
static std::unique_ptr<gcpp::Gemma> s_gemma;
static gcpp::KVCache s_kv_cache;
static gcpp::Model s_model;
}; };
/*static*/ std::unique_ptr<hwy::ThreadPool> GemmaTest::s_pool;
/*static*/ std::unique_ptr<gcpp::Gemma> GemmaTest::s_gemma;
/*static*/ gcpp::KVCache GemmaTest::s_kv_cache;
/*static*/ gcpp::Model GemmaTest::s_model;
TEST_F(GemmaTest, Geography) { TEST_F(GemmaTest, Geography) {
static const char* kQA[][2] = { static const char* kQA[][2] = {
{"What is the capital of Hungary?", "Budapest"}, {"What is the capital of Hungary?", "Budapest"},
@ -130,7 +66,7 @@ TEST_F(GemmaTest, Geography) {
TEST_F(GemmaTest, History) { TEST_F(GemmaTest, History) {
static const char* kQA[][2] = { static const char* kQA[][2] = {
{"When was the Battle of Hastings?", "1066"}, {"When was the battle of Hastings?", "1066"},
}; };
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]); static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
TestQuestions(kQA, kNum); TestQuestions(kQA, kNum);
@ -181,42 +117,39 @@ static const char kGettysburg[] = {
"people, for the people, shall not perish from the earth.\n"}; "people, for the people, shall not perish from the earth.\n"};
TEST_F(GemmaTest, CrossEntropySmall) { TEST_F(GemmaTest, CrossEntropySmall) {
if (!s_gemma) return; if (!s_env->GetModel()) return;
static const char kSmall[] = static const char kSmall[] =
"The capital of Hungary is Budapest which is located in Europe."; "The capital of Hungary is Budapest which is located in Europe.";
float entropy = GemmaCrossEntropy(kSmall); float entropy = s_env->CrossEntropy(kSmall);
fprintf(stderr, "per-byte entropy: %f\n", entropy); fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 2.1f : 2.0f); EXPECT_LT(entropy,
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 2.1f : 2.0f);
} }
TEST_F(GemmaTest, CrossEntropyJingleBells) { TEST_F(GemmaTest, CrossEntropyJingleBells) {
if (!s_gemma) return; if (!s_env->GetModel()) return;
float entropy = GemmaCrossEntropy(kJingleBells); float entropy = s_env->CrossEntropy(kJingleBells);
fprintf(stderr, "per-byte entropy: %f\n", entropy); fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 0.9f : 1.8f); EXPECT_LT(entropy,
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 0.9f : 1.8f);
} }
TEST_F(GemmaTest, CrossEntropyGettysburg) { TEST_F(GemmaTest, CrossEntropyGettysburg) {
if (!s_gemma) return; if (!s_env->GetModel()) return;
float entropy = GemmaCrossEntropy(kGettysburg); float entropy = s_env->CrossEntropy(kGettysburg);
fprintf(stderr, "per-byte entropy: %f\n", entropy); fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 0.8f : 1.2f); EXPECT_LT(entropy,
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 0.8f : 1.2f);
} }
} // namespace } // namespace
} // namespace gcpp } // namespace gcpp
int main(int argc, char** argv) { int main(int argc, char** argv) {
{ gcpp::GemmaEnv env(argc, argv);
// Placeholder for internal init, do not modify. gcpp::s_env = &env;
}
// For later use by SetUp. testing::InitGoogleTest(&argc, argv);
gcpp::s_argc = argc;
gcpp::s_argv = argv;
// Probably should be called before SetUpTestSuite.
testing::InitGoogleTest(&gcpp::s_argc, argv);
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();
} }

View File

@ -1668,12 +1668,13 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
std::array<int, k> indices{}; std::array<int, k> indices{};
for (size_t i = 0; i < vocab_size; ++i) { for (size_t i = 0; i < vocab_size; ++i) {
if (probabilities[i] < top_k[k - 1] && if (probabilities[i] < top_k[k - 1] &&
(!accept_token || accept_token(StaticCast<int>(i)))) { (!accept_token || accept_token(StaticCast<int>(i), probabilities[i]))) {
continue; continue;
} }
for (size_t j = 0; j < k; ++j) { for (size_t j = 0; j < k; ++j) {
if (probabilities[i] > top_k[j] && if (probabilities[i] > top_k[j] &&
(!accept_token || accept_token(StaticCast<int>(i)))) { (!accept_token ||
accept_token(StaticCast<int>(i), probabilities[i]))) {
// shift elements by 1, insert the new value, move on to next value // shift elements by 1, insert the new value, move on to next value
for (size_t idx = k - 1; idx > j; --idx) { for (size_t idx = k - 1; idx > j; --idx) {
top_k[idx] = top_k[idx - 1]; top_k[idx] = top_k[idx - 1];

View File

@ -15,27 +15,22 @@
// Command line text interface to gemma. // Command line text interface to gemma.
#include <ctime>
#include <iostream> #include <iostream>
#include <random> #include <random>
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <thread> // NOLINT
#include <vector> #include <vector>
// Placeholder for internal header, do not modify. // Placeholder for internal header, do not modify.
#include "compression/compress.h" #include "gemma/benchmark_helper.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/gemma.h" // Gemma #include "gemma/gemma.h" // Gemma
#include "util/app.h" #include "util/app.h"
#include "util/args.h" // HasHelp #include "util/args.h" // HasHelp
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/per_target.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
#include "hwy/timer.h"
#if (!defined(HWY_VERSION_LT) || HWY_VERSION_LT(1, 2)) && !HWY_IDE #if (!defined(HWY_VERSION_LT) || HWY_VERSION_LT(1, 2)) && !HWY_IDE
#error "Please update to version 1.2 of github.com/google/highway." #error "Please update to version 1.2 of github.com/google/highway."
@ -57,56 +52,6 @@ static constexpr std::string_view kAsciiArtBanner = R""(
|___/ |_| |_| |___/ |_| |_|
)""; )"";
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
loader.Print(app.verbosity);
inference.Print(app.verbosity);
app.Print(app.verbosity);
if (app.verbosity >= 2) {
time_t now = time(nullptr);
char* dt = ctime(&now); // NOLINT
std::cout << "Date & Time : " << dt
<< "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize
<< "\n"
<< "Hardware concurrency : "
<< std::thread::hardware_concurrency() << "\n"
<< "Instruction set : "
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
<< hwy::VectorBytes() * 8 << " bits)" << "\n";
char cpu100[100];
if (hwy::platform::GetCpuString(cpu100)) {
std::cout << "CPU : " << cpu100 << "\n";
}
std::cout << "Compiled config : " << CompiledConfig() << "\n"
<< "Weight Type : "
<< gcpp::StringFromType(loader.WeightType()) << "\n"
<< "EmbedderInput Type : "
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n";
}
}
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
gcpp::AppArgs& app) {
std::cerr
<< kAsciiArtBanner
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
"==========================================================\n\n"
"To run gemma.cpp, you need to "
"specify 3 required model loading arguments:\n"
" --tokenizer\n"
" --weights\n"
" --model.\n";
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
"--weights 2b-it-sfp.sbs --model 2b-it\n";
std::cerr << "\n*Model Loading Arguments*\n\n";
loader.Help();
std::cerr << "\n*Inference Arguments*\n\n";
inference.Help();
std::cerr << "\n*Application Arguments*\n\n";
app.Help();
std::cerr << "\n";
}
// The main Read-Eval-Print Loop. // The main Read-Eval-Print Loop.
void ReplGemma(gcpp::Gemma& model, ModelTraining training, void ReplGemma(gcpp::Gemma& model, ModelTraining training,
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
@ -118,12 +63,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
int prompt_size{}; int prompt_size{};
std::mt19937 gen; std::mt19937 gen;
if (args.deterministic) { InitGenerator(args, gen);
gen.seed(42);
} else {
std::random_device rd;
gen.seed(rd());
}
// callback function invoked for each generated token. // callback function invoked for each generated token.
auto stream_token = [&abs_pos, &current_pos, &args, &gen, &prompt_size, auto stream_token = [&abs_pos, &current_pos, &args, &gen, &prompt_size,
@ -162,7 +102,6 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
while (abs_pos < args.max_tokens) { while (abs_pos < args.max_tokens) {
std::string prompt_string; std::string prompt_string;
std::vector<int> prompt;
current_pos = 0; current_pos = 0;
{ {
PROFILER_ZONE("Gen.input"); PROFILER_ZONE("Gen.input");
@ -192,30 +131,11 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
continue; continue;
} }
if (training == ModelTraining::GEMMA_IT) { const std::vector<int> prompt =
// For instruction-tuned models: add control tokens. WrapAndTokenize(model.Tokenizer(), training, abs_pos, prompt_string);
prompt_string = "<start_of_turn>user\n" + prompt_string +
"<end_of_turn>\n<start_of_turn>model\n";
if (abs_pos != 0) {
// Prepend "<end_of_turn>" token if this is a multi-turn dialogue
// continuation.
prompt_string = "<end_of_turn>\n" + prompt_string;
}
}
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
if (abs_pos == 0) {
prompt.insert(prompt.begin(), gcpp::BOS_ID);
}
prompt_size = prompt.size(); prompt_size = prompt.size();
std::cerr << "\n" std::cerr << "\n"
<< "[ Reading prompt ] " << std::flush; << "[ Reading prompt ] " << std::flush;
if constexpr (kVerboseLogTokens) { if constexpr (kVerboseLogTokens) {
for (int i = 0; i < prompt_size; ++i) { for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
@ -301,17 +221,20 @@ int main(int argc, char** argv) {
gcpp::AppArgs app(argc, argv); gcpp::AppArgs app(argc, argv);
if (gcpp::HasHelp(argc, argv)) { if (gcpp::HasHelp(argc, argv)) {
ShowHelp(loader, inference, app); std::cerr << gcpp::kAsciiArtBanner;
gcpp::ShowHelp(loader, inference, app);
return 0; return 0;
} }
if (const char* error = loader.Validate()) { if (const char* error = loader.Validate()) {
ShowHelp(loader, inference, app); std::cerr << gcpp::kAsciiArtBanner;
gcpp::ShowHelp(loader, inference, app);
HWY_ABORT("\nInvalid args: %s", error); HWY_ABORT("\nInvalid args: %s", error);
} }
if (const char* error = inference.Validate()) { if (const char* error = inference.Validate()) {
ShowHelp(loader, inference, app); std::cerr << gcpp::kAsciiArtBanner;
gcpp::ShowHelp(loader, inference, app);
HWY_ABORT("\nInvalid args: %s", error); HWY_ABORT("\nInvalid args: %s", error);
} }

View File

@ -13,19 +13,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Command line text interface to gemma. #include <stdio.h>
#include <fstream> #include <algorithm>
#include <iostream>
#include <random>
#include <set>
#include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
// Placeholder for internal header, do not modify. #include "compression/io.h" // Path
#include "gemma/benchmark_helper.h"
#include "gemma/gemma.h" // Gemma #include "gemma/gemma.h" // Gemma
#include "util/app.h" #include "util/args.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h" #include "hwy/highway.h"
@ -34,164 +31,134 @@
namespace gcpp { namespace gcpp {
void JsonGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache, struct JsonArgs : public ArgsBase<JsonArgs> {
hwy::ThreadPool& pool, JsonArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
const InferenceArgs& args, int verbosity,
std::string& eot_line) {
PROFILER_ZONE("Gen.misc");
// token index within the current turn
int max_tokens = 4096;
std::mt19937 gen; Path input;
if (args.deterministic) {
gen.seed(42); // Returns error string or nullptr if OK.
} else { const char* Validate() const {
std::random_device rd; if (input.Empty()) return "Must specify --input";
gen.seed(rd()); if (!input.Exists()) return "--input file does not exist";
return nullptr;
} }
float answers = 0.0; template <class Visitor>
float correct_answers = 0.0; void ForEach(const Visitor& visitor) {
visitor(input, "input", Path(), "Full pathname of mmlu.json.");
};
};
std::ifstream fJson("/tmp/mmlu.json"); // Linear search for a few tokens is faster than std::set.
std::stringstream buffer; // TODO: instead of accepting for each vocab entry, filter the logits once.
buffer << fJson.rdbuf(); class TokenSet {
auto json = nlohmann::json::parse(buffer.str()); public:
TokenSet(const GemmaTokenizer& tokenizer,
std::vector<std::string> accept_tokens = {"A", "B", "C", "D"}; const std::vector<std::string>& strings) {
std::set<int> accept_token_set{}; all_tokens_.reserve(strings.size());
for (const std::string& accept_token : accept_tokens) { for (const std::string& str : strings) {
std::vector<int> accept_token_ids; std::vector<int> tokens;
HWY_ASSERT(model.Tokenizer().Encode(accept_token, &accept_token_ids)); fprintf(stderr, "%s -> ", str.c_str());
accept_token_set.insert(accept_token_ids.begin(), accept_token_ids.end()); HWY_ASSERT(tokenizer.Encode(str, &tokens));
for (int token : tokens) {
fprintf(stderr, "%d, ", token);
all_tokens_.push_back(token);
}
fprintf(stderr, "\n");
}
} }
for (auto sample : json["samples"]) { bool Contains(int token) const {
int abs_pos = 0; // absolute token index over all turns return std::find(all_tokens_.begin(), all_tokens_.end(), token) !=
int current_pos = 0; all_tokens_.end();
int prompt_size{}; }
// cout << "prompt:" << sample["prompt"] << endl; private:
const std::string& prompt_string = sample["prompt"]; std::vector<int> all_tokens_;
std::vector<int> prompt; };
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt)); void Run(GemmaEnv& env, JsonArgs& json) {
prompt_size = prompt.size(); PROFILER_ZONE("Run.all");
const std::string& correct_answer = accept_tokens[sample["input_label"]]; float answers = 0.0f;
float correct_answers = 0.0f;
// max_tokens = prompt_size + max_tokens; auto json_data = nlohmann::json::parse(ReadFileToString(json.input));
const std::vector<std::string> accept_strings = {
"A", "B", "C", "D", //
" A", " B", " C", " D", //
"**", "**:", ":**", "The", "Answer", "is", ":", "."};
const TokenSet accept_set(env.GetModel()->Tokenizer(), accept_strings);
for (auto sample : json_data["samples"]) {
const int id = sample["i"];
fprintf(stderr, "Processing question %d\n", id);
const std::string& correct_answer = accept_strings[sample["input_label"]];
std::string prompt_string = sample["prompt"];
// AcceptFunc restricts the output to one of these four tokens, so make an
// effort to steer the model towards that. See
// https://huggingface.co/blog/open-llm-leaderboard-mmlu
prompt_string +=
"What is start of the line with the correct answer? "
"Do not include any justifications or explanations. Reply only with a "
"letter.";
const std::vector<int> prompt =
WrapAndTokenize(env.GetModel()->Tokenizer(), env.ModelTrainingType(),
/*pos=*/0, prompt_string);
const size_t prompt_size = prompt.size();
std::vector<int> predicted_token_ids; std::vector<int> predicted_token_ids;
predicted_token_ids.reserve(max_tokens); predicted_token_ids.reserve(4096);
auto stream_token = [&current_pos, &prompt_size, &predicted_token_ids, size_t current_pos = 0;
&accept_token_set](int token, float proba) { const StreamFunc stream_token = [&current_pos, prompt_size,
&predicted_token_ids](int token,
float proba) {
PROFILER_ZONE("Stream");
++current_pos; ++current_pos;
if (current_pos > prompt_size) { if (current_pos > prompt_size) {
predicted_token_ids.push_back(token); predicted_token_ids.push_back(token);
// If the generated token is in the accepted token set, return False.
// This will stop further generation.
return accept_token_set.find(token) == accept_token_set.end();
} }
return true; return true;
}; };
const AcceptFunc accept_token = [&current_pos, &prompt_size, // Although " A" is a token, it is difficult to associate that with the
&accept_token_set](int token) { // correct answer. Only accepting certain tokens is risky: (A) is easily
// i.e. we have no constraints on accepted tokens // confused with the word "A".
if (accept_token_set.empty()) {
return true;
}
if (current_pos >= prompt_size) {
return accept_token_set.find(token) != accept_token_set.end();
} else {
// auto-accept early tokens
return true;
}
};
gcpp::TimingInfo timing_info; gcpp::TimingInfo timing_info;
gcpp::RuntimeConfig runtime_config = { gcpp::RuntimeConfig runtime_config = {
.max_tokens = args.max_tokens, .max_tokens = env.MaxTokens(),
.max_generated_tokens = args.max_generated_tokens, .max_generated_tokens = 30,
.temperature = args.temperature, .temperature = 0.0f,
.verbosity = verbosity, .verbosity = env.Verbosity(),
.gen = &gen, .gen = &env.MutableGen(),
.stream_token = stream_token, .stream_token = stream_token,
.accept_token = accept_token,
}; };
model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info); env.GetModel()->Generate(runtime_config, prompt, /*pos=*/0,
env.MutableKVCache(), timing_info);
std::string output_string; std::string output_string = env.StringFromTokens(predicted_token_ids);
HWY_ASSERT(model.Tokenizer().Decode(predicted_token_ids, &output_string)); fprintf(stderr, "Correct %s, model '%s'\n", correct_answer.c_str(),
std::cout << "QuestionId: " << sample["i"] << "; " output_string.c_str());
<< "Predicted Answer: " << output_string << "; "
<< "Correct Answer: " << correct_answer << std::endl;
answers += 1.0; answers += 1.0f;
if (output_string == correct_answer) { if (output_string == correct_answer) {
correct_answers += 1.0; correct_answers += 1.0f;
} }
std::cout << "Running accuracy = " << "[" fprintf(stderr, "%.0f/%.0f = %.2f%%\n", correct_answers, answers,
<< static_cast<int>(correct_answers) << "/" correct_answers / answers);
<< static_cast<int>(answers) << "]" << " = "
<< correct_answers / answers << std::endl;
} }
} }
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
loader.Print(app.verbosity);
inference.Print(app.verbosity);
app.Print(app.verbosity);
}
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
PROFILER_ZONE("Run.misc");
hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning workers to cores helps.
if (app.num_threads > 10) {
PinWorkersToCores(pool);
}
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
JsonGemma(model, kv_cache, pool, inference, app.verbosity, app.eot_line);
}
} // namespace gcpp } // namespace gcpp
int main(int argc, char** argv) { int main(int argc, char** argv) {
{ {
PROFILER_ZONE("Startup.misc"); PROFILER_ZONE("Startup.all");
gcpp::GemmaEnv env(argc, argv);
// Placeholder for internal init, do not modify. gcpp::JsonArgs json(argc, argv);
gcpp::AbortIfInvalidArgs(json);
gcpp::LoaderArgs loader(argc, argv); gcpp::Run(env, json);
gcpp::InferenceArgs inference(argc, argv);
gcpp::AppArgs app(argc, argv);
if (const char* error = loader.Validate()) {
fprintf(stderr,
"\ngemma.cpp\n---------\n\nTo run gemma.cpp, you need to "
"specify 3 required model loading arguments: --tokenizer, "
"--compressed_weights, "
"and --model.\n\nModel Loading Arguments\n\n");
loader.Help();
fprintf(stderr, "\nInference Arguments\n\n");
inference.Help();
fprintf(stderr, "\nApplication Arguments\n\n");
app.Help();
fprintf(stderr, "\n\n");
HWY_ABORT("\nInvalid args: %s", error);
}
gcpp::Run(loader, inference, app);
} }
PROFILER_PRINT_RESULTS(); // Must call outside the zone above. PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
return 0; return 0;

View File

@ -197,6 +197,14 @@ static inline HWY_MAYBE_UNUSED bool HasHelp(int argc, char* argv[]) {
return false; return false;
} }
template <class TArgs>
static inline HWY_MAYBE_UNUSED void AbortIfInvalidArgs(TArgs& args) {
if (const char* err = args.Validate()) {
args.Help();
HWY_ABORT("Problem with args: %s\n", err);
}
}
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_ARGS_H_ #endif // THIRD_PARTY_GEMMA_CPP_UTIL_ARGS_H_