Compare commits

...

5 Commits

Author SHA1 Message Date
The gemma.cpp Authors 2228055bb8 Internal change.
PiperOrigin-RevId: 643330703
2024-06-14 06:53:41 -07:00
Jan Wassenberg 29c0c574e6 Integrate matmul into FFW: 4.3x prefill speedup
```
before, bf16:
27.2929 prefill tokens / sec
17.2114 tokens / sec

after, bf16
116.496 prefill tokens / sec
17.5391 tokens / sec
```

PiperOrigin-RevId: 643328437
2024-06-14 06:32:26 -07:00
Ray Smith 198326a682 Removed now redundant non-batch matmul
PiperOrigin-RevId: 643317187
2024-06-14 05:13:36 -07:00
Andrey Vlasov b17631c95f Implement a missing (bf16, f32) tiled MatMul kernel.
PiperOrigin-RevId: 643313676
2024-06-14 04:54:40 -07:00
Jan Wassenberg d3c6a45b59 Major duplicated code reduction in test/benchmarks
Helper functions to tokenize/wrap
Move LayersOutputFunc into RuntimeConfig
AcceptFunc passes the probability
Implement StringFromType using the parser, and verify results match

PiperOrigin-RevId: 643255119
2024-06-14 00:16:25 -07:00
19 changed files with 878 additions and 1029 deletions

View File

@ -113,12 +113,8 @@ cc_library(
cc_library(
name = "cross_entropy",
srcs = [
"gemma/cross_entropy.cc",
],
hdrs = [
"gemma/cross_entropy.h",
],
srcs = ["gemma/cross_entropy.cc"],
hdrs = ["gemma/cross_entropy.h"],
deps = [
":common",
":gemma_lib",
@ -149,6 +145,25 @@ cc_library(
],
)
cc_library(
name = "benchmark_helper",
srcs = ["gemma/benchmark_helper.cc"],
hdrs = ["gemma/benchmark_helper.h"],
deps = [
":app",
":args",
":common",
":cross_entropy",
":gemma_lib",
# Placeholder for internal dep, do not remove.,
"@benchmark//:benchmark",
"//compression:compress",
"@hwy//:hwy",
"@hwy//:nanobenchmark",
"@hwy//:thread_pool",
],
)
cc_test(
name = "gemma_test",
srcs = ["gemma/gemma_test.cc"],
@ -161,11 +176,11 @@ cc_test(
deps = [
":app",
":args",
":benchmark_helper",
":common",
":cross_entropy",
":gemma_lib",
":ops",
# Placeholder for internal dep, do not remove.,
"@googletest//:gtest_main",
"//compression:io",
"@hwy//:hwy_test_util",
@ -179,6 +194,7 @@ cc_binary(
deps = [
":app",
":args",
":benchmark_helper",
":common",
":gemma_lib",
# Placeholder for internal dep, do not remove.,
@ -213,10 +229,10 @@ cc_binary(
deps = [
":app",
":args",
":benchmark_helper",
":common",
":cross_entropy",
":gemma_lib",
# Placeholder for internal dep, do not remove.,
"//compression:io",
"@hwy//:hwy",
"@hwy//:nanobenchmark",
@ -230,7 +246,6 @@ cc_binary(
srcs = ["gemma/benchmarks.cc"],
deps = [
":benchmark_helper",
# Placeholder for internal dep, do not remove.,
"@benchmark//:benchmark",
],
)
@ -243,8 +258,8 @@ cc_binary(
deps = [
":app",
":args",
":benchmark_helper",
":gemma_lib",
# Placeholder for internal dep, do not remove.,
"//compression:io",
"@hwy//:hwy",
"@hwy//:thread_pool",
@ -257,8 +272,10 @@ cc_binary(
srcs = ["gemma/run_mmlu.cc"],
deps = [
":app",
":args",
":benchmark_helper",
":gemma_lib",
# Placeholder for internal dep, do not remove.,
"//compression:io",
"@hwy//:hwy",
"@hwy//:profiler",
"@hwy//:thread_pool",
@ -318,25 +335,6 @@ cc_library(
],
)
cc_library(
name = "benchmark_helper",
srcs = [
"gemma/benchmark_helper.cc",
],
hdrs = [
"gemma/benchmark_helper.h",
],
deps = [
":app",
":common",
":gemma_lib",
"@benchmark//:benchmark",
"@hwy//:hwy",
"@hwy//:nanobenchmark",
"@hwy//:thread_pool",
],
)
cc_test(
name = "backward_scalar_test",
size = "large",

View File

@ -23,6 +23,8 @@
#include <string>
#include <utility> // std::move
#include "hwy/base.h"
namespace gcpp {
// Forward-declare to break the circular dependency: OpenFileOrNull returns
@ -77,12 +79,30 @@ struct Path {
return path;
}
bool Empty() const { return path.empty(); }
// Returns whether the file existed when this was called.
bool Exists() const { return !!OpenFileOrNull(*this, "r"); }
std::string path;
};
static inline HWY_MAYBE_UNUSED std::string ReadFileToString(const Path& path) {
std::unique_ptr<File> file = OpenFileOrNull(path, "r");
if (!file) {
HWY_ABORT("Failed to open %s", path.path.c_str());
}
const size_t size = file->FileSize();
if (size == 0) {
HWY_ABORT("Empty file %s", path.path.c_str());
}
std::string content(size, ' ');
if (!file->Read(0, size, content.data())) {
HWY_ABORT("Failed to read %s", path.path.c_str());
}
return content;
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_

View File

@ -1,146 +1,83 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <random>
#include <string>
#include <utility>
#include <vector>
// Placeholder for internal header, do not modify.
#include "compression/io.h"
#include "gemma/gemma.h"
#include "util/app.h"
#include "gemma/benchmark_helper.h"
#include "gemma/gemma.h" // LayersOutputFunc
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "nlohmann/json.hpp"
using json = nlohmann::json;
class PromptArgs : public gcpp::ArgsBase<PromptArgs> {
namespace gcpp {
class PromptArgs : public ArgsBase<PromptArgs> {
public:
PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
gcpp::Path layers_output;
Path layers_output; // optional
std::string prompt;
// Returns error string or nullptr if OK.
const char* Validate() const {
if (prompt.empty()) return "Must specify --prompt";
return nullptr;
}
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(layers_output.path, "layers_output", std::string(""),
visitor(layers_output, "layers_output", Path(""),
"Path to store layers output", 2);
visitor(prompt, "prompt", std::string(""), "Prompt to the model", 2);
}
};
std::pair<std::string, int> QueryModel(
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input,
gcpp::LayersOutputT* layers_output) {
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer().Encode(input, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
prompt.insert(prompt.begin(), gcpp::BOS_ID);
std::string res;
size_t total_tokens = 0;
std::mt19937 gen;
gen.seed(42);
auto stream_token = [&res, &total_tokens, &model](int token, float) {
++total_tokens;
std::string token_text;
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
res += token_text;
return true;
};
if (app.verbosity >= 2) {
std::cout << args.max_tokens << " " << args.max_generated_tokens << " "
<< args.temperature;
}
gcpp::TimingInfo timing_info;
gcpp::RuntimeConfig runtime_config = {
.max_tokens = args.max_tokens,
.max_generated_tokens = args.max_generated_tokens,
.temperature = args.temperature,
.verbosity = app.verbosity,
.gen = &gen,
.stream_token = stream_token,
};
model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, timing_info,
layers_output);
return {res, total_tokens};
}
class OutputJsonLogger {
public:
json json_output;
gcpp::LayersOutputT layers_output_log_f =
[this](int pos, const std::string& key, const float* values,
size_t values_len) {
std::vector<float> v{values, values + values_len};
json_output[std::to_string(pos)][key] = v;
};
};
/* Run this in the same way as gemma, p.ex.:
./debug_prompt --tokenizer tokenizer.spm --model 2b-it --weights \
2b-it-sfp.sbs --prompt "..." --layers_output [path]
*/
int main(int argc, char** argv) {
{
// Placeholder for internal init, do not modify.
}
gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs args(argc, argv); // inference
gcpp::AppArgs app(argc, argv);
int Run(int argc, char** argv) {
PromptArgs prompt_args(argc, argv);
AbortIfInvalidArgs(prompt_args);
if (const char* error = loader.Validate()) {
HWY_ABORT("\nInvalid loader args: %s", error);
}
if (const char* error = args.Validate()) {
HWY_ABORT("\nInvalid inference args: %s", error);
}
const bool log_layers_output = !prompt_args.layers_output.path.empty();
OutputJsonLogger json_logger;
gcpp::LayersOutputT* layers_output =
log_layers_output ? &json_logger.layers_output_log_f : nullptr;
json json_output;
GemmaEnv env(argc, argv);
env.MutableConfig().layers_output =
prompt_args.layers_output.Empty()
? LayersOutputFunc()
: [&json_output](int pos, const std::string& key, const float* values,
size_t values_len) {
std::vector<float> v{values, values + values_len};
json_output[std::to_string(pos)][key] = v;
};
hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning workers to cores helps.
if (app.num_threads > 10) {
gcpp::PinWorkersToCores(pool);
}
const auto [answer, token_count] = env.QueryModel(prompt_args.prompt);
std::cout << answer.substr(prompt_args.prompt.size()) << "\n" << std::flush;
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
const std::string& prompt = prompt_args.prompt;
if (prompt.empty()) {
std::cout << "Please specify --prompt" << std::endl;
return EXIT_FAILURE;
}
const auto [answer, token_count] = QueryModel(
model, args, app, kv_cache, pool, prompt, layers_output);
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
if (log_layers_output) {
if (env.MutableConfig().layers_output) {
std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out);
if (!output_f) {
std::cout << "Opening file failed" << std::endl;
return EXIT_FAILURE;
}
output_f << json_logger.json_output.dump();
if (!output_f) {
std::cout << "Writing to file failed" << std::endl;
return EXIT_FAILURE;
}
if (!output_f) HWY_ABORT("Opening layer output file failed");
output_f << json_output.dump();
if (!output_f) HWY_ABORT("Writing to layer output file failed");
output_f.close();
}
return EXIT_SUCCESS;
return 0;
}
} // namespace gcpp
int main(int argc, char** argv) { return gcpp::Run(argc, argv); }

View File

@ -1,36 +1,36 @@
#include <stdio.h>
#include <algorithm>
#include <cstdlib> // EXIT_FAILURE
#include <fstream>
#include <iostream>
#include <ostream>
#include <random>
#include <sstream>
#include <string>
#include <utility> // std::pair
#include <vector>
// Placeholder for internal header, do not modify.
#include "compression/io.h" // Path
#include "gemma/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/cross_entropy.h"
#include "gemma/gemma.h"
#include "util/app.h"
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/timer.h"
#include "nlohmann/json.hpp"
namespace gcpp {
using json = nlohmann::json;
class BenchmarkArgs : public gcpp::ArgsBase<BenchmarkArgs> {
class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
public:
BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
gcpp::Path goldens;
gcpp::Path summarize_text;
gcpp::Path cross_entropy;
gcpp::Path trivia_qa;
Path goldens;
Path summarize_text;
Path cross_entropy;
Path trivia_qa;
size_t max_questions;
size_t batch_tokens;
@ -53,61 +53,6 @@ class BenchmarkArgs : public gcpp::ArgsBase<BenchmarkArgs> {
}
};
void LogSpeedStats(const double time_start, size_t total_tokens) {
const double time_end = hwy::platform::Now();
const double time_elapsed = time_end - time_start;
const double tok_sec = total_tokens / time_elapsed;
std::cout << total_tokens << " tokens in " << time_elapsed << " seconds"
<< " [" << tok_sec << " tokens / sec" << "]\n";
}
std::pair<std::string, int> QueryModel(
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input) {
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer().Encode(input, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
prompt.insert(prompt.begin(), 2);
std::string res;
size_t total_tokens = 0;
std::mt19937 gen;
gen.seed(42);
const double time_start = hwy::platform::Now();
auto stream_token = [&res, &total_tokens, &time_start, &app, &model](
int token, float) {
++total_tokens;
std::string token_text;
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
res += token_text;
if (app.verbosity >= 1 && total_tokens % 100 == 0) {
LogSpeedStats(time_start, total_tokens);
}
return true;
};
if (app.verbosity >= 2) {
std::cout << args.max_tokens << " " << args.max_generated_tokens << " "
<< args.temperature;
}
gcpp::TimingInfo timing_info;
gcpp::RuntimeConfig runtime_config = {
.max_tokens = args.max_tokens,
.max_generated_tokens = args.max_generated_tokens,
.temperature = args.temperature,
.verbosity = app.verbosity,
.gen = &gen,
.stream_token = stream_token,
};
model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, timing_info,
/*layers_output=*/nullptr);
if (app.verbosity >= 1) {
LogSpeedStats(time_start, total_tokens);
}
return {res, total_tokens};
}
std::vector<std::pair<std::string, std::string>> load_goldens(
const std::string& path) {
std::ifstream goldens_file(path);
@ -129,28 +74,14 @@ std::vector<std::pair<std::string, std::string>> load_goldens(
return res;
}
std::string ReadFile(const gcpp::Path& path) {
std::ifstream text_file(path.path);
if (!text_file) {
std::cout << "Could not open file: " << path.path << "\n" << std::flush;
return {};
}
std::stringstream buffer;
buffer << text_file.rdbuf();
return buffer.str();
}
int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
hwy::ThreadPool& pool, const std::string& golden_path) {
const std::vector<std::pair<std::string, std::string>> queries_answers =
int BenchmarkGoldens(GemmaEnv& env, const std::string& golden_path) {
std::vector<std::pair<std::string, std::string>> queries_answers =
load_goldens(golden_path);
int correct_answers = 0;
int total_tokens = 0;
size_t correct_answers = 0;
size_t total_tokens = 0;
const double time_start = hwy::platform::Now();
for (const auto& [question, expected_answer] : queries_answers) {
const auto [answer, token_count] =
QueryModel(model, args, app, kv_cache, pool, question);
for (auto& [question, expected_answer] : queries_answers) {
const auto [answer, token_count] = env.QueryModel(question);
total_tokens += token_count;
if (answer.find(expected_answer) != std::string::npos) {
correct_answers++;
@ -172,28 +103,22 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
return EXIT_SUCCESS;
}
int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args,
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
hwy::ThreadPool& pool, const gcpp::Path& text) {
int BenchmarkSummary(GemmaEnv& env, const Path& text) {
std::string prompt("Here is some text to summarize:\n");
prompt.append(ReadFile(text));
prompt.append(ReadFileToString(text));
prompt.append("\nSummarize this text.\n");
const double time_start = hwy::platform::Now();
const auto [answer, token_count] =
QueryModel(model, args, app, kv_cache, pool, prompt);
const auto [answer, token_count] = env.QueryModel(prompt);
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
LogSpeedStats(time_start, token_count);
return EXIT_SUCCESS;
}
int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
gcpp::InferenceArgs& args, gcpp::AppArgs& app,
hwy::ThreadPool& pool, const gcpp::Path& text,
int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
size_t batch_tokens) {
std::string input = ReadFile(text);
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer().Encode(input, &prompt));
prompt.resize(std::min<size_t>(args.max_tokens, prompt.size()));
std::string input = ReadFileToString(text);
std::vector<int> prompt = env.Tokenize(input);
prompt.resize(std::min<size_t>(env.MaxTokens(), prompt.size()));
std::cout << "Number of input tokens: " << prompt.size() << "\n";
const double time_start = hwy::platform::Now();
float total_entropy = 0.0f;
@ -203,13 +128,12 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
std::vector<int> prompt_slice(prompt.begin() + pos,
prompt.begin() + pos + num_tokens);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(model_type);
float entropy = ComputeCrossEntropy(model, num_tokens, prompt_slice,
kv_cache, app.verbosity);
KVCache kv_cache = KVCache::Create(env.ModelType());
float entropy = ComputeCrossEntropy(
*env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
total_entropy += entropy;
LogSpeedStats(time_start, pos + num_tokens);
std::string text_slice;
HWY_ASSERT(model.Tokenizer().Decode(prompt_slice, &text_slice));
std::string text_slice = env.StringFromTokens(prompt_slice);
total_input_len += text_slice.size();
printf("Total cross entropy: %f [cumulative: %f]\n",
entropy, total_entropy);
@ -219,23 +143,19 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
return EXIT_SUCCESS;
}
int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
hwy::ThreadPool& pool, const gcpp::Path& json_file,
int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file,
size_t max_questions) {
std::ifstream trivia_file(json_file.path);
if (!trivia_file) {
std::cout << "Could not load file: " << json_file.path << "\n"
<< std::flush;
return EXIT_FAILURE;
HWY_ABORT("Could not load file %s\n", json_file.path.c_str());
}
std::string line;
size_t correct_answers = 0;
size_t i = 0;
while (std::getline(trivia_file, line)) {
json data = json::parse(line);
const auto [answer, token_count] = QueryModel(
model, args, app, kv_cache, pool, data["question"]);
std::string q(data["question"]);
const auto [answer, token_count] = env.QueryModel(q);
std::cout << answer << "\n";
bool correct = false;
for (const std::string expected : data["answer"]["aliases"]) {
@ -256,52 +176,25 @@ int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
return EXIT_SUCCESS;
}
/* Run this in the same way as gemma, p.ex.:
./benchmark --tokenizer tokenizer.spm --model 2b-it --weights \
2b-it-sfp.sbs --goldens_dir "../goldens"
*/
} // namespace gcpp
int main(int argc, char** argv) {
{
// Placeholder for internal init, do not modify.
}
gcpp::GemmaEnv env(argc, argv);
gcpp::BenchmarkArgs benchmark_args(argc, argv);
gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs args(argc, argv); // inference
gcpp::AppArgs app(argc, argv);
BenchmarkArgs benchmark_args(argc, argv);
if (const char* error = loader.Validate()) {
HWY_ABORT("\nInvalid loader args: %s", error);
}
if (const char* error = args.Validate()) {
HWY_ABORT("\nInvalid inference args: %s", error);
}
hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning workers to cores helps.
if (app.num_threads > 10) {
gcpp::PinWorkersToCores(pool);
}
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
if (!benchmark_args.goldens.path.empty()) {
if (!benchmark_args.goldens.Empty()) {
const std::string golden_path =
benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt";
return BenchmarkGoldens(model, args, app, kv_cache, pool, golden_path);
} else if (!benchmark_args.summarize_text.path.empty()) {
return BenchmarkSummary(model, args, app, kv_cache, pool,
benchmark_args.summarize_text);
} else if (!benchmark_args.cross_entropy.path.empty()) {
return BenchmarkCrossEntropy(model, loader.ModelType(), args, app,
pool, benchmark_args.cross_entropy,
benchmark_args.goldens.path + "/" +
gcpp::ModelString(env.ModelType(), env.ModelTrainingType()) + ".txt";
return BenchmarkGoldens(env, golden_path);
} else if (!benchmark_args.summarize_text.Empty()) {
return BenchmarkSummary(env, benchmark_args.summarize_text);
} else if (!benchmark_args.cross_entropy.Empty()) {
return BenchmarkCrossEntropy(env, benchmark_args.cross_entropy,
benchmark_args.batch_tokens);
} else if (!benchmark_args.trivia_qa.path.empty()) {
return BenchmarkTriviaQA(model, args, app, kv_cache, pool,
benchmark_args.trivia_qa,
} else if (!benchmark_args.trivia_qa.Empty()) {
return BenchmarkTriviaQA(env, benchmark_args.trivia_qa,
benchmark_args.max_questions);
}
std::cout << "No benchmark command given." << "\n" << std::flush;
return EXIT_FAILURE;
HWY_ABORT("No benchmark command given.");
}

View File

@ -14,70 +14,95 @@
// limitations under the License.
#include "gemma/benchmark_helper.h"
#include <cstdlib> // EXIT_FAILURE
#include <stdio.h>
#include <time.h>
#include <iostream>
#include <memory>
#include <ostream>
#include <random>
#include <string>
#include <thread> // NOLINT
#include <utility> // std::pair
#include <vector>
#include "gemma/common.h"
// Placeholder for internal header, do not modify.
#include "compression/compress.h" // TypeName
#include "gemma/common.h" // StringFromType
#include "gemma/cross_entropy.h"
#include "gemma/gemma.h"
#include "util/app.h"
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/per_target.h"
#include "hwy/timer.h"
namespace gcpp {
GemmaEnv::GemmaEnv(int argc, char** argv)
: loader_(argc, argv), inference_args_(argc, argv), app_(argc, argv),
pool_(app_.num_threads) {
if (const char* error = loader_.Validate()) {
HWY_ABORT("\nInvalid loader args: %s", error);
}
if (const char* error = inference_args_.Validate()) {
HWY_ABORT("\nInvalid inference args: %s", error);
}
// For many-core, pinning workers to cores helps.
if (app_.num_threads > 10) {
gcpp::PinWorkersToCores(pool_);
}
void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
if (inference.deterministic) {
// Nothing up my sleeve number, at least some upper bits set.
gen.seed(0x12345678);
} else {
// Depending on the library implementation, this may still be deterministic.
std::random_device rd;
gen.seed(rd());
}
}
GemmaEnv::GemmaEnv(int argc, char** argv)
: loader_(argc, argv),
inference_args_(argc, argv),
app_(argc, argv),
pool_(app_.num_threads) {
{
// Placeholder for internal init, do not modify.
}
// For many-core, pinning workers to cores helps.
if (app_.num_threads > 10) {
gcpp::PinWorkersToCores(pool_);
}
AbortIfInvalidArgs(inference_args_);
if (const char* err = loader_.Validate()) {
loader_.Help();
fprintf(stderr, "Skipping model load because: %s\n", err);
} else {
fprintf(stderr, "Loading model...\n");
model_ = AllocateGemma(loader_, pool_);
kv_cache_ = KVCache::Create(loader_.ModelType());
gen_.seed(42);
}
std::pair<std::string, int> GemmaEnv::QueryModel(const std::string& input) {
std::string prompt_string = input;
if (loader_.ModelTrainingType() == ModelTraining::GEMMA_IT) {
// For instruction-tuned models: add control tokens.
prompt_string = "<start_of_turn>user\n" + input +
"<end_of_turn>\n<start_of_turn>model\n";
}
std::vector<int> prompt;
HWY_ASSERT(model_->Tokenizer().Encode(input, &prompt));
InitGenerator(inference_args_, gen_);
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
prompt.insert(prompt.begin(), gcpp::BOS_ID);
runtime_config_ = {
.max_tokens = inference_args_.max_tokens,
.max_generated_tokens = inference_args_.max_generated_tokens,
.temperature = inference_args_.temperature,
.verbosity = app_.verbosity,
.gen = &gen_,
};
}
std::pair<std::string, size_t> GemmaEnv::QueryModel(
const std::vector<int>& tokens) {
std::string res;
size_t total_tokens = 0;
auto accept_token = [](int) { return true; };
std::mt19937 gen;
gen.seed(42);
const double time_start = hwy::platform::Now();
auto stream_token = [&res, &total_tokens, &time_start, this](
int token, float) {
const StreamFunc stream_token = [&res, &total_tokens, &time_start, this](
int token, float) {
++total_tokens;
std::string token_text;
HWY_ASSERT(model_->Tokenizer().Decode(std::vector<int>{token},
&token_text));
HWY_ASSERT(
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
res += token_text;
if (app_.verbosity >= 1 && total_tokens % 100 == 0) {
if (app_.verbosity >= 1 && total_tokens % 128 == 0) {
LogSpeedStats(time_start, total_tokens);
}
return true;
@ -88,24 +113,32 @@ std::pair<std::string, int> GemmaEnv::QueryModel(const std::string& input) {
<< inference_args_.temperature;
}
gcpp::TimingInfo timing_info;
gcpp::RuntimeConfig runtime_config = {
.max_tokens = inference_args_.max_tokens,
.max_generated_tokens = inference_args_.max_generated_tokens,
.temperature = inference_args_.temperature,
.verbosity = app_.verbosity,
.gen = &gen,
.stream_token = stream_token,
.accept_token = accept_token,
};
model_->Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache_,
timing_info, /*layers_output=*/nullptr);
runtime_config_.stream_token = stream_token;
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_cache_,
timing_info);
if (app_.verbosity >= 1) {
LogSpeedStats(time_start, total_tokens);
}
return {res, total_tokens};
}
void GemmaEnv::LogSpeedStats(double time_start, size_t total_tokens) const {
std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) {
const std::vector<int> prompt =
WrapAndTokenize(model_->Tokenizer(), loader_.ModelTrainingType(),
/*pos=*/0, input);
return QueryModel(prompt);
}
float GemmaEnv::CrossEntropy(const std::string& input) {
std::vector<int> prompt = Tokenize(input);
prompt.insert(prompt.begin(), BOS_ID);
return ComputeCrossEntropy(*GetModel(), /*max_tokens=*/3072, prompt,
MutableKVCache(),
/*verbosity=*/0) /
static_cast<int>(input.size());
}
void LogSpeedStats(double time_start, size_t total_tokens) {
const double time_end = hwy::platform::Now();
const double time_elapsed = time_end - time_start;
const double tok_sec = total_tokens / time_elapsed;
@ -113,6 +146,53 @@ void GemmaEnv::LogSpeedStats(double time_start, size_t total_tokens) const {
<< " [" << tok_sec << " tokens / sec" << "]\n";
}
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
loader.Print(app.verbosity);
inference.Print(app.verbosity);
app.Print(app.verbosity);
if (app.verbosity >= 2) {
time_t now = time(nullptr);
char* dt = ctime(&now); // NOLINT
std::cout << "Date & Time : " << dt
<< "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize
<< "\n"
<< "Hardware concurrency : "
<< std::thread::hardware_concurrency() << "\n"
<< "Instruction set : "
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
<< hwy::VectorBytes() * 8 << " bits)" << "\n";
char cpu100[100];
if (hwy::platform::GetCpuString(cpu100)) {
std::cout << "CPU : " << cpu100 << "\n";
}
std::cout << "Compiled config : " << CompiledConfig() << "\n"
<< "Weight Type : "
<< gcpp::StringFromType(loader.WeightType()) << "\n"
<< "EmbedderInput Type : "
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n";
}
}
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
gcpp::AppArgs& app) {
std::cerr
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
"==========================================================\n\n"
"To run gemma.cpp, you need to "
"specify 3 required model loading arguments:\n"
" --tokenizer\n"
" --weights\n"
" --model.\n";
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
"--weights 2b-it-sfp.sbs --model 2b-it\n";
std::cerr << "\n*Model Loading Arguments*\n\n";
loader.Help();
std::cerr << "\n*Inference Arguments*\n\n";
inference.Help();
std::cerr << "\n*Application Arguments*\n\n";
app.Help();
std::cerr << "\n";
}
} // namespace gcpp

View File

@ -16,11 +16,15 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_
#include <stddef.h>
#include <memory>
#include <random>
#include <string>
#include <utility>
#include <vector>
#include "gemma/common.h"
#include "gemma/gemma.h"
#include "util/app.h"
#include "hwy/base.h"
@ -28,24 +32,60 @@
namespace gcpp {
void InitGenerator(const InferenceArgs& inference, std::mt19937& gen);
// Convenience class to load a model and run inference.
class GemmaEnv {
public:
GemmaEnv(int argc, char** argv);
size_t MaxTokens() const { return inference_args_.max_tokens; }
// Sets the maximum number of output tokens to generate.
void set_max_generated_tokens(int max_tokens) {
void SetMaxGeneratedTokens(size_t max_tokens) {
inference_args_.max_generated_tokens = max_tokens;
}
std::vector<int> Tokenize(const std::string& input) const {
std::vector<int> tokens;
HWY_ASSERT(model_->Tokenizer().Encode(input, &tokens));
return tokens;
}
std::vector<int> TokenizeAndPrependBOS(const std::string& input) const {
std::vector<int> tokens = Tokenize(input);
tokens.insert(tokens.begin(), BOS_ID);
return tokens;
}
std::string StringFromTokens(const std::vector<int>& tokens) const {
std::string string;
HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string));
return string;
}
// Runs inference on the given input and returns the top-1 result string and
// the number of tokens that were generated.
std::pair<std::string, int> QueryModel(const std::string& input);
std::pair<std::string, size_t> QueryModel(const std::vector<int>& tokens);
// Adds turn structure to input, tokenizes and calls the above overload.
std::pair<std::string, size_t> QueryModel(std::string& input);
// Runs inference on the given input and returns the cross entropy, a measure
// of how well the model predicts the correct output. It is the average
// number of bits per token.
float CrossEntropy(const std::string& input);
// Returns nullptr if the model failed to load.
Gemma* GetModel() const { return model_.get(); }
Model ModelType() const { return loader_.ModelType(); }
ModelTraining ModelTrainingType() const {
return loader_.ModelTrainingType();
}
int Verbosity() const { return app_.verbosity; }
gcpp::RuntimeConfig& MutableConfig() { return runtime_config_; }
std::mt19937& MutableGen() { return gen_; }
KVCache& MutableKVCache() { return kv_cache_; }
private:
// Logs the inference speed in tokens/sec.
void LogSpeedStats(double time_start, size_t total_tokens) const;
// Arguments to the model loader: file locations, etc.
LoaderArgs loader_;
// Arguments to the inference function: max tokens, etc.
@ -60,10 +100,16 @@ class GemmaEnv {
std::unique_ptr<Gemma> model_;
// The KV cache to use for inference.
KVCache kv_cache_;
gcpp::RuntimeConfig runtime_config_;
};
// Logs the inference speed in tokens/sec.
void LogSpeedStats(double time_start, size_t total_tokens);
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
gcpp::AppArgs& app);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_

View File

@ -13,108 +13,81 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <iostream>
#include <ostream>
#include <random>
#include <sstream>
#include <stddef.h>
#include <stdio.h>
#include <string>
// Placeholder for internal header, do not modify.
#include "benchmark/benchmark.h"
#include "gemma/benchmark_helper.h"
void run_gemma_prompt(const std::string& prompt_string,
gcpp::GemmaEnv& env,
benchmark::State& state) {
std::mt19937 gen;
namespace gcpp {
if (prompt_string.empty()) return;
// Shared state for benchmarks - unfortunately the library does not allow
// passing context nor closures. Raw pointer because style guide forbids
// non-local static objects with dtors.t
GemmaEnv* s_env = nullptr;
int token_counter = 0;
void RunPrompt(const std::string& original_prompt, benchmark::State& state) {
size_t total_tokens = 0;
for (auto s : state) {
auto [response, n] = env.QueryModel(prompt_string);
std::cout << "response: " << response << "\n";
std::cout << "n: " << n << "\n";
token_counter += n;
std::string prompt = original_prompt; // reset from original
auto [response, n] = s_env->QueryModel(prompt);
if (s_env->Verbosity() != 0) {
fprintf(stdout, "|%s|\n", response.c_str());
}
total_tokens += n;
}
state.SetItemsProcessed(token_counter);
state.SetItemsProcessed(total_tokens);
}
// Awkward global because benchmarks don't support additional state, so it is
// either this or cast to int64_t.
gcpp::GemmaEnv* global_env = nullptr;
} // namespace gcpp
static void BM_short_prompt(benchmark::State& state) {
run_gemma_prompt("What is the capital of Spain?", *global_env,
state);
gcpp::RunPrompt("What is the capital of Spain?", state);
}
static void BM_factuality_prompt(benchmark::State& state) {
run_gemma_prompt("How does an inkjet printer work?",
*global_env, state);
gcpp::RunPrompt("How does an inkjet printer work?", state);
}
static void BM_creative_prompt(benchmark::State& state) {
run_gemma_prompt(
"Tell me a story about a magical bunny and their TRS-80.",
*global_env, state);
gcpp::RunPrompt("Tell me a story about a magical bunny and their TRS-80.",
state);
}
static void BM_coding_prompt(benchmark::State& state) {
run_gemma_prompt(
"Write a python program to generate a fibonacci sequence.",
*global_env, state);
gcpp::RunPrompt("Write a python program to generate a fibonacci sequence.",
state);
}
static void BM_long_coding_prompt(benchmark::State& state) {
std::ifstream t("benchmarks.cc", std::ios_base::in);
std::stringstream buffer;
buffer << t.rdbuf();
std::string prompt_string = buffer.str();
t.close();
BENCHMARK(BM_short_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
run_gemma_prompt("Make improvements to the following code:\n " +
prompt_string, *global_env, state);
}
BENCHMARK(BM_factuality_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
BENCHMARK(BM_creative_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
BENCHMARK(BM_coding_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
int main(int argc, char** argv) {
{
// Placeholder for internal init, do not modify.
}
gcpp::GemmaEnv env(argc, argv);
env.SetMaxGeneratedTokens(256);
gcpp::s_env = &env;
env.set_max_generated_tokens(128);
global_env = &env;
BENCHMARK(BM_short_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
env.set_max_generated_tokens(256);
BENCHMARK(BM_factuality_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
BENCHMARK(BM_creative_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
BENCHMARK(BM_coding_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
env.set_max_generated_tokens(1024);
BENCHMARK(BM_long_coding_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
::benchmark ::RunSpecifiedBenchmarks();
::benchmark ::Shutdown();
::benchmark::RunSpecifiedBenchmarks();
::benchmark::Shutdown();
return 0;
}

View File

@ -28,20 +28,18 @@
namespace gcpp {
namespace {
constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt", "2b-it",
"7b-it", "gr2b-it", "tiny"};
constexpr Model kModelTypes[] = {
Model::GEMMA_2B, Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_2B,
Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_TINY};
constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT,
ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT,
ModelTraining::GEMMA_IT};
} // namespace
const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training) {
constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt", "2b-it",
"7b-it", "gr2b-it", "tiny"};
constexpr Model kModelTypes[] = {
Model::GEMMA_2B, Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_2B,
Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_TINY};
constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT,
ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT,
ModelTraining::GEMMA_IT};
constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags);
static char kErrorMessageBuffer[kNum * 8 + 1024] =
"Invalid or missing model flag, need to specify one of ";
@ -51,37 +49,58 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag,
}
strcat(kErrorMessageBuffer, kModelFlags[kNum - 1]); // NOLINT
strcat(kErrorMessageBuffer, "."); // NOLINT
std::string model_type_lc = model_flag;
std::transform(begin(model_type_lc), end(model_type_lc), begin(model_type_lc),
[](unsigned char c) { return std::tolower(c); });
for (size_t i = 0; i < kNum; i++) {
if (kModelFlags[i] == model_type_lc) {
model = kModelTypes[i];
training = kModelTraining[i];
HWY_ASSERT(std::string(ModelString(model, training)) == model_type_lc);
return nullptr;
}
}
return kErrorMessageBuffer;
}
const char* ModelString(Model model, ModelTraining training) {
if (model == Model::GEMMA_TINY) return "tiny";
static_assert(static_cast<size_t>(ModelTraining::GEMMA_IT) == 0);
constexpr const char* k2B[] = {"2b-it", "2b-pt"};
constexpr const char* k7B[] = {"7b-it", "7b-pt"};
constexpr const char* kGr2B[] = {"gr2b-it", "gr2b-pt"};
if (model == Model::GEMMA_2B) return k2B[static_cast<size_t>(training)];
if (model == Model::GEMMA_7B) return k7B[static_cast<size_t>(training)];
if (model == Model::GRIFFIN_2B) return kGr2B[static_cast<size_t>(training)];
HWY_ABORT("Unknown model %d training %d\n", static_cast<int>(model),
static_cast<int>(training));
}
constexpr const char* kTypeStrings[] = {"f32", "bf16", "sfp"};
const char* StringFromType(Type type) {
return kTypeStrings[static_cast<size_t>(type)];
}
const char* ParseType(const std::string& type_string, Type& type) {
constexpr Type kTypes[] = {Type::kF32, Type::kBF16, Type::kSFP};
constexpr const char* kStrings[] = {"f32", "bf16", "sfp"};
constexpr size_t kNum = std::end(kStrings) - std::begin(kStrings);
constexpr size_t kNum = std::end(kTypeStrings) - std::begin(kTypeStrings);
static char kErrorMessageBuffer[kNum * 8 + 100] =
"Invalid or missing type, need to specify one of ";
for (size_t i = 0; i + 1 < kNum; i++) {
strcat(kErrorMessageBuffer, kStrings[i]); // NOLINT
strcat(kErrorMessageBuffer, kTypeStrings[i]); // NOLINT
strcat(kErrorMessageBuffer, ", "); // NOLINT
}
strcat(kErrorMessageBuffer, kStrings[kNum - 1]); // NOLINT
strcat(kErrorMessageBuffer, kTypeStrings[kNum - 1]); // NOLINT
strcat(kErrorMessageBuffer, "."); // NOLINT
std::string type_lc = type_string;
std::transform(begin(type_lc), end(type_lc), begin(type_lc),
[](unsigned char c) { return std::tolower(c); });
for (size_t i = 0; i < kNum; i++) {
if (kStrings[i] == type_lc) {
type = kTypes[i];
if (kTypeStrings[i] == type_lc) {
type = static_cast<Type>(i);
HWY_ASSERT(std::string(StringFromType(type)) == type_lc);
return nullptr;
}
}

View File

@ -104,30 +104,30 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
// Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float),
// calls FUNC<ConfigT<TWEIGHT>> where ConfigT is chosen via MODEL enum.
#define GEMMA_DISPATCH_MODEL(MODEL, TWEIGHT, FUNC, ARGS) \
switch (MODEL) { \
case Model::GEMMA_TINY: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemmaTiny<TWEIGHT>>) \
ARGS; \
break; \
} \
case Model::GEMMA_2B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2B<TWEIGHT>>) \
ARGS; \
break; \
} \
case Model::GEMMA_7B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma7B<TWEIGHT>>) \
ARGS; \
break; \
} \
case Model::GRIFFIN_2B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGriffin2B<TWEIGHT>>) \
ARGS; \
break; \
} \
default: \
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
#define GEMMA_DISPATCH_MODEL(MODEL, TWEIGHT, FUNC, ARGS) \
switch (MODEL) { \
case Model::GEMMA_TINY: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemmaTiny<TWEIGHT>>) \
ARGS; \
break; \
} \
case Model::GEMMA_2B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2B<TWEIGHT>>) \
ARGS; \
break; \
} \
case Model::GEMMA_7B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma7B<TWEIGHT>>) \
ARGS; \
break; \
} \
case Model::GRIFFIN_2B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGriffin2B<TWEIGHT>>) \
ARGS; \
break; \
} \
default: \
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
}
// Like CallForModelAndWeight, but for SIMD function templates. This is a macro
@ -154,18 +154,12 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training);
const char* ParseType(const std::string& type_string, Type& type);
static inline const char* StringFromType(Type type) {
switch (type) {
case Type::kF32:
return "f32";
case Type::kBF16:
return "bf16";
case Type::kSFP:
return "sfp";
default:
return "?";
}
}
// Inverse of ParseModelTypeAndTraining.
const char* ModelString(Model model, ModelTraining training);
const char* StringFromType(Type type);
// ----------------------------------------------------------------------------
//
// __builtin_sqrt is not constexpr as of Clang 17.
#if HWY_COMPILER_GCC_ACTUAL

View File

@ -136,7 +136,7 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 32;
static constexpr int kVocabSize = 16;
static constexpr int kVocabSize = 64;
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
FixedLayerConfig<3>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size();
@ -146,8 +146,8 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
kLayers);
static constexpr int kModelDim = 64;
static constexpr int kFFHiddenDim = 128;
static constexpr int kModelDim = 128;
static constexpr int kFFHiddenDim = 256;
static constexpr int kHeads = 4;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 16; // query size == key size == value size

View File

@ -111,7 +111,7 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
};
TimingInfo timing_info;
gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info, nullptr);
gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info);
const float scale = 1.0f / std::log(2.0f);
return cross_entropy * scale;

View File

@ -17,7 +17,6 @@
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
#include <cstdio>
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
@ -89,6 +88,11 @@ struct Activations {
att_post2; // accumulation of attention outputs over heads
std::array<hwy::bfloat16_t, kBatchSize * kModelDim> bf_pre_ffw_rms_out;
std::array<float, kBatchSize * TConfig::kFFHiddenDim * 2> ffw_hidden;
// For FFW MatMul.
std::array<float, kBatchSize * TConfig::kFFHiddenDim> C1;
std::array<float, kBatchSize * TConfig::kFFHiddenDim> C2;
// bf_ version can't be used until GeluMulToBF16 issue in FFW() is resolved.
// std::array<hwy::bfloat16_t, kBatchSize * 2 * TConfig::kFFHiddenDim>
// bf_ffw_hidden;
@ -208,8 +212,8 @@ bool GemmaTokenizer::Encode(const std::string& input,
}
bool GemmaTokenizer::Encode(const std::string& input,
std::vector<int>* pieces) const {
return impl_->Encode(input, pieces);
std::vector<int>* ids) const {
return impl_->Encode(input, ids);
}
// Given a sequence of ids, decodes it into a detokenized output.
@ -509,41 +513,70 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
float* HWY_RESTRICT even_odd = activations.even_odd.data();
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
// TODO: MatMul does not yet support adding another matrix to the result.
if constexpr (!TConfig::kFFBiases) {
PROFILER_ZONE("Gen.FFW.GatedGELU");
const hwy::bfloat16_t* HWY_RESTRICT vec =
activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim;
float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset;
float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
// Same matrix, first and second half of rows. Could fuse into one MatVec.
MatVecT</*kAdd=*/TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec,
TConfig::kFFBiases ?
layer_weights->ffw_gating_biases.data() + kFFHiddenDim : nullptr,
even_odd, out_mul, pool);
// Gate, will go through the nonlinearity.
MatVecT</*kAdd=*/TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, 0, vec,
layer_weights->ffw_gating_biases.data(), even_odd, out, pool);
// MatMul expects col-major B, which is what we have: kModelDim consecutive
// elements in memory, repeated kFFHiddenDim times.
const auto b1 = layer_weights->gating_einsum_w.data();
constexpr size_t kColsA = kModelDim;
constexpr size_t kColsB = kFFHiddenDim;
const auto b2 = b1 + kColsA * kColsB;
auto A = activations.bf_pre_ffw_rms_out.data();
// Will go through GELU.
MatMul_4x4_Batch<kColsA, kColsB>(num_tokens, A, b1, activations.C1.data(),
pool);
// What to multiply by.
MatMul_4x4_Batch<kColsA, kColsB>(num_tokens, A, b2, activations.C2.data(),
pool);
// Gelu and multiply by gate.
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
hn::Transform1(DF(), out, kFFHiddenDim, out_mul,
[](DF df, VF v, VF mul)
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); });
}
hn::Transform1(DF(), activations.C1.data(), kFFHiddenDim * num_tokens,
activations.C2.data(), [](DF df, VF v, VF mul) HWY_ATTR {
return hn::Mul(mul, Gelu(df, v));
});
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
PROFILER_ZONE("Gen.FFW\\GatedGELU");
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
MatVecT</*kAdd=*/TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
layer_weights->linear_w, 0,
activations.ffw_hidden.data() + hidden_offset,
layer_weights->ffw_output_biases.data(), even_odd,
activations.ffw_out.data() + batch_idx * kModelDim, pool);
MatMul_4x4_Batch<kFFHiddenDim, kModelDim>(num_tokens, activations.C1.data(),
layer_weights->linear_w.data(),
activations.ffw_out.data(), pool);
} else {
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
const hwy::bfloat16_t* HWY_RESTRICT vec =
activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim;
float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset;
float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
PROFILER_ZONE("Gen.FFW.GatedGELU");
// Same matrix, first and second half of rows. Could fuse into one MatVec.
MatVecT<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec,
TConfig::kFFBiases
? layer_weights->ffw_gating_biases.data() + kFFHiddenDim
: nullptr,
even_odd, out_mul, pool);
// Gate, will go through the nonlinearity.
MatVecT<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, 0, vec,
layer_weights->ffw_gating_biases.data(), even_odd, out, pool);
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
hn::Transform1(DF(), out, kFFHiddenDim, out_mul,
[](DF df, VF v, VF mul)
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); });
MatVecT</*kAdd=*/TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
layer_weights->linear_w, 0,
activations.ffw_hidden.data() + hidden_offset,
layer_weights->ffw_output_biases.data(), even_odd,
activations.ffw_out.data() + batch_idx * kModelDim, pool);
}
}
}
@ -649,16 +682,16 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
// Compute the transformer for a batch of input tokens. During generation,
// we usually have num_tokens == 1 (and also kBatchSize == 1).
template <size_t kBatchSize, typename WeightArrayT, class TConfig>
HWY_NOINLINE void Transformer(const int *tokens, size_t num_tokens, size_t pos,
HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, size_t pos,
const WeightArrayT& weights,
Activations<TConfig, kBatchSize>& activations,
KVCache& kv_cache, hwy::ThreadPool& pool,
LayersOutputT* layers_output) {
const LayersOutputFunc& layers_output) {
HWY_ASSERT(num_tokens <= kBatchSize);
if (layers_output != nullptr) {
if (layers_output) {
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
float token_f = tokens[token_idx];
(*layers_output)(pos + token_idx, "Tokens", &token_f, 1);
layers_output(pos + token_idx, "Tokens", &token_f, 1);
}
}
static constexpr size_t kModelDim = TConfig::kModelDim;
@ -713,12 +746,11 @@ HWY_NOINLINE void Transformer(const int *tokens, size_t num_tokens, size_t pos,
}
AddFromBatched<kBatchSize>(num_tokens, activations.ffw_out.data(),
activations.x.data(), kModelDim);
if (layers_output != nullptr) {
if (layers_output) {
std::string block_name = "blocks." + std::to_string(layer);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
(*layers_output)(pos + token_idx, block_name,
activations.x.data() + token_idx * kModelDim,
kModelDim);
layers_output(pos + token_idx, block_name,
activations.x.data() + token_idx * kModelDim, kModelDim);
}
}
}
@ -727,10 +759,10 @@ HWY_NOINLINE void Transformer(const int *tokens, size_t num_tokens, size_t pos,
RMSNormInplaceBatched<kBatchSize>(num_tokens, weights.final_norm_scale.data(),
activations.x.data(), kModelDim);
if (layers_output != nullptr) {
if (layers_output) {
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
(*layers_output)(pos + token_idx, "final_norm",
activations.x.data() + token_idx * kModelDim, kModelDim);
layers_output(pos + token_idx, "final_norm",
activations.x.data() + token_idx * kModelDim, kModelDim);
}
}
}
@ -782,8 +814,7 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
const ByteStorageT& decode_u8,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info,
LayersOutputT* layers_output) {
hwy::ThreadPool& pool, TimingInfo& timing_info) {
const CompressedWeights<TConfig>& weights = GetWeights<TConfig>(weights_u8);
auto& prefill_activations =
GetActivations<TConfig, kPrefillBatchSize>(prefill_u8);
@ -860,7 +891,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
pos < max_tokens && generate_pos < max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) {
Transformer<kDecodeBatchSize>(&token, kDecodeBatchSize, pos, weights,
activations, kv_cache, pool, layers_output);
activations, kv_cache, pool,
runtime_config.layers_output);
float token_logit = 0.0f;
// The condition below is always true if we are doing Prefill above.
// We keep it here for clarity so that the code is correct even if Prefill
@ -953,17 +985,37 @@ Gemma::~Gemma() {
void Gemma::Generate(const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, TimingInfo& timing_info,
LayersOutputT* layers_output) {
KVCache& kv_cache, TimingInfo& timing_info) {
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
GEMMA_EXPORT_AND_DISPATCH(
model_type_, weight_type_, GenerateT,
(weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos,
kv_cache, pool_, timing_info, layers_output));
kv_cache, pool_, timing_info));
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
}
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const ModelTraining training, size_t pos,
std::string& prompt) {
// Instruction-tuned models are trained to expect control tokens.
if (training == ModelTraining::GEMMA_IT) {
// Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation.
const std::string start = (pos == 0)
? "<start_of_turn>user\n"
: "<end_of_turn>\n<start_of_turn>user\n";
prompt = start + prompt + "<end_of_turn>\n<start_of_turn>model\n";
}
std::vector<int> tokens;
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
// Both pre-trained and instruction-tuned require BOS as first token.
if (pos == 0) {
tokens.insert(tokens.begin(), gcpp::BOS_ID);
}
return tokens;
}
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -74,10 +74,17 @@ class GemmaTokenizer {
using StreamFunc = std::function<bool(int, float)>;
// If not empty, AcceptFunc is called with token. It should return false for
// tokens you don't want to generate and true for tokens you want to generate.
using AcceptFunc = std::function<bool(int)>;
using AcceptFunc = std::function<bool(int, float)>;
// If not empty, SampleFunc is called with the probability distribution for the
// next token, and its return value is used as the next generated token.
using SampleFunc = std::function<int(const float*, size_t)>;
// Will be called for layers output with:
// - position in the tokens sequence
// - name of the data, p.ex. "tokens", "block.1", "final_norm"
// - pointer to the data array
// - size of the data array
using LayersOutputFunc =
std::function<void(int, const std::string&, const float*, size_t)>;
struct RuntimeConfig {
size_t max_tokens;
@ -88,6 +95,7 @@ struct RuntimeConfig {
StreamFunc stream_token;
AcceptFunc accept_token; // if empty, accepts all tokens.
SampleFunc sample_func; // if empty, uses SampleTopK.
LayersOutputFunc layers_output; // if not empty, called after each layer.
int eos_id = EOS_ID;
};
@ -97,14 +105,6 @@ struct TimingInfo {
double time_to_first_token = 0.0;
};
// Will be called for layers output with:
// - position in the tokens sequence
// - name of the data, p.ex. "tokens", "block.1", "final_norm"
// - pointer to the data array
// - size of the data array
using LayersOutputT =
std::function<void(int, const std::string&, const float*, size_t)>;
class Gemma {
public:
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
@ -121,12 +121,9 @@ class Gemma {
const ByteStorageT& Prefill() const { return prefill_u8_; }
const ByteStorageT& Decode() const { return decode_u8_; }
// layers_output is optional; if set - it will be called with the activations
// output after applying each layer.
void Generate(const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, TimingInfo& timing_info,
LayersOutputT* layers_output = nullptr);
KVCache& kv_cache, TimingInfo& timing_info);
private:
hwy::ThreadPool& pool_;
@ -141,14 +138,19 @@ class Gemma {
Type weight_type_;
};
// Adds BOS token and possibly 'turn' annotations, which depend on `training`
// and `pos`, the number of tokens decoded so far; returns the corresponding
// tokens. Asserts that tokenization is successful.
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
ModelTraining training, size_t pos,
std::string& prompt);
// DEPRECATED, call Gemma::Generate directly.
HWY_INLINE void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& /*pool*/,
TimingInfo& timing_info,
LayersOutputT* layers_output) {
gemma.Generate(runtime_config, prompt, start_pos, kv_cache, timing_info,
layers_output);
TimingInfo& timing_info) {
gemma.Generate(runtime_config, prompt, start_pos, kv_cache, timing_info);
}
} // namespace gcpp

View File

@ -17,89 +17,35 @@
#include <stdio.h>
#include <memory>
#include <random>
#include <string>
#include <vector>
// Placeholder for internal header, do not modify.
#include "gemma/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/cross_entropy.h"
#include "gemma/ops.h"
#include "util/app.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/tests/test_util-inl.h"
#include "hwy/tests/hwy_gtest.h"
namespace gcpp {
namespace {
int s_argc = 0;
char** s_argv = nullptr;
// Shared state. Requires argc/argv, so construct in main and use the same raw
// pointer approach as in benchmarks.cc. Note that the style guide forbids
// non-local static variables with dtors.
GemmaEnv* s_env = nullptr;
class GemmaTest : public ::testing::Test {
protected:
static void SetUpTestSuite() {
gcpp::AppArgs app(s_argc, s_argv);
gcpp::LoaderArgs loader(s_argc, s_argv);
if (const char* err = loader.Validate()) {
fprintf(stderr, "Insufficient LoaderArgs, skipping e2e tests.\n");
} else {
fprintf(stderr, "Loading model..\n");
s_pool = std::make_unique<hwy::ThreadPool>(app.num_threads);
s_gemma = AllocateGemma(loader, *s_pool);
s_kv_cache = KVCache::Create(loader.ModelType());
s_model = loader.ModelType();
}
}
static void TearDownTestSuite() {
s_pool.reset();
s_gemma.reset();
}
std::string GemmaReply(const std::string& prompt_string) {
std::mt19937 gen;
gen.seed(42);
std::vector<int> prompt;
HWY_ASSERT(s_gemma->Tokenizer().Encode(prompt_string, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
prompt.insert(prompt.begin(), BOS_ID);
std::vector<int> response;
auto stream_token = [&response](int token, float) {
response.push_back(token);
return true;
};
gcpp::RuntimeConfig runtime_config = {
.max_tokens = 3072,
.max_generated_tokens = 2048,
.temperature = 1.0,
.verbosity = 0,
.gen = &gen,
.stream_token = stream_token,
};
gcpp::TimingInfo timing_info;
s_gemma->Generate(runtime_config, prompt, /*start_pos=*/0, s_kv_cache,
timing_info, /*layers_output=*/nullptr);
std::string response_text;
HWY_ASSERT(s_gemma->Tokenizer().Decode(response, &response_text));
return response_text;
}
float GemmaCrossEntropy(const std::string& prompt_string) {
std::vector<int> prompt;
HWY_ASSERT(s_gemma->Tokenizer().Encode(prompt_string, &prompt));
prompt.insert(prompt.begin(), BOS_ID);
return ComputeCrossEntropy(*s_gemma, /*max_tokens=*/3072, prompt,
s_kv_cache,
/*verbosity=*/0) /
prompt_string.size();
std::string GemmaReply(const std::string& prompt) {
s_env->SetMaxGeneratedTokens(2048);
s_env->MutableConfig().temperature = 0.0f; // deterministic
s_env->MutableConfig().verbosity = 0;
// Using the turn structure worsens results.
const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt);
auto [response, n] = s_env->QueryModel(tokens);
return response;
}
void TestQuestions(const char* kQA[][2], size_t num_questions) {
if (!s_gemma) return;
if (!s_env->GetModel()) return;
for (size_t i = 0; i < num_questions; ++i) {
fprintf(stderr, "Question %zu\n\n", i + 1);
std::string response = GemmaReply(kQA[i][0]);
@ -107,18 +53,8 @@ class GemmaTest : public ::testing::Test {
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
}
}
static std::unique_ptr<hwy::ThreadPool> s_pool;
static std::unique_ptr<gcpp::Gemma> s_gemma;
static gcpp::KVCache s_kv_cache;
static gcpp::Model s_model;
};
/*static*/ std::unique_ptr<hwy::ThreadPool> GemmaTest::s_pool;
/*static*/ std::unique_ptr<gcpp::Gemma> GemmaTest::s_gemma;
/*static*/ gcpp::KVCache GemmaTest::s_kv_cache;
/*static*/ gcpp::Model GemmaTest::s_model;
TEST_F(GemmaTest, Geography) {
static const char* kQA[][2] = {
{"What is the capital of Hungary?", "Budapest"},
@ -130,7 +66,7 @@ TEST_F(GemmaTest, Geography) {
TEST_F(GemmaTest, History) {
static const char* kQA[][2] = {
{"When was the Battle of Hastings?", "1066"},
{"When was the battle of Hastings?", "1066"},
};
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
TestQuestions(kQA, kNum);
@ -181,42 +117,39 @@ static const char kGettysburg[] = {
"people, for the people, shall not perish from the earth.\n"};
TEST_F(GemmaTest, CrossEntropySmall) {
if (!s_gemma) return;
if (!s_env->GetModel()) return;
static const char kSmall[] =
"The capital of Hungary is Budapest which is located in Europe.";
float entropy = GemmaCrossEntropy(kSmall);
float entropy = s_env->CrossEntropy(kSmall);
fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 2.1f : 2.0f);
EXPECT_LT(entropy,
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 2.1f : 2.0f);
}
TEST_F(GemmaTest, CrossEntropyJingleBells) {
if (!s_gemma) return;
float entropy = GemmaCrossEntropy(kJingleBells);
if (!s_env->GetModel()) return;
float entropy = s_env->CrossEntropy(kJingleBells);
fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 0.9f : 1.8f);
EXPECT_LT(entropy,
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 0.9f : 1.8f);
}
TEST_F(GemmaTest, CrossEntropyGettysburg) {
if (!s_gemma) return;
float entropy = GemmaCrossEntropy(kGettysburg);
if (!s_env->GetModel()) return;
float entropy = s_env->CrossEntropy(kGettysburg);
fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 0.8f : 1.2f);
EXPECT_LT(entropy,
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 0.8f : 1.2f);
}
} // namespace
} // namespace gcpp
int main(int argc, char** argv) {
{
// Placeholder for internal init, do not modify.
}
gcpp::GemmaEnv env(argc, argv);
gcpp::s_env = &env;
// For later use by SetUp.
gcpp::s_argc = argc;
gcpp::s_argv = argv;
// Probably should be called before SetUpTestSuite.
testing::InitGoogleTest(&gcpp::s_argc, argv);
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -23,6 +23,7 @@
#include <stdio.h>
#include <array>
#include <cmath>
#include <random>
#include <type_traits> // std::enable_if_t
@ -70,6 +71,29 @@ StaticCast(From from) noexcept {
return static_cast<To>(from);
}
// For testing.
template <typename MatT>
void AssertClose(const MatT* HWY_RESTRICT expected,
const MatT* HWY_RESTRICT actual, size_t num) {
for (size_t idx = 0; idx < num; idx++) {
const double expected_value = hwy::ConvertScalarTo<double>(expected[idx]);
const double actual_value = hwy::ConvertScalarTo<double>(actual[idx]);
const double magnitude = std::abs(expected_value);
const double tolerance =
256.0 * hwy::ConvertScalarTo<double>(hwy::Epsilon<MatT>()) *
HWY_MAX(magnitude, 1.0);
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f\n", idx,
expected_value, idx, actual_value);
HWY_ASSERT(0);
}
}
}
template <size_t kOuter>
HWY_INLINE constexpr size_t RowsPerStrip() {
// Aim for 128 work items to reduce pool overhead. Must be at least one
@ -362,11 +386,11 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
c23, c30, c31, c32, c33, tile_c, stride_c);
}
// Same as above, but with mixed Mat types: (f32, sfp)).
// Same as above, but with mixed Mat types: (f32, compressed).
template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
HWY_IF_F32(MatTA)>
HWY_IF_F32(MatTA), typename MatTB, HWY_IF_T_SIZE(MatTB, 1)>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const SfpStream* HWY_RESTRICT B,
const MatTB* HWY_RESTRICT B,
float* HWY_RESTRICT C, const size_t idx_tile,
const size_t xtiles, const size_t stride_a,
const size_t stride_b, const size_t stride_c) {
@ -406,7 +430,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
hwy::AlignedFreeUniquePtr<float[]> tile_b_unique_ptr =
hwy::AllocateAligned<float>(kRegRows * kColsA_RowsB);
CompressTraits<SfpStream>::Decompress(
CompressTraits<MatTB>::Decompress(
d,
/*in_capacity=*/0, B, stride_b * row_b_col_c, tile_b_unique_ptr.get(),
kRegRows * kColsA_RowsB);
@ -455,11 +479,11 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
c23, c30, c31, c32, c33, tile_c, stride_c);
}
// Same as above, but with mixed Mat types: (bf16, sfp)).
// Same as above, but with mixed Mat types: (bf16, compressed)).
template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
HWY_IF_BF16(MatTA)>
HWY_IF_BF16(MatTA), typename MatTB, HWY_IF_T_SIZE(MatTB, 1)>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const SfpStream* HWY_RESTRICT B,
const MatTB* HWY_RESTRICT B,
float* HWY_RESTRICT C, const size_t idx_tile,
const size_t xtiles, const size_t stride_a,
const size_t stride_b, const size_t stride_c) {
@ -504,7 +528,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
hwy::AlignedFreeUniquePtr<float[]> tile_b_unique_ptr =
hwy::AllocateAligned<float>(kRegRows * kColsA_RowsB);
CompressTraits<SfpStream>::Decompress(
CompressTraits<MatTB>::Decompress(
d32,
/*in_capacity=*/0, B, stride_b * row_b_col_c, tile_b_unique_ptr.get(),
kRegRows * kColsA_RowsB);
@ -656,70 +680,103 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
c22, c23, c30, c31, c32, c33, tile_c, stride_c);
}
// Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and
// kColsBC is 24k or 3k. Note: B is transposed (column-major).
// This function loops over all tiles (static scheduling). TODO(janwas): we can
// possibly remove this if ThreadPool(0) is as efficient as the loop.
template <size_t kRowsAC, size_t kColsA_RowsB, size_t kColsBC, typename MatT>
void GEMM_4x4_Static(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B,
MatT* HWY_RESTRICT C) {
const hn::ScalableTag<MatT> d;
const size_t N = hn::Lanes(d); // column step size
// Same as above, but with mixed Mat types: (bf16, f32).
template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
HWY_IF_BF16(MatTA), typename MatTB, HWY_IF_F32(MatTB)>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const MatTB* HWY_RESTRICT B,
float* HWY_RESTRICT C, const size_t idx_tile,
const size_t xtiles, const size_t stride_a,
const size_t stride_b, const size_t stride_c) {
constexpr size_t kRegRows = 4;
constexpr size_t kRegCols = 4; // in vectors
constexpr size_t kRegCols = 4;
static_assert(kNumRows <= kRegRows);
static_assert(kRowsAC % kRegRows == 0);
static_assert(kColsBC % kRegCols == 0);
HWY_ASSERT(kColsA_RowsB % (N * kRegCols) == 0);
constexpr size_t kTilesY = kRowsAC / kRegRows;
constexpr size_t kTilesX = kColsBC / kRegCols;
constexpr size_t kTiles = kTilesX * kTilesY;
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
const size_t row_a = idx_tile / xtiles * kRegRows;
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
constexpr size_t kStrideA = kColsA_RowsB;
constexpr size_t kStrideB = kColsA_RowsB; // B is column-major
constexpr size_t kStrideC = kColsBC;
const hn::ScalableTag<float> d32;
using VF = hn::Vec<decltype(d32)>;
// TODO: Using half-vectors for now, it might be faster to
// PromoteLower/UpperTo, and more so to PromoteEven/OddTo if we have packed B
// accordingly.
const hn::Rebind<MatTA, decltype(d32)> d16;
HWY_DASSERT(Lanes(d16) == Lanes(d32));
const size_t N = Lanes(d16);
VF c00 = hn::Zero(d32);
VF c01 = hn::Zero(d32);
VF c02 = hn::Zero(d32);
VF c03 = hn::Zero(d32);
VF c10 = hn::Zero(d32);
VF c11 = hn::Zero(d32);
VF c12 = hn::Zero(d32);
VF c13 = hn::Zero(d32);
VF c20 = hn::Zero(d32);
VF c21 = hn::Zero(d32);
VF c22 = hn::Zero(d32);
VF c23 = hn::Zero(d32);
VF c30 = hn::Zero(d32);
VF c31 = hn::Zero(d32);
VF c32 = hn::Zero(d32);
VF c33 = hn::Zero(d32);
const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a;
const MatTB* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c;
// Loop over columns of A and columns of the transposed B, in steps of N.
// Accumulates into the c## vectors.
HWY_UNROLL(1)
for (size_t idx_tile = 0; idx_tile < kTiles; ++idx_tile) {
GEMM_4x4_Tile<kRegRows, kColsA_RowsB>(
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
// Promote bf16 to f32
const VF b0 = hn::LoadU(d32, tile_b + stride_b * 0 + col_ab);
const VF b1 = hn::LoadU(d32, tile_b + stride_b * 1 + col_ab);
const VF b2 = hn::LoadU(d32, tile_b + stride_b * 2 + col_ab);
const VF b3 = hn::LoadU(d32, tile_b + stride_b * 3 + col_ab);
const VF a0 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 0 + col_ab));
c00 = hn::MulAdd(a0, b0, c00);
c01 = hn::MulAdd(a0, b1, c01);
c02 = hn::MulAdd(a0, b2, c02);
c03 = hn::MulAdd(a0, b3, c03);
if (kNumRows == 1) continue;
const VF a1 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab));
c10 = hn::MulAdd(a1, b0, c10);
c11 = hn::MulAdd(a1, b1, c11);
c12 = hn::MulAdd(a1, b2, c12);
c13 = hn::MulAdd(a1, b3, c13);
if (kNumRows == 2) continue;
const VF a2 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab));
c20 = hn::MulAdd(a2, b0, c20);
c21 = hn::MulAdd(a2, b1, c21);
c22 = hn::MulAdd(a2, b2, c22);
c23 = hn::MulAdd(a2, b3, c23);
if (kNumRows == 3) continue;
const VF a3 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab));
c30 = hn::MulAdd(a3, b0, c30);
c31 = hn::MulAdd(a3, b1, c31);
c32 = hn::MulAdd(a3, b2, c32);
c33 = hn::MulAdd(a3, b3, c33);
}
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
StoreHorizontalSums<kNumRows>(d32, c00, c01, c02, c03, c10, c11, c12, c13,
c20, c21, c22, c23, c30, c31, c32, c33, tile_c,
stride_c);
}
// Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and
// kColsBC is 24k or 3k. Note: B is transposed (column-major).
// This function processes tiles in parallel with a work-stealing thread pool.
template <size_t kRowsAC, size_t kColsA_RowsB, size_t kColsBC, typename MatTA,
typename MatTB, typename OutT>
HWY_NOINLINE void MatMul_4x4(const MatTA* HWY_RESTRICT A,
const MatTB* HWY_RESTRICT B, OutT* HWY_RESTRICT C,
hwy::ThreadPool& pool) {
// Process reg-sized tiles of C in parallel. We currently write C directly,
// which touches more memory than fits in L3. TODO: add another level of loops
// so that we finish one L3-sized piece of C at a time.
const hn::ScalableTag<MatTA> d;
const size_t N = Lanes(d);
constexpr size_t kRegRows = 4;
constexpr size_t kRegCols = 4; // in vectors
static_assert(kRowsAC % kRegRows == 0);
static_assert(kColsBC % kRegCols == 0);
HWY_ASSERT(kColsA_RowsB % (N * kRegCols) == 0);
const size_t kTilesY = kRowsAC / kRegRows;
const size_t kTilesX = kColsBC / kRegCols;
const size_t kTiles = kTilesX * kTilesY;
constexpr size_t kStrideA = kColsA_RowsB;
constexpr size_t kStrideB = kColsA_RowsB;
constexpr size_t kStrideC = kColsBC;
pool.Run(0, kTiles, [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
// Computes the finished product of one 4x4N tile and writes to C.
GEMM_4x4_Tile<kRegRows, kColsA_RowsB>(
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
});
}
// Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k,
// and kColsBC is 24k or 3k. Note: B is transposed (column-major).
// NOTE that batch_size is the number of rows of A and C.
@ -802,7 +859,8 @@ HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a,
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
// ops_test across instruction sets.
template <size_t kN, size_t kK, typename MatTA, typename MatTB>
template <size_t kN, size_t kK, typename MatTA, typename MatTB,
HWY_IF_T_SIZE_GT(MatTB, 1)>
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b,
float* HWY_RESTRICT out) {
@ -817,15 +875,18 @@ HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
}
}
template <size_t kN, size_t kK, typename MatTA>
// The above overload can handle combinations of f32 and bf16, but this one
// is required for MatTB = {SFP, NUQ}.
template <size_t kN, size_t kK, typename MatTA, typename MatTB,
HWY_IF_T_SIZE(MatTB, 1)>
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
const SfpStream* HWY_RESTRICT b_sfp_stream,
const MatTB* HWY_RESTRICT b_compr,
float* HWY_RESTRICT out) {
const hn::ScalableTag<float> d;
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
CompressTraits<SfpStream>::Decompress(d,
/*in_capacity=*/0, b_sfp_stream, 0,
b.get(), kK * kN);
CompressTraits<MatTB>::Decompress(d,
/*in_capacity=*/0, b_compr, 0, b.get(),
kK * kN);
MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), out);
}
@ -1668,12 +1729,13 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
std::array<int, k> indices{};
for (size_t i = 0; i < vocab_size; ++i) {
if (probabilities[i] < top_k[k - 1] &&
(!accept_token || accept_token(StaticCast<int>(i)))) {
(!accept_token || accept_token(StaticCast<int>(i), probabilities[i]))) {
continue;
}
for (size_t j = 0; j < k; ++j) {
if (probabilities[i] > top_k[j] &&
(!accept_token || accept_token(StaticCast<int>(i)))) {
(!accept_token ||
accept_token(StaticCast<int>(i), probabilities[i]))) {
// shift elements by 1, insert the new value, move on to next value
for (size_t idx = k - 1; idx > j; --idx) {
top_k[idx] = top_k[idx - 1];

View File

@ -506,77 +506,10 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
}
}
template <typename MatT>
void AssertClose(const MatT* HWY_RESTRICT expected,
const MatT* HWY_RESTRICT actual, size_t num) {
for (size_t idx = 0; idx < num; idx++) {
const double expected_value = hwy::ConvertScalarTo<double>(expected[idx]);
const double actual_value = hwy::ConvertScalarTo<double>(actual[idx]);
const double magnitude = std::abs(expected_value);
const double tolerance =
64.0 * hwy::ConvertScalarTo<double>(hwy::Epsilon<MatT>()) *
HWY_MAX(magnitude, 1.0);
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f\n", idx,
expected_value, idx, actual_value);
HWY_ASSERT(0);
}
}
}
template <size_t kM, size_t kN, size_t kK, typename MatTA,
typename MatTB = MatTA>
void TestTiledMatMul() {
hwy::ThreadPool pool(hwy::ThreadPool::MaxThreads());
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
GenerateMatHeap<MatTA, kM, kN>(0, pool);
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b =
GenerateMatHeap<MatTB, kN, kK>(0, pool);
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow =
GenerateZeroMatHeap<float, kM, kK>(pool);
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow_batch =
GenerateZeroMatHeap<float, kM, kK>(pool);
MatMulSlow<kM, kN, kK>(a->data(), b->data(), c_slow->data());
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(), c_slow_batch->data());
AssertClose(c_slow->data(), c_slow_batch->data(), kM * kK);
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b_trans =
GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool);
MatMul_4x4<kM, kN, kK>(a->data(), b_trans->data(), c.get(), pool);
AssertClose(c_slow->data(), c.get(), kM * kK);
}
void TestAllTiledMatMul() {
// medium-sized square test
TestTiledMatMul<512, 512, 512, float>();
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t>();
TestTiledMatMul<512, 512, 512, float, hwy::bfloat16_t>();
TestTiledMatMul<512, 512, 512, float, SfpStream>();
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t, SfpStream>();
// minimal non-square test
TestTiledMatMul<4, 128, 4, float>();
TestTiledMatMul<4, 128, 4, hwy::bfloat16_t>();
TestTiledMatMul<4, 128, 4, float, hwy::bfloat16_t>();
TestTiledMatMul<32, 128, 32, float, SfpStream>();
TestTiledMatMul<32, 128, 32, hwy::bfloat16_t, SfpStream>();
// large-scale test
// TODO(philculliton): investigate rounding issues with large matrices.
// Causes test timeout.
// TestTiledMatMul<512, 24576, 3072, float>();
}
template <size_t kM, size_t kN, size_t kK, typename MatTA,
typename MatTB = MatTA>
void TestTiledBatchMatMul() {
fprintf(stderr, "TestTiledBatchMatMul %lu, %lu, %lu", kM, kN, kK);
hwy::ThreadPool pool(3);
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
GenerateMatHeap<MatTA, kM, kN>(0, pool);
@ -600,35 +533,46 @@ void TestAllTiledBatchMatMul() {
TestTiledBatchMatMul<512, 512, 512, float>();
TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t>();
TestTiledBatchMatMul<512, 512, 512, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t, float>();
TestTiledBatchMatMul<512, 512, 512, float, SfpStream>();
TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t, SfpStream>();
// minimal non-square test
TestTiledBatchMatMul<35, 128, 4, float>();
TestTiledBatchMatMul<34, 128, 4, hwy::bfloat16_t>();
TestTiledBatchMatMul<33, 128, 4, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<35, 128, 32, float>();
TestTiledBatchMatMul<34, 128, 32, hwy::bfloat16_t>();
TestTiledBatchMatMul<33, 128, 32, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<33, 128, 32, hwy::bfloat16_t, float>();
TestTiledBatchMatMul<31, 128, 32, float, SfpStream>();
TestTiledBatchMatMul<29, 128, 32, hwy::bfloat16_t, SfpStream>();
TestTiledBatchMatMul<4, 128, 4, float>();
TestTiledBatchMatMul<4, 128, 4, hwy::bfloat16_t>();
TestTiledBatchMatMul<4, 128, 4, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<4, 128, 32, float, SfpStream>();
TestTiledBatchMatMul<4, 128, 32, hwy::bfloat16_t, SfpStream>();
TestTiledBatchMatMul<3, 128, 4, float>();
TestTiledBatchMatMul<3, 128, 4, hwy::bfloat16_t>();
TestTiledBatchMatMul<3, 128, 4, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<4, 128, 8, float>();
TestTiledBatchMatMul<4, 128, 8, hwy::bfloat16_t>();
TestTiledBatchMatMul<4, 128, 8, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<4, 128, 8, hwy::bfloat16_t, float>();
TestTiledBatchMatMul<4, 128, 8, float, SfpStream>();
TestTiledBatchMatMul<4, 128, 8, hwy::bfloat16_t, SfpStream>();
TestTiledBatchMatMul<3, 128, 32, float>();
TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t>();
TestTiledBatchMatMul<3, 128, 32, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t, float>();
TestTiledBatchMatMul<3, 128, 32, float, SfpStream>();
TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t, SfpStream>();
TestTiledBatchMatMul<2, 128, 4, float>();
TestTiledBatchMatMul<2, 128, 4, hwy::bfloat16_t>();
TestTiledBatchMatMul<2, 128, 4, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<2, 128, 32, float, SfpStream>();
TestTiledBatchMatMul<2, 128, 32, hwy::bfloat16_t, SfpStream>();
TestTiledBatchMatMul<1, 128, 4, float>();
TestTiledBatchMatMul<1, 128, 4, hwy::bfloat16_t>();
TestTiledBatchMatMul<1, 128, 4, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<2, 128, 16, float>();
TestTiledBatchMatMul<2, 128, 16, hwy::bfloat16_t>();
TestTiledBatchMatMul<2, 128, 16, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<2, 128, 16, hwy::bfloat16_t, float>();
TestTiledBatchMatMul<2, 128, 16, float, SfpStream>();
TestTiledBatchMatMul<2, 128, 16, hwy::bfloat16_t, SfpStream>();
TestTiledBatchMatMul<1, 128, 32, float>();
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t>();
TestTiledBatchMatMul<1, 128, 32, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, float>();
TestTiledBatchMatMul<1, 128, 32, float, SfpStream>();
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, SfpStream>();
// large-scale test
// TODO(philculliton): investigate rounding issues with large matrices.
// Causes test timeout.
// TestTiledBatchMatMul<512, 24576, 3072, float>();
}
void TestMatVecAdd() {
@ -730,7 +674,6 @@ HWY_AFTER_NAMESPACE();
namespace gcpp {
HWY_BEFORE_TEST(OpsTest);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulBy);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
@ -738,7 +681,6 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledBatchMatMul);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledMatMul);
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop);

View File

@ -15,27 +15,22 @@
// Command line text interface to gemma.
#include <ctime>
#include <iostream>
#include <random>
#include <string>
#include <string_view>
#include <thread> // NOLINT
#include <vector>
// Placeholder for internal header, do not modify.
#include "compression/compress.h"
#include "gemma/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/gemma.h" // Gemma
#include "util/app.h"
#include "util/args.h" // HasHelp
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/per_target.h"
#include "hwy/profiler.h"
#include "hwy/timer.h"
#if (!defined(HWY_VERSION_LT) || HWY_VERSION_LT(1, 2)) && !HWY_IDE
#error "Please update to version 1.2 of github.com/google/highway."
@ -57,56 +52,6 @@ static constexpr std::string_view kAsciiArtBanner = R""(
|___/ |_| |_|
)"";
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
loader.Print(app.verbosity);
inference.Print(app.verbosity);
app.Print(app.verbosity);
if (app.verbosity >= 2) {
time_t now = time(nullptr);
char* dt = ctime(&now); // NOLINT
std::cout << "Date & Time : " << dt
<< "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize
<< "\n"
<< "Hardware concurrency : "
<< std::thread::hardware_concurrency() << "\n"
<< "Instruction set : "
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
<< hwy::VectorBytes() * 8 << " bits)" << "\n";
char cpu100[100];
if (hwy::platform::GetCpuString(cpu100)) {
std::cout << "CPU : " << cpu100 << "\n";
}
std::cout << "Compiled config : " << CompiledConfig() << "\n"
<< "Weight Type : "
<< gcpp::StringFromType(loader.WeightType()) << "\n"
<< "EmbedderInput Type : "
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n";
}
}
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
gcpp::AppArgs& app) {
std::cerr
<< kAsciiArtBanner
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
"==========================================================\n\n"
"To run gemma.cpp, you need to "
"specify 3 required model loading arguments:\n"
" --tokenizer\n"
" --weights\n"
" --model.\n";
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
"--weights 2b-it-sfp.sbs --model 2b-it\n";
std::cerr << "\n*Model Loading Arguments*\n\n";
loader.Help();
std::cerr << "\n*Inference Arguments*\n\n";
inference.Help();
std::cerr << "\n*Application Arguments*\n\n";
app.Help();
std::cerr << "\n";
}
// The main Read-Eval-Print Loop.
void ReplGemma(gcpp::Gemma& model, ModelTraining training,
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
@ -118,12 +63,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
int prompt_size{};
std::mt19937 gen;
if (args.deterministic) {
gen.seed(42);
} else {
std::random_device rd;
gen.seed(rd());
}
InitGenerator(args, gen);
// callback function invoked for each generated token.
auto stream_token = [&abs_pos, &current_pos, &args, &gen, &prompt_size,
@ -162,7 +102,6 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
while (abs_pos < args.max_tokens) {
std::string prompt_string;
std::vector<int> prompt;
current_pos = 0;
{
PROFILER_ZONE("Gen.input");
@ -192,30 +131,11 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
continue;
}
if (training == ModelTraining::GEMMA_IT) {
// For instruction-tuned models: add control tokens.
prompt_string = "<start_of_turn>user\n" + prompt_string +
"<end_of_turn>\n<start_of_turn>model\n";
if (abs_pos != 0) {
// Prepend "<end_of_turn>" token if this is a multi-turn dialogue
// continuation.
prompt_string = "<end_of_turn>\n" + prompt_string;
}
}
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
if (abs_pos == 0) {
prompt.insert(prompt.begin(), gcpp::BOS_ID);
}
const std::vector<int> prompt =
WrapAndTokenize(model.Tokenizer(), training, abs_pos, prompt_string);
prompt_size = prompt.size();
std::cerr << "\n"
<< "[ Reading prompt ] " << std::flush;
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
@ -301,17 +221,20 @@ int main(int argc, char** argv) {
gcpp::AppArgs app(argc, argv);
if (gcpp::HasHelp(argc, argv)) {
ShowHelp(loader, inference, app);
std::cerr << gcpp::kAsciiArtBanner;
gcpp::ShowHelp(loader, inference, app);
return 0;
}
if (const char* error = loader.Validate()) {
ShowHelp(loader, inference, app);
std::cerr << gcpp::kAsciiArtBanner;
gcpp::ShowHelp(loader, inference, app);
HWY_ABORT("\nInvalid args: %s", error);
}
if (const char* error = inference.Validate()) {
ShowHelp(loader, inference, app);
std::cerr << gcpp::kAsciiArtBanner;
gcpp::ShowHelp(loader, inference, app);
HWY_ABORT("\nInvalid args: %s", error);
}

View File

@ -13,19 +13,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Command line text interface to gemma.
#include <stdio.h>
#include <fstream>
#include <iostream>
#include <random>
#include <set>
#include <sstream>
#include <algorithm>
#include <string>
#include <vector>
// Placeholder for internal header, do not modify.
#include "compression/io.h" // Path
#include "gemma/benchmark_helper.h"
#include "gemma/gemma.h" // Gemma
#include "util/app.h"
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
@ -34,164 +31,134 @@
namespace gcpp {
void JsonGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
hwy::ThreadPool& pool,
const InferenceArgs& args, int verbosity,
std::string& eot_line) {
PROFILER_ZONE("Gen.misc");
// token index within the current turn
int max_tokens = 4096;
struct JsonArgs : public ArgsBase<JsonArgs> {
JsonArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
std::mt19937 gen;
if (args.deterministic) {
gen.seed(42);
} else {
std::random_device rd;
gen.seed(rd());
Path input;
// Returns error string or nullptr if OK.
const char* Validate() const {
if (input.Empty()) return "Must specify --input";
if (!input.Exists()) return "--input file does not exist";
return nullptr;
}
float answers = 0.0;
float correct_answers = 0.0;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(input, "input", Path(), "Full pathname of mmlu.json.");
};
};
std::ifstream fJson("/tmp/mmlu.json");
std::stringstream buffer;
buffer << fJson.rdbuf();
auto json = nlohmann::json::parse(buffer.str());
std::vector<std::string> accept_tokens = {"A", "B", "C", "D"};
std::set<int> accept_token_set{};
for (const std::string& accept_token : accept_tokens) {
std::vector<int> accept_token_ids;
HWY_ASSERT(model.Tokenizer().Encode(accept_token, &accept_token_ids));
accept_token_set.insert(accept_token_ids.begin(), accept_token_ids.end());
// Linear search for a few tokens is faster than std::set.
// TODO: instead of accepting for each vocab entry, filter the logits once.
class TokenSet {
public:
TokenSet(const GemmaTokenizer& tokenizer,
const std::vector<std::string>& strings) {
all_tokens_.reserve(strings.size());
for (const std::string& str : strings) {
std::vector<int> tokens;
fprintf(stderr, "%s -> ", str.c_str());
HWY_ASSERT(tokenizer.Encode(str, &tokens));
for (int token : tokens) {
fprintf(stderr, "%d, ", token);
all_tokens_.push_back(token);
}
fprintf(stderr, "\n");
}
}
for (auto sample : json["samples"]) {
int abs_pos = 0; // absolute token index over all turns
int current_pos = 0;
int prompt_size{};
bool Contains(int token) const {
return std::find(all_tokens_.begin(), all_tokens_.end(), token) !=
all_tokens_.end();
}
// cout << "prompt:" << sample["prompt"] << endl;
const std::string& prompt_string = sample["prompt"];
std::vector<int> prompt;
private:
std::vector<int> all_tokens_;
};
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
prompt_size = prompt.size();
void Run(GemmaEnv& env, JsonArgs& json) {
PROFILER_ZONE("Run.all");
const std::string& correct_answer = accept_tokens[sample["input_label"]];
float answers = 0.0f;
float correct_answers = 0.0f;
// max_tokens = prompt_size + max_tokens;
auto json_data = nlohmann::json::parse(ReadFileToString(json.input));
const std::vector<std::string> accept_strings = {
"A", "B", "C", "D", //
" A", " B", " C", " D", //
"**", "**:", ":**", "The", "Answer", "is", ":", "."};
const TokenSet accept_set(env.GetModel()->Tokenizer(), accept_strings);
for (auto sample : json_data["samples"]) {
const int id = sample["i"];
fprintf(stderr, "Processing question %d\n", id);
const std::string& correct_answer = accept_strings[sample["input_label"]];
std::string prompt_string = sample["prompt"];
// AcceptFunc restricts the output to one of these four tokens, so make an
// effort to steer the model towards that. See
// https://huggingface.co/blog/open-llm-leaderboard-mmlu
prompt_string +=
"What is start of the line with the correct answer? "
"Do not include any justifications or explanations. Reply only with a "
"letter.";
const std::vector<int> prompt =
WrapAndTokenize(env.GetModel()->Tokenizer(), env.ModelTrainingType(),
/*pos=*/0, prompt_string);
const size_t prompt_size = prompt.size();
std::vector<int> predicted_token_ids;
predicted_token_ids.reserve(max_tokens);
auto stream_token = [&current_pos, &prompt_size, &predicted_token_ids,
&accept_token_set](int token, float proba) {
predicted_token_ids.reserve(4096);
size_t current_pos = 0;
const StreamFunc stream_token = [&current_pos, prompt_size,
&predicted_token_ids](int token,
float proba) {
PROFILER_ZONE("Stream");
++current_pos;
if (current_pos > prompt_size) {
predicted_token_ids.push_back(token);
// If the generated token is in the accepted token set, return False.
// This will stop further generation.
return accept_token_set.find(token) == accept_token_set.end();
}
return true;
};
const AcceptFunc accept_token = [&current_pos, &prompt_size,
&accept_token_set](int token) {
// i.e. we have no constraints on accepted tokens
if (accept_token_set.empty()) {
return true;
}
if (current_pos >= prompt_size) {
return accept_token_set.find(token) != accept_token_set.end();
} else {
// auto-accept early tokens
return true;
}
};
// Although " A" is a token, it is difficult to associate that with the
// correct answer. Only accepting certain tokens is risky: (A) is easily
// confused with the word "A".
gcpp::TimingInfo timing_info;
gcpp::RuntimeConfig runtime_config = {
.max_tokens = args.max_tokens,
.max_generated_tokens = args.max_generated_tokens,
.temperature = args.temperature,
.verbosity = verbosity,
.gen = &gen,
.max_tokens = env.MaxTokens(),
.max_generated_tokens = 30,
.temperature = 0.0f,
.verbosity = env.Verbosity(),
.gen = &env.MutableGen(),
.stream_token = stream_token,
.accept_token = accept_token,
};
model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info);
env.GetModel()->Generate(runtime_config, prompt, /*pos=*/0,
env.MutableKVCache(), timing_info);
std::string output_string;
HWY_ASSERT(model.Tokenizer().Decode(predicted_token_ids, &output_string));
std::cout << "QuestionId: " << sample["i"] << "; "
<< "Predicted Answer: " << output_string << "; "
<< "Correct Answer: " << correct_answer << std::endl;
std::string output_string = env.StringFromTokens(predicted_token_ids);
fprintf(stderr, "Correct %s, model '%s'\n", correct_answer.c_str(),
output_string.c_str());
answers += 1.0;
answers += 1.0f;
if (output_string == correct_answer) {
correct_answers += 1.0;
correct_answers += 1.0f;
}
std::cout << "Running accuracy = " << "["
<< static_cast<int>(correct_answers) << "/"
<< static_cast<int>(answers) << "]" << " = "
<< correct_answers / answers << std::endl;
fprintf(stderr, "%.0f/%.0f = %.2f%%\n", correct_answers, answers,
correct_answers / answers);
}
}
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
loader.Print(app.verbosity);
inference.Print(app.verbosity);
app.Print(app.verbosity);
}
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
PROFILER_ZONE("Run.misc");
hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning workers to cores helps.
if (app.num_threads > 10) {
PinWorkersToCores(pool);
}
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
JsonGemma(model, kv_cache, pool, inference, app.verbosity, app.eot_line);
}
} // namespace gcpp
int main(int argc, char** argv) {
{
PROFILER_ZONE("Startup.misc");
// Placeholder for internal init, do not modify.
gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs inference(argc, argv);
gcpp::AppArgs app(argc, argv);
if (const char* error = loader.Validate()) {
fprintf(stderr,
"\ngemma.cpp\n---------\n\nTo run gemma.cpp, you need to "
"specify 3 required model loading arguments: --tokenizer, "
"--compressed_weights, "
"and --model.\n\nModel Loading Arguments\n\n");
loader.Help();
fprintf(stderr, "\nInference Arguments\n\n");
inference.Help();
fprintf(stderr, "\nApplication Arguments\n\n");
app.Help();
fprintf(stderr, "\n\n");
HWY_ABORT("\nInvalid args: %s", error);
}
gcpp::Run(loader, inference, app);
PROFILER_ZONE("Startup.all");
gcpp::GemmaEnv env(argc, argv);
gcpp::JsonArgs json(argc, argv);
gcpp::AbortIfInvalidArgs(json);
gcpp::Run(env, json);
}
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
return 0;

View File

@ -197,6 +197,14 @@ static inline HWY_MAYBE_UNUSED bool HasHelp(int argc, char* argv[]) {
return false;
}
template <class TArgs>
static inline HWY_MAYBE_UNUSED void AbortIfInvalidArgs(TArgs& args) {
if (const char* err = args.Validate()) {
args.Help();
HWY_ABORT("Problem with args: %s\n", err);
}
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_ARGS_H_