mirror of https://github.com/google/gemma.cpp.git
Add a benchmark and additional tests.
Also add a script to help running sanitizer builds, and do some cleanup. Co-authored-by: Andrey Mikhaylov <amik@google.com> Co-authored-by: Eugene Kliuchnikov <eustas@google.com> Co-authored-by: Sami Boukortt <sboukortt@google.com> Co-authored-by: Zoltan Szabadka <szabadka@google.com>
This commit is contained in:
parent
325ef06cf9
commit
5862d1f995
|
|
@ -90,6 +90,7 @@ Checks: "-*,\
|
||||||
-concurrency-mt-unsafe,\
|
-concurrency-mt-unsafe,\
|
||||||
-cppcoreguidelines-avoid-c-arrays,\
|
-cppcoreguidelines-avoid-c-arrays,\
|
||||||
-cppcoreguidelines-avoid-const-or-ref-data-members,\
|
-cppcoreguidelines-avoid-const-or-ref-data-members,\
|
||||||
|
-cppcoreguidelines-avoid-do-while,\
|
||||||
-cppcoreguidelines-avoid-goto,\
|
-cppcoreguidelines-avoid-goto,\
|
||||||
-cppcoreguidelines-avoid-magic-numbers,\
|
-cppcoreguidelines-avoid-magic-numbers,\
|
||||||
-cppcoreguidelines-avoid-non-const-global-variables,\
|
-cppcoreguidelines-avoid-non-const-global-variables,\
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,11 @@
|
||||||
name: build
|
name: build
|
||||||
|
|
||||||
# Trigger on push, pull request, or via manual dispatch.
|
# Trigger on push, pull request, or via manual dispatch.
|
||||||
on: [push, pull_request, workflow_dispatch]
|
on:
|
||||||
|
push:
|
||||||
|
pull_request:
|
||||||
|
types: [opened, reopened, labeled, unlabeled, synchronize]
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
.cache/
|
||||||
|
bazel-*/
|
||||||
|
build-*/
|
||||||
|
python/*/__pycache__
|
||||||
32
BUILD.bazel
32
BUILD.bazel
|
|
@ -86,6 +86,19 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "gemma_test",
|
||||||
|
srcs = ["gemma_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":args",
|
||||||
|
":gemma_lib",
|
||||||
|
":ops",
|
||||||
|
"@googletest//:gtest_main",
|
||||||
|
"@hwy//:hwy_test_util",
|
||||||
|
"@hwy//:thread_pool",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "app",
|
name = "app",
|
||||||
hdrs = [
|
hdrs = [
|
||||||
|
|
@ -132,3 +145,22 @@ cc_binary(
|
||||||
"@hwy//:thread_pool",
|
"@hwy//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "benchmark",
|
||||||
|
srcs = [
|
||||||
|
"benchmark.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":app",
|
||||||
|
":args",
|
||||||
|
":gemma_lib",
|
||||||
|
# "//base",
|
||||||
|
"//compression:compress",
|
||||||
|
"@hwy//:hwy",
|
||||||
|
"@hwy//:nanobenchmark",
|
||||||
|
"@hwy//:profiler",
|
||||||
|
"@hwy//:thread_pool",
|
||||||
|
"@nlohmann_json//:json",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,9 @@ FetchContent_MakeAvailable(highway)
|
||||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c EXCLUDE_FROM_ALL)
|
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c EXCLUDE_FROM_ALL)
|
||||||
FetchContent_MakeAvailable(sentencepiece)
|
FetchContent_MakeAvailable(sentencepiece)
|
||||||
|
|
||||||
|
FetchContent_Declare(json GIT_REPOSITORY https://github.com/nlohmann/json.git GIT_TAG 9cca280a4d0ccf0c08f47a99aa71d1b0e52f8d03 EXCLUDE_FROM_ALL)
|
||||||
|
FetchContent_MakeAvailable(json)
|
||||||
|
|
||||||
set(SOURCES
|
set(SOURCES
|
||||||
gemma.cc
|
gemma.cc
|
||||||
compression/blob_store.cc
|
compression/blob_store.cc
|
||||||
|
|
@ -60,17 +63,7 @@ if (WEIGHT_TYPE)
|
||||||
add_definitions(-DGEMMA_WEIGHT_T=${WEIGHT_TYPE})
|
add_definitions(-DGEMMA_WEIGHT_T=${WEIGHT_TYPE})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Executable Target
|
|
||||||
|
|
||||||
add_executable(gemma run.cc)
|
|
||||||
target_sources(gemma PRIVATE ${SOURCES})
|
|
||||||
set_property(TARGET gemma PROPERTY CXX_STANDARD 17)
|
|
||||||
target_link_libraries(gemma hwy hwy_contrib sentencepiece)
|
|
||||||
target_include_directories(gemma PRIVATE ./)
|
|
||||||
FetchContent_GetProperties(sentencepiece)
|
FetchContent_GetProperties(sentencepiece)
|
||||||
target_include_directories(gemma PRIVATE ${sentencepiece_SOURCE_DIR})
|
|
||||||
target_compile_definitions(gemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
|
|
||||||
target_compile_options(gemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
|
||||||
|
|
||||||
## Library Target
|
## Library Target
|
||||||
|
|
||||||
|
|
@ -84,11 +77,21 @@ target_include_directories(libgemma PUBLIC ${sentencepiece_SOURCE_DIR})
|
||||||
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
|
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
|
||||||
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
||||||
|
|
||||||
|
# Executable Target
|
||||||
|
|
||||||
|
add_executable(gemma run.cc)
|
||||||
|
target_link_libraries(gemma libgemma hwy hwy_contrib)
|
||||||
|
|
||||||
|
add_executable(benchmark benchmark.cc)
|
||||||
|
target_link_libraries(benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
|
||||||
|
|
||||||
|
## Tests
|
||||||
set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests")
|
set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests")
|
||||||
if (GEMMA_ENABLE_TESTS)
|
if (GEMMA_ENABLE_TESTS)
|
||||||
|
|
||||||
set(GEMMA_TEST_FILES
|
set(GEMMA_TEST_FILES
|
||||||
ops_test.cc
|
ops_test.cc
|
||||||
|
gemma_test.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)
|
foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,13 @@ http_archive(
|
||||||
strip_prefix = "highway-1.1.0",
|
strip_prefix = "highway-1.1.0",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_archive(
|
||||||
|
name = "nlohmann_json",
|
||||||
|
urls = ["https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.zip"],
|
||||||
|
integrity = "sha256-BAIrBdgG61/3MCPCgLaGl9Erk+G3JnoLIqGjnsdXgGk=",
|
||||||
|
strip_prefix = "json-3.11.3",
|
||||||
|
)
|
||||||
|
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "com_google_sentencepiece",
|
name = "com_google_sentencepiece",
|
||||||
sha256 = "8409b0126ebd62b256c685d5757150cf7fcb2b92a2f2b98efb3f38fc36719754",
|
sha256 = "8409b0126ebd62b256c685d5757150cf7fcb2b92a2f2b98efb3f38fc36719754",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,293 @@
|
||||||
|
#include <algorithm>
|
||||||
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <ostream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
#include "hwy/per_target.h"
|
||||||
|
#include "hwy/profiler.h"
|
||||||
|
#include "hwy/timer.h"
|
||||||
|
#include "nlohmann/json.hpp"
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "util/app.h"
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "util/args.h"
|
||||||
|
|
||||||
|
using json = nlohmann::json;
|
||||||
|
|
||||||
|
class BenchmarkArgs : public gcpp::ArgsBase<BenchmarkArgs> {
|
||||||
|
public:
|
||||||
|
BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
|
|
||||||
|
gcpp::Path goldens;
|
||||||
|
gcpp::Path summarize_text;
|
||||||
|
gcpp::Path cross_entropy;
|
||||||
|
gcpp::Path trivia_qa;
|
||||||
|
size_t max_questions;
|
||||||
|
size_t batch_tokens;
|
||||||
|
|
||||||
|
template <class Visitor>
|
||||||
|
void ForEach(const Visitor& visitor) {
|
||||||
|
visitor(goldens.path, "goldens_dir", std::string(""),
|
||||||
|
"Directory containing golden files", 2);
|
||||||
|
visitor(summarize_text.path, "summarize_text", std::string(""),
|
||||||
|
"Path to text file to summarize", 2);
|
||||||
|
visitor(cross_entropy.path, "cross_entropy", std::string(""),
|
||||||
|
"Path to text file to compute cross entropy on", 2);
|
||||||
|
visitor(trivia_qa.path, "trivia_qa", std::string(""),
|
||||||
|
"Path to json file containing TriviaQA entries", 2);
|
||||||
|
visitor(max_questions, "max_questions", (size_t)20,
|
||||||
|
"Maximum number of questions to ask from --trivial_qa input", 2);
|
||||||
|
visitor(batch_tokens, "batch_tokens", (size_t)0,
|
||||||
|
"If not zero, break prompt into batches of this size and compute "
|
||||||
|
"cross entropy on them independently.",
|
||||||
|
2);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void LogSpeedStats(const double time_start, size_t total_tokens) {
|
||||||
|
const double time_end = hwy::platform::Now();
|
||||||
|
const double time_elapsed = time_end - time_start;
|
||||||
|
const double tok_sec = total_tokens / time_elapsed;
|
||||||
|
std::cout << total_tokens << " tokens in " << time_elapsed << " seconds"
|
||||||
|
<< " [" << tok_sec << " tokens / sec" << "]\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::string, int> QueryModel(
|
||||||
|
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
||||||
|
gcpp::KVCache& kv_cache, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
||||||
|
const std::string& input) {
|
||||||
|
std::vector<int> prompt;
|
||||||
|
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
||||||
|
|
||||||
|
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
||||||
|
// if needed.
|
||||||
|
prompt.insert(prompt.begin(), 2);
|
||||||
|
std::string res;
|
||||||
|
size_t total_tokens = 0;
|
||||||
|
auto accept_token = [](int) { return true; };
|
||||||
|
std::mt19937 gen;
|
||||||
|
gen.seed(42);
|
||||||
|
|
||||||
|
const double time_start = hwy::platform::Now();
|
||||||
|
auto stream_token = [&res, &total_tokens, &time_start, &app,
|
||||||
|
tokenizer = model.Tokenizer()](int token, float) {
|
||||||
|
++total_tokens;
|
||||||
|
std::string token_text;
|
||||||
|
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text));
|
||||||
|
res += token_text;
|
||||||
|
if (app.verbosity >= 1 && total_tokens % 100 == 0) {
|
||||||
|
LogSpeedStats(time_start, total_tokens);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
if (app.verbosity >= 2) {
|
||||||
|
std::cout << args.max_tokens << " " << args.max_generated_tokens << " "
|
||||||
|
<< args.temperature;
|
||||||
|
}
|
||||||
|
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||||
|
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
|
||||||
|
inner_pool, stream_token, accept_token, gen, app.verbosity);
|
||||||
|
if (app.verbosity >= 1) {
|
||||||
|
LogSpeedStats(time_start, total_tokens);
|
||||||
|
}
|
||||||
|
return {res, total_tokens};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<std::string, std::string>> load_goldens(
|
||||||
|
const std::string& path) {
|
||||||
|
std::ifstream goldens_file(path);
|
||||||
|
if (!goldens_file) {
|
||||||
|
std::cout << "Could not load goldens file: " << path << "\n" << std::flush;
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
std::vector<std::pair<std::string, std::string>> res;
|
||||||
|
std::string query_separator;
|
||||||
|
std::string query;
|
||||||
|
std::string answer_separator;
|
||||||
|
std::string answer;
|
||||||
|
while (std::getline(goldens_file, query_separator) &&
|
||||||
|
std::getline(goldens_file, query) &&
|
||||||
|
std::getline(goldens_file, answer_separator) &&
|
||||||
|
std::getline(goldens_file, answer)) {
|
||||||
|
res.push_back({query, answer});
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ReadFile(const gcpp::Path& path) {
|
||||||
|
std::ifstream text_file(path.path);
|
||||||
|
if (!text_file) {
|
||||||
|
std::cout << "Could not open file: " << path.path << "\n" << std::flush;
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
std::stringstream buffer;
|
||||||
|
buffer << text_file.rdbuf();
|
||||||
|
return buffer.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||||
|
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
||||||
|
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
||||||
|
const std::string& golden_path) {
|
||||||
|
const std::vector<std::pair<std::string, std::string>> queries_answers =
|
||||||
|
load_goldens(golden_path);
|
||||||
|
int correct_answers = 0;
|
||||||
|
int total_tokens = 0;
|
||||||
|
const double time_start = hwy::platform::Now();
|
||||||
|
for (const auto& [question, expected_answer] : queries_answers) {
|
||||||
|
const auto [answer, token_count] =
|
||||||
|
QueryModel(model, args, app, kv_cache, inner_pool, pool, question);
|
||||||
|
total_tokens += token_count;
|
||||||
|
if (answer.find(expected_answer) != std::string::npos) {
|
||||||
|
correct_answers++;
|
||||||
|
} else {
|
||||||
|
std::cout << "Wrong!\n";
|
||||||
|
std::cout << "Input: " << question << "\n";
|
||||||
|
std::cout << "Expected: " << expected_answer << "\n";
|
||||||
|
std::cout << "Output: " << answer << "\n\n" << std::flush;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LogSpeedStats(time_start, total_tokens);
|
||||||
|
|
||||||
|
std::cout << "Correct: " << correct_answers << " out of "
|
||||||
|
<< queries_answers.size() << "\n"
|
||||||
|
<< std::flush;
|
||||||
|
if (correct_answers != queries_answers.size()) {
|
||||||
|
return EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
return EXIT_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||||
|
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
||||||
|
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
||||||
|
const gcpp::Path& text) {
|
||||||
|
std::string prompt("Here is some text to summarize:\n");
|
||||||
|
prompt.append(ReadFile(text));
|
||||||
|
prompt.append("\nSummarize this text.\n");
|
||||||
|
const double time_start = hwy::platform::Now();
|
||||||
|
const auto [answer, token_count] =
|
||||||
|
QueryModel(model, args, app, kv_cache, inner_pool, pool, prompt);
|
||||||
|
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
|
||||||
|
LogSpeedStats(time_start, token_count);
|
||||||
|
return EXIT_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
|
||||||
|
gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
||||||
|
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
||||||
|
const gcpp::Path& text, size_t batch_tokens) {
|
||||||
|
std::string input = ReadFile(text);
|
||||||
|
std::vector<int> prompt;
|
||||||
|
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
||||||
|
prompt.resize(std::min<size_t>(args.max_tokens, prompt.size()));
|
||||||
|
std::cout << "Number of input tokens: " << prompt.size() << "\n";
|
||||||
|
const double time_start = hwy::platform::Now();
|
||||||
|
float total_entropy = 0.0f;
|
||||||
|
size_t total_input_len = 0;
|
||||||
|
if (batch_tokens == 0) batch_tokens = prompt.size();
|
||||||
|
for (size_t pos = 0; pos < prompt.size(); pos += batch_tokens) {
|
||||||
|
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
||||||
|
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||||
|
prompt.begin() + pos + num_tokens);
|
||||||
|
auto kv_cache = CreateKVCache(model_type);
|
||||||
|
float entropy =
|
||||||
|
ComputeCrossEntropy(model, num_tokens, prompt_slice, kv_cache, pool,
|
||||||
|
inner_pool, app.verbosity);
|
||||||
|
total_entropy += entropy;
|
||||||
|
LogSpeedStats(time_start, pos + num_tokens);
|
||||||
|
std::string text_slice;
|
||||||
|
HWY_ASSERT(model.Tokenizer()->Decode(prompt_slice, &text_slice));
|
||||||
|
total_input_len += text_slice.size();
|
||||||
|
printf("Cross entropy per byte: %f [cumulative: %f]\n",
|
||||||
|
entropy / text_slice.size(), total_entropy / total_input_len);
|
||||||
|
}
|
||||||
|
return EXIT_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||||
|
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
||||||
|
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
||||||
|
const gcpp::Path& json_file, size_t max_questions) {
|
||||||
|
std::ifstream trivia_file(json_file.path);
|
||||||
|
if (!trivia_file) {
|
||||||
|
std::cout << "Could not load file: " << json_file.path << "\n"
|
||||||
|
<< std::flush;
|
||||||
|
return EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
std::string line;
|
||||||
|
size_t correct_answers = 0;
|
||||||
|
size_t i = 0;
|
||||||
|
while (std::getline(trivia_file, line)) {
|
||||||
|
json data = json::parse(line);
|
||||||
|
const auto [answer, token_count] = QueryModel(
|
||||||
|
model, args, app, kv_cache, inner_pool, pool, data["question"]);
|
||||||
|
std::cout << answer << "\n";
|
||||||
|
bool correct = false;
|
||||||
|
for (const std::string expected : data["answer"]["aliases"]) {
|
||||||
|
if (answer.find(expected) != std::string::npos) {
|
||||||
|
correct = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (correct) {
|
||||||
|
++correct_answers;
|
||||||
|
std::cout << "CORRECT\n\n";
|
||||||
|
} else {
|
||||||
|
std::cout << "WRONG\n\n";
|
||||||
|
}
|
||||||
|
if (++i >= max_questions) break;
|
||||||
|
}
|
||||||
|
printf("Correct answers: %zu / %zu\n", correct_answers, i);
|
||||||
|
return EXIT_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Run this in the same way as gemma, p.ex.:
|
||||||
|
./benchmark --tokenizer tokenizer.spm --model 2b-it --weights \
|
||||||
|
2b-it-sfp.sbs --goldens_dir "../goldens"
|
||||||
|
*/
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
gcpp::LoaderArgs loader(argc, argv);
|
||||||
|
gcpp::InferenceArgs args(argc, argv); // inference
|
||||||
|
gcpp::AppArgs app(argc, argv);
|
||||||
|
BenchmarkArgs benchmark_args(argc, argv);
|
||||||
|
|
||||||
|
hwy::ThreadPool inner_pool(0);
|
||||||
|
hwy::ThreadPool pool(app.num_threads);
|
||||||
|
// For many-core, pinning threads to cores helps.
|
||||||
|
if (app.num_threads > 10) {
|
||||||
|
gcpp::PinThreadToCore(app.num_threads - 1); // Main thread
|
||||||
|
|
||||||
|
pool.Run(0, pool.NumThreads(), [](uint64_t /*task*/, size_t thread) {
|
||||||
|
gcpp::PinThreadToCore(thread);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
||||||
|
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||||
|
|
||||||
|
if (!benchmark_args.goldens.path.empty()) {
|
||||||
|
const std::string golden_path =
|
||||||
|
benchmark_args.goldens.path + "/" + loader.model_type + ".txt";
|
||||||
|
return BenchmarkGoldens(model, args, app, kv_cache, inner_pool, pool,
|
||||||
|
golden_path);
|
||||||
|
} else if (!benchmark_args.summarize_text.path.empty()) {
|
||||||
|
return BenchmarkSummary(model, args, app, kv_cache, inner_pool, pool,
|
||||||
|
benchmark_args.summarize_text);
|
||||||
|
} else if (!benchmark_args.cross_entropy.path.empty()) {
|
||||||
|
return BenchmarkCrossEntropy(model, loader.ModelType(), args, app,
|
||||||
|
inner_pool, pool, benchmark_args.cross_entropy,
|
||||||
|
benchmark_args.batch_tokens);
|
||||||
|
} else if (!benchmark_args.trivia_qa.path.empty()) {
|
||||||
|
return BenchmarkTriviaQA(model, args, app, kv_cache, inner_pool, pool,
|
||||||
|
benchmark_args.trivia_qa,
|
||||||
|
benchmark_args.max_questions);
|
||||||
|
}
|
||||||
|
std::cout << "No benchmark command given." << "\n" << std::flush;
|
||||||
|
return EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,299 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
# Copyright 2024 Google LLC
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
MYDIR=$(dirname $(realpath "$0"))
|
||||||
|
BUILD_DIR="${BUILD_DIR:-${MYDIR}/build}"
|
||||||
|
|
||||||
|
CMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE:-RelWithDebInfo}"
|
||||||
|
CMAKE_C_COMPILER="${CMAKE_C_COMPILER:-clang-14}"
|
||||||
|
CMAKE_CXX_COMPILER="${CMAKE_CXX_COMPILER:-clang++-14}"
|
||||||
|
# Convenience flag to pass both CMAKE_C_FLAGS and CMAKE_CXX_FLAGS
|
||||||
|
CMAKE_FLAGS="${CMAKE_FLAGS:-}"
|
||||||
|
CMAKE_C_FLAGS="${CMAKE_C_FLAGS:-} ${CMAKE_FLAGS}"
|
||||||
|
CMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS:-} ${CMAKE_FLAGS}"
|
||||||
|
CMAKE_EXE_LINKER_FLAGS="${CMAKE_EXE_LINKER_FLAGS:-}"
|
||||||
|
CMAKE_MODULE_LINKER_FLAGS="${CMAKE_MODULE_LINKER_FLAGS:-}"
|
||||||
|
CMAKE_SHARED_LINKER_FLAGS="${CMAKE_SHARED_LINKER_FLAGS:-}"
|
||||||
|
|
||||||
|
# Local flags passed to sanitizers.
|
||||||
|
UBSAN_FLAGS=(
|
||||||
|
-fsanitize=alignment
|
||||||
|
-fsanitize=bool
|
||||||
|
-fsanitize=bounds
|
||||||
|
-fsanitize=builtin
|
||||||
|
-fsanitize=enum
|
||||||
|
-fsanitize=float-cast-overflow
|
||||||
|
-fsanitize=float-divide-by-zero
|
||||||
|
-fsanitize=integer-divide-by-zero
|
||||||
|
-fsanitize=null
|
||||||
|
-fsanitize=object-size
|
||||||
|
-fsanitize=pointer-overflow
|
||||||
|
-fsanitize=return
|
||||||
|
-fsanitize=returns-nonnull-attribute
|
||||||
|
-fsanitize=shift-base
|
||||||
|
-fsanitize=shift-exponent
|
||||||
|
-fsanitize=unreachable
|
||||||
|
-fsanitize=vla-bound
|
||||||
|
|
||||||
|
-fno-sanitize-recover=undefined
|
||||||
|
-fsanitize-recover=alignment
|
||||||
|
)
|
||||||
|
|
||||||
|
CLANG_VERSION="${CLANG_VERSION:-}"
|
||||||
|
# Detect the clang version suffix and store it in CLANG_VERSION. For example,
|
||||||
|
# "6.0" for clang 6 or "7" for clang 7.
|
||||||
|
detect_clang_version() {
|
||||||
|
if [[ -n "${CLANG_VERSION}" ]]; then
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
local clang_version=$("${CMAKE_C_COMPILER:-clang}" --version | head -n1)
|
||||||
|
clang_version=${clang_version#"Debian "}
|
||||||
|
clang_version=${clang_version#"Ubuntu "}
|
||||||
|
local llvm_tag
|
||||||
|
case "${clang_version}" in
|
||||||
|
"clang version 6."*)
|
||||||
|
CLANG_VERSION="6.0"
|
||||||
|
;;
|
||||||
|
"clang version "*)
|
||||||
|
# Any other clang version uses just the major version number.
|
||||||
|
local suffix="${clang_version#clang version }"
|
||||||
|
CLANG_VERSION="${suffix%%.*}"
|
||||||
|
;;
|
||||||
|
"emcc"*)
|
||||||
|
# We can't use asan or msan in the emcc case.
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown clang version: ${clang_version}" >&2
|
||||||
|
return 1
|
||||||
|
esac
|
||||||
|
}
|
||||||
|
|
||||||
|
# Temporary files cleanup hooks.
|
||||||
|
CLEANUP_FILES=()
|
||||||
|
cleanup() {
|
||||||
|
if [[ ${#CLEANUP_FILES[@]} -ne 0 ]]; then
|
||||||
|
rm -fr "${CLEANUP_FILES[@]}"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Executed on exit.
|
||||||
|
on_exit() {
|
||||||
|
local retcode="$1"
|
||||||
|
# Always cleanup the CLEANUP_FILES.
|
||||||
|
cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
trap 'retcode=$?; { set +x; } 2>/dev/null; on_exit ${retcode}' INT TERM EXIT
|
||||||
|
|
||||||
|
|
||||||
|
# Install libc++ libraries compiled with msan in the msan_prefix for the current
|
||||||
|
# compiler version.
|
||||||
|
cmd_msan_install() {
|
||||||
|
local tmpdir=$(mktemp -d)
|
||||||
|
CLEANUP_FILES+=("${tmpdir}")
|
||||||
|
# Detect the llvm to install:
|
||||||
|
detect_clang_version
|
||||||
|
# Allow overriding the LLVM checkout.
|
||||||
|
local llvm_root="${LLVM_ROOT:-}"
|
||||||
|
if [ -z "${llvm_root}" ]; then
|
||||||
|
local llvm_tag="llvmorg-${CLANG_VERSION}.0.0"
|
||||||
|
case "${CLANG_VERSION}" in
|
||||||
|
"6.0")
|
||||||
|
llvm_tag="llvmorg-6.0.1"
|
||||||
|
;;
|
||||||
|
"7")
|
||||||
|
llvm_tag="llvmorg-7.0.1"
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
local llvm_targz="${tmpdir}/${llvm_tag}.tar.gz"
|
||||||
|
curl -L --show-error -o "${llvm_targz}" \
|
||||||
|
"https://github.com/llvm/llvm-project/archive/${llvm_tag}.tar.gz"
|
||||||
|
tar -C "${tmpdir}" -zxf "${llvm_targz}"
|
||||||
|
llvm_root="${tmpdir}/llvm-project-${llvm_tag}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
local msan_prefix="${HOME}/.msan/${CLANG_VERSION}"
|
||||||
|
rm -rf "${msan_prefix}"
|
||||||
|
|
||||||
|
declare -A CMAKE_EXTRAS
|
||||||
|
CMAKE_EXTRAS[libcxx]="\
|
||||||
|
-DLIBCXX_CXX_ABI=libstdc++ \
|
||||||
|
-DLIBCXX_INSTALL_EXPERIMENTAL_LIBRARY=ON \
|
||||||
|
-DLIBCXX_INCLUDE_BENCHMARKS=OFF"
|
||||||
|
|
||||||
|
for project in libcxx; do
|
||||||
|
local proj_build="${tmpdir}/build-${project}"
|
||||||
|
local proj_dir="${llvm_root}/${project}"
|
||||||
|
mkdir -p "${proj_build}"
|
||||||
|
cmake -B"${proj_build}" -H"${proj_dir}" \
|
||||||
|
-G Ninja \
|
||||||
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
|
-DLLVM_USE_SANITIZER=Memory \
|
||||||
|
-DLLVM_PATH="${llvm_root}/llvm" \
|
||||||
|
-DLLVM_CONFIG_PATH="$(which llvm-config llvm-config-7 llvm-config-6.0 | \
|
||||||
|
head -n1)" \
|
||||||
|
-DCMAKE_C_COMPILER="${CMAKE_C_COMPILER}" \
|
||||||
|
-DCMAKE_CXX_COMPILER="${CMAKE_CXX_COMPILER}" \
|
||||||
|
-DCMAKE_C_FLAGS="${CMAKE_C_FLAGS}" \
|
||||||
|
-DCMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS}" \
|
||||||
|
-DCMAKE_EXE_LINKER_FLAGS="${CMAKE_EXE_LINKER_FLAGS}" \
|
||||||
|
-DCMAKE_SHARED_LINKER_FLAGS="${CMAKE_SHARED_LINKER_FLAGS}" \
|
||||||
|
-DCMAKE_INSTALL_PREFIX="${msan_prefix}" \
|
||||||
|
${CMAKE_EXTRAS[${project}]}
|
||||||
|
cmake --build "${proj_build}"
|
||||||
|
ninja -C "${proj_build}" install
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd_msan() {
|
||||||
|
detect_clang_version
|
||||||
|
local msan_prefix="${HOME}/.msan/${CLANG_VERSION}"
|
||||||
|
if [[ ! -d "${msan_prefix}" || -e "${msan_prefix}/lib/libc++abi.a" ]]; then
|
||||||
|
# Install msan libraries for this version if needed or if an older version
|
||||||
|
# with libc++abi was installed.
|
||||||
|
cmd_msan_install
|
||||||
|
fi
|
||||||
|
|
||||||
|
local msan_c_flags=(
|
||||||
|
-fsanitize=memory
|
||||||
|
-fno-omit-frame-pointer
|
||||||
|
|
||||||
|
-g
|
||||||
|
-DMEMORY_SANITIZER
|
||||||
|
|
||||||
|
# Force gtest to not use the cxxbai.
|
||||||
|
-DGTEST_HAS_CXXABI_H_=0
|
||||||
|
|
||||||
|
-fsanitize-memory-track-origins
|
||||||
|
)
|
||||||
|
|
||||||
|
local msan_cxx_flags=(
|
||||||
|
"${msan_c_flags[@]}"
|
||||||
|
|
||||||
|
# Some C++ sources don't use the std at all, so the -stdlib=libc++ is unused
|
||||||
|
# in those cases. Ignore the warning.
|
||||||
|
-Wno-unused-command-line-argument
|
||||||
|
-stdlib=libc++
|
||||||
|
|
||||||
|
# We include the libc++ from the msan directory instead, so we don't want
|
||||||
|
# the std includes.
|
||||||
|
-nostdinc++
|
||||||
|
-cxx-isystem"${msan_prefix}/include/c++/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
local msan_linker_flags=(
|
||||||
|
-L"${msan_prefix}"/lib
|
||||||
|
-Wl,-rpath -Wl,"${msan_prefix}"/lib/
|
||||||
|
)
|
||||||
|
|
||||||
|
CMAKE_C_FLAGS+=" ${msan_c_flags[@]} ${UBSAN_FLAGS[@]}"
|
||||||
|
CMAKE_CXX_FLAGS+=" ${msan_cxx_flags[@]} ${UBSAN_FLAGS[@]}"
|
||||||
|
CMAKE_EXE_LINKER_FLAGS+=" ${msan_linker_flags[@]}"
|
||||||
|
CMAKE_MODULE_LINKER_FLAGS+=" ${msan_linker_flags[@]}"
|
||||||
|
CMAKE_SHARED_LINKER_FLAGS+=" ${msan_linker_flags[@]}"
|
||||||
|
cmake_configure "$@" \
|
||||||
|
-DCMAKE_CROSSCOMPILING=1 -DRUN_HAVE_STD_REGEX=0 -DRUN_HAVE_POSIX_REGEX=0 \
|
||||||
|
-DCMAKE_REQUIRED_LINK_OPTIONS="${msan_linker_flags[@]}"
|
||||||
|
}
|
||||||
|
|
||||||
|
cmake_configure() {
|
||||||
|
local args=(
|
||||||
|
-B"${BUILD_DIR}" -H"${MYDIR}"
|
||||||
|
-DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}"
|
||||||
|
-G Ninja
|
||||||
|
-DCMAKE_C_COMPILER="${CMAKE_C_COMPILER}"
|
||||||
|
-DCMAKE_CXX_COMPILER="${CMAKE_CXX_COMPILER}"
|
||||||
|
-DCMAKE_C_FLAGS="${CMAKE_C_FLAGS}"
|
||||||
|
-DCMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS}"
|
||||||
|
-DCMAKE_EXE_LINKER_FLAGS="${CMAKE_EXE_LINKER_FLAGS}"
|
||||||
|
-DCMAKE_MODULE_LINKER_FLAGS="${CMAKE_MODULE_LINKER_FLAGS}"
|
||||||
|
-DCMAKE_SHARED_LINKER_FLAGS="${CMAKE_SHARED_LINKER_FLAGS}"
|
||||||
|
-DGEMMA_ENABLE_TESTS=ON
|
||||||
|
)
|
||||||
|
|
||||||
|
cmake "${args[@]}" "$@"
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd_opt() {
|
||||||
|
CMAKE_BUILD_TYPE="RelWithDebInfo"
|
||||||
|
cmake_configure "$@"
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd_asan() {
|
||||||
|
CMAKE_C_FLAGS+=" -g -DADDRESS_SANITIZER -fsanitize=address ${UBSAN_FLAGS[@]}"
|
||||||
|
CMAKE_CXX_FLAGS+=" -g -DADDRESS_SANITIZER -fsanitize=address \
|
||||||
|
${UBSAN_FLAGS[@]}"
|
||||||
|
cmake_configure "$@"
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd_tsan() {
|
||||||
|
SANITIZER="tsan"
|
||||||
|
local tsan_args=(
|
||||||
|
-g
|
||||||
|
-DTHREAD_SANITIZER
|
||||||
|
${UBSAN_FLAGS[@]}
|
||||||
|
-fsanitize=thread
|
||||||
|
)
|
||||||
|
CMAKE_C_FLAGS+=" ${tsan_args[@]}"
|
||||||
|
CMAKE_CXX_FLAGS+=" ${tsan_args[@]}"
|
||||||
|
|
||||||
|
CMAKE_BUILD_TYPE="RelWithDebInfo"
|
||||||
|
cmake_configure "$@"
|
||||||
|
}
|
||||||
|
|
||||||
|
main() {
|
||||||
|
local cmd="${1:-}"
|
||||||
|
if [[ -z "${cmd}" ]]; then
|
||||||
|
cat >&2 <<EOF
|
||||||
|
Use: $0 CMD
|
||||||
|
|
||||||
|
Where cmd is one of:
|
||||||
|
opt Build and test a Release with symbols build.
|
||||||
|
asan Build and test an ASan (AddressSanitizer) build.
|
||||||
|
msan Build and test an MSan (MemorySanitizer) build. Needs to have msan
|
||||||
|
c++ libs installed with msan_install first.
|
||||||
|
msan_install Install the libc++ libraries required to build in msan mode. This
|
||||||
|
needs to be done once.
|
||||||
|
tsan Build and test a TSan (ThreadSanitizer) build.
|
||||||
|
|
||||||
|
You can pass some optional environment variables as well:
|
||||||
|
- BUILD_DIR: The output build directory (by default "$$repo/build")
|
||||||
|
- CMAKE_FLAGS: Convenience flag to pass both CMAKE_C_FLAGS and CMAKE_CXX_FLAGS.
|
||||||
|
|
||||||
|
These optional environment variables are forwarded to the cmake call as
|
||||||
|
parameters:
|
||||||
|
- CMAKE_BUILD_TYPE
|
||||||
|
- CMAKE_C_FLAGS
|
||||||
|
- CMAKE_CXX_FLAGS
|
||||||
|
- CMAKE_C_COMPILER
|
||||||
|
- CMAKE_CXX_COMPILER
|
||||||
|
- CMAKE_EXE_LINKER_FLAGS
|
||||||
|
- CMAKE_MODULE_LINKER_FLAGS
|
||||||
|
- CMAKE_SHARED_LINKER_FLAGS
|
||||||
|
|
||||||
|
Example:
|
||||||
|
BUILD_DIR=/tmp/build $0 opt
|
||||||
|
EOF
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
cmd="cmd_${cmd}"
|
||||||
|
shift
|
||||||
|
set -x
|
||||||
|
"${cmd}" "$@"
|
||||||
|
}
|
||||||
|
|
||||||
|
main "$@"
|
||||||
|
|
@ -23,13 +23,12 @@
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
|
|
||||||
std::vector<int> tokenize(
|
std::vector<int> tokenize(const std::string& prompt_string,
|
||||||
const std::string& prompt_string,
|
const gcpp::GemmaTokenizer* tokenizer) {
|
||||||
const sentencepiece::SentencePieceProcessor* tokenizer) {
|
|
||||||
std::string formatted = "<start_of_turn>user\n" + prompt_string +
|
std::string formatted = "<start_of_turn>user\n" + prompt_string +
|
||||||
"<end_of_turn>\n<start_of_turn>model\n";
|
"<end_of_turn>\n<start_of_turn>model\n";
|
||||||
std::vector<int> tokens;
|
std::vector<int> tokens;
|
||||||
HWY_ASSERT(tokenizer->Encode(formatted, &tokens).ok());
|
HWY_ASSERT(tokenizer->Encode(formatted, &tokens));
|
||||||
tokens.insert(tokens.begin(), 2); // BOS token
|
tokens.insert(tokens.begin(), 2); // BOS token
|
||||||
return tokens;
|
return tokens;
|
||||||
}
|
}
|
||||||
|
|
@ -58,14 +57,14 @@ int main(int argc, char** argv) {
|
||||||
size_t ntokens = tokens.size();
|
size_t ntokens = tokens.size();
|
||||||
|
|
||||||
// This callback function gets invoked everytime a token is generated
|
// This callback function gets invoked everytime a token is generated
|
||||||
auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()](
|
auto stream_token = [&pos, &ntokens, tokenizer = model.Tokenizer()](int token,
|
||||||
int token, float) {
|
float) {
|
||||||
++pos;
|
++pos;
|
||||||
if (pos < ntokens) {
|
if (pos < ntokens) {
|
||||||
// print feedback
|
// print feedback
|
||||||
} else if (token != gcpp::EOS_ID) {
|
} else if (token != gcpp::EOS_ID) {
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text).ok());
|
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text));
|
||||||
std::cout << token_text << std::flush;
|
std::cout << token_text << std::flush;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
|
@ -78,5 +77,5 @@ int main(int argc, char** argv) {
|
||||||
.verbosity = 0},
|
.verbosity = 0},
|
||||||
tokens, /*KV cache position = */ 0, kv_cache, pool,
|
tokens, /*KV cache position = */ 0, kv_cache, pool,
|
||||||
stream_token, gen);
|
stream_token, gen);
|
||||||
std::cout << std::endl;
|
std::cout << "\n";
|
||||||
}
|
}
|
||||||
|
|
|
||||||
168
gemma.cc
168
gemma.cc
|
|
@ -31,6 +31,9 @@
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "util/args.h" // Path
|
#include "util/args.h" // Path
|
||||||
|
// copybara:import_next_line:sentencepiece
|
||||||
|
#include "src/sentencepiece_processor.h"
|
||||||
|
// copybara:end
|
||||||
|
|
||||||
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
|
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
|
||||||
// compile pass, whereas we want this defined in the first.
|
// compile pass, whereas we want this defined in the first.
|
||||||
|
|
@ -49,6 +52,7 @@
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <regex>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
@ -330,7 +334,7 @@ struct Activations {
|
||||||
struct GemmaInterface {
|
struct GemmaInterface {
|
||||||
virtual ~GemmaInterface() = default;
|
virtual ~GemmaInterface() = default;
|
||||||
|
|
||||||
virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0;
|
virtual const GemmaTokenizer* Tokenizer() const = 0;
|
||||||
|
|
||||||
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
|
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||||
float temperature, const std::vector<int>& prompt,
|
float temperature, const std::vector<int>& prompt,
|
||||||
|
|
@ -339,6 +343,12 @@ struct GemmaInterface {
|
||||||
const StreamFunc& stream_token,
|
const StreamFunc& stream_token,
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||||
int verbosity) = 0;
|
int verbosity) = 0;
|
||||||
|
|
||||||
|
virtual float ComputeCrossEntropy(size_t max_tokens,
|
||||||
|
const std::vector<int>& prompt,
|
||||||
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
|
hwy::ThreadPool& inner_pool,
|
||||||
|
int verbosity) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class Config>
|
template <class Config>
|
||||||
|
|
@ -358,6 +368,29 @@ KVCache CreateKVCache(Model type) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class GemmaTokenizerImpl : public GemmaTokenizer {
|
||||||
|
public:
|
||||||
|
GemmaTokenizerImpl(
|
||||||
|
std::unique_ptr<sentencepiece::SentencePieceProcessor>&& impl)
|
||||||
|
: impl_(std::move(impl)) {}
|
||||||
|
bool Encode(const std::string& input,
|
||||||
|
std::vector<std::string>* pieces) const override {
|
||||||
|
return impl_->Encode(input, pieces).ok();
|
||||||
|
}
|
||||||
|
bool Encode(const std::string& input,
|
||||||
|
std::vector<int>* pieces) const override {
|
||||||
|
return impl_->Encode(input, pieces).ok();
|
||||||
|
}
|
||||||
|
// Given a sequence of ids, decodes it into a detokenized output.
|
||||||
|
bool Decode(const std::vector<int>& ids,
|
||||||
|
std::string* detokenized) const override {
|
||||||
|
return impl_->Decode(ids, detokenized).ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<sentencepiece::SentencePieceProcessor> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <class Config>
|
template <class Config>
|
||||||
void DeleteLayersPtrs(CompressedWeights<Config>* c_weights) {
|
void DeleteLayersPtrs(CompressedWeights<Config>* c_weights) {
|
||||||
|
|
@ -381,8 +414,8 @@ struct GemmaImpl : public GemmaInterface {
|
||||||
DeleteLayersPtrs(weights);
|
DeleteLayersPtrs(weights);
|
||||||
}
|
}
|
||||||
|
|
||||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const override {
|
const GemmaTokenizer* Tokenizer() const override {
|
||||||
return tokenizer.get();
|
return &tokenizer;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generate(size_t max_tokens, size_t max_generated_tokens,
|
void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||||
|
|
@ -392,7 +425,12 @@ struct GemmaImpl : public GemmaInterface {
|
||||||
const AcceptFunc& accept_token, std::mt19937&,
|
const AcceptFunc& accept_token, std::mt19937&,
|
||||||
int verbosity) override;
|
int verbosity) override;
|
||||||
|
|
||||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
|
||||||
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
|
hwy::ThreadPool& inner_pool,
|
||||||
|
int verbosity) override;
|
||||||
|
|
||||||
|
GemmaTokenizerImpl tokenizer;
|
||||||
hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8;
|
hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8;
|
||||||
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
|
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
|
||||||
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
|
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
|
||||||
|
|
@ -804,6 +842,78 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class TConfig>
|
||||||
|
std::string TokenString(GemmaImpl<TConfig>& gemma, int token) {
|
||||||
|
std::string token_str;
|
||||||
|
gemma.Tokenizer()->Decode({token}, &token_str);
|
||||||
|
return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'";
|
||||||
|
}
|
||||||
|
|
||||||
|
#define TOKEN(token_id) TokenString(gemma, token_id).c_str()
|
||||||
|
|
||||||
|
template <class TConfig>
|
||||||
|
void LogTopK(GemmaImpl<TConfig>& gemma, float* logits, float* dist, size_t len,
|
||||||
|
size_t k) {
|
||||||
|
std::vector<std::pair<float, int>> sorted(len);
|
||||||
|
for (int i = 0; i < len; ++i) {
|
||||||
|
sorted[i] = std::make_pair(dist[i], i);
|
||||||
|
}
|
||||||
|
std::sort(sorted.begin(), sorted.end(),
|
||||||
|
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
||||||
|
if (a.first != b.first) {
|
||||||
|
return a.first > b.first;
|
||||||
|
}
|
||||||
|
return a.second < b.second;
|
||||||
|
});
|
||||||
|
for (int i = 0; i < k; ++i) {
|
||||||
|
printf(" [#%-2d token %6d = %-12s %.2e %f]\n", i + 1, sorted[i].second,
|
||||||
|
TOKEN(sorted[i].second), sorted[i].first, logits[sorted[i].second]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class TConfig>
|
||||||
|
float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
|
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
|
hwy::ThreadPool& pool,
|
||||||
|
hwy::ThreadPool& inner_pool, int verbosity) {
|
||||||
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
|
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||||
|
Activations<TConfig, 1>& activations = *gemma.state.get();
|
||||||
|
const WeightsT<TConfig>& weights =
|
||||||
|
*reinterpret_cast<const WeightsT<TConfig>*>(gemma.weights_u8.get());
|
||||||
|
std::vector<float> logits(kVocabSize);
|
||||||
|
Softmax(activations.logits.data(), kVocabSize);
|
||||||
|
float total_entropy = 0.0f;
|
||||||
|
for (size_t pos = 0; pos < max_tokens && pos < prompt.size(); ++pos) {
|
||||||
|
if (verbosity >= 4) {
|
||||||
|
LogTopK(gemma, logits.data(), activations.logits.data(), kVocabSize, 10);
|
||||||
|
}
|
||||||
|
const int token = prompt[pos];
|
||||||
|
const float prob = activations.logits[token];
|
||||||
|
if (verbosity >= 3) {
|
||||||
|
printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, token,
|
||||||
|
TOKEN(token), prob, -std::log(prob) / std::log(2.0));
|
||||||
|
}
|
||||||
|
total_entropy -= std::max(std::log(prob), -64.0f);
|
||||||
|
if (verbosity >= 2 && pos % 100 == 99) {
|
||||||
|
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
|
||||||
|
total_entropy / std::log(2.0) / (pos + 1));
|
||||||
|
}
|
||||||
|
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool);
|
||||||
|
MatVec<kVocabSize, kModelDim>(weights.embedder_input_embedding, 0,
|
||||||
|
activations.x.data(),
|
||||||
|
activations.logits.data(), pool);
|
||||||
|
LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize);
|
||||||
|
memcpy(logits.data(), activations.logits.data(),
|
||||||
|
kVocabSize * sizeof(logits[0]));
|
||||||
|
Softmax(activations.logits.data(), kVocabSize);
|
||||||
|
}
|
||||||
|
return total_entropy / std::log(2.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef TOKEN
|
||||||
|
|
||||||
|
|
||||||
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
||||||
size_t max_generated_tokens, float temperature,
|
size_t max_generated_tokens, float temperature,
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
|
|
@ -828,6 +938,23 @@ void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
||||||
accept_token, gen, verbosity);
|
accept_token, gen, verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
||||||
|
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||||
|
int verbosity) {
|
||||||
|
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
||||||
|
inner_pool, verbosity);
|
||||||
|
}
|
||||||
|
|
||||||
|
float ComputeCrossEntropy7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
||||||
|
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||||
|
int verbosity) {
|
||||||
|
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
||||||
|
inner_pool, verbosity);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
||||||
// if weights = null, which happens during the first call where we attempt to
|
// if weights = null, which happens during the first call where we attempt to
|
||||||
// load from cache.
|
// load from cache.
|
||||||
|
|
@ -983,6 +1110,8 @@ HWY_EXPORT(LoadWeightsT);
|
||||||
HWY_EXPORT(CompressWeightsT);
|
HWY_EXPORT(CompressWeightsT);
|
||||||
HWY_EXPORT(Generate2B);
|
HWY_EXPORT(Generate2B);
|
||||||
HWY_EXPORT(Generate7B);
|
HWY_EXPORT(Generate7B);
|
||||||
|
HWY_EXPORT(ComputeCrossEntropy2B);
|
||||||
|
HWY_EXPORT(ComputeCrossEntropy7B);
|
||||||
|
|
||||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
|
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
|
||||||
KVCache kv_cache = {};
|
KVCache kv_cache = {};
|
||||||
|
|
@ -995,7 +1124,7 @@ template <class Config>
|
||||||
GemmaImpl<Config>::GemmaImpl(
|
GemmaImpl<Config>::GemmaImpl(
|
||||||
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
|
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
|
||||||
hwy::AlignedFreeUniquePtr<uint8_t[]>& weights_u8, hwy::ThreadPool& pool)
|
hwy::AlignedFreeUniquePtr<uint8_t[]>& weights_u8, hwy::ThreadPool& pool)
|
||||||
: tokenizer(std::move(tokenizer)),
|
: tokenizer(GemmaTokenizerImpl(std::move(tokenizer))),
|
||||||
weights_u8(std::move(weights_u8)),
|
weights_u8(std::move(weights_u8)),
|
||||||
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
|
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
|
||||||
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {}
|
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {}
|
||||||
|
|
@ -1023,6 +1152,22 @@ void GemmaImpl<ConfigGemma7B>::Generate(
|
||||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
float GemmaImpl<ConfigGemma2B>::ComputeCrossEntropy(
|
||||||
|
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) {
|
||||||
|
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy2B)(
|
||||||
|
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
float GemmaImpl<ConfigGemma7B>::ComputeCrossEntropy(
|
||||||
|
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) {
|
||||||
|
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy7B)(
|
||||||
|
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
|
||||||
|
}
|
||||||
|
|
||||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
||||||
|
|
@ -1056,7 +1201,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||||
|
|
||||||
Gemma::~Gemma() = default; // after GemmaInterface is defined
|
Gemma::~Gemma() = default; // after GemmaInterface is defined
|
||||||
|
|
||||||
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
|
const GemmaTokenizer* Gemma::Tokenizer() const {
|
||||||
return impl_->Tokenizer();
|
return impl_->Tokenizer();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1090,5 +1235,16 @@ void CompressWeights(gcpp::Model model, const Path& weights,
|
||||||
(model, weights, compressed_weights, pool);
|
(model, weights, compressed_weights, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||||
|
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||||
|
int verbosity) {
|
||||||
|
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||||
|
const float result = gemma.impl_->ComputeCrossEntropy(
|
||||||
|
max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
|
||||||
|
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
#endif // HWY_ONCE
|
#endif // HWY_ONCE
|
||||||
|
|
|
||||||
17
gemma.h
17
gemma.h
|
|
@ -60,11 +60,21 @@ struct RuntimeConfig {
|
||||||
|
|
||||||
struct GemmaInterface;
|
struct GemmaInterface;
|
||||||
|
|
||||||
|
class GemmaTokenizer {
|
||||||
|
public:
|
||||||
|
virtual bool Encode(const std::string& input,
|
||||||
|
std::vector<std::string>* pieces) const = 0;
|
||||||
|
virtual bool Encode(const std::string& input,
|
||||||
|
std::vector<int>* pieces) const = 0;
|
||||||
|
virtual bool Decode(const std::vector<int>& ids,
|
||||||
|
std::string* detokenized) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
struct Gemma {
|
struct Gemma {
|
||||||
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||||
hwy::ThreadPool& pool);
|
hwy::ThreadPool& pool);
|
||||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
||||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
const GemmaTokenizer* Tokenizer() const;
|
||||||
std::unique_ptr<GemmaInterface> impl_;
|
std::unique_ptr<GemmaInterface> impl_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -95,6 +105,11 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
||||||
void CompressWeights(gcpp::Model model, const Path& weights,
|
void CompressWeights(gcpp::Model model, const Path& weights,
|
||||||
const Path& compressed_weights, hwy::ThreadPool& pool);
|
const Path& compressed_weights, hwy::ThreadPool& pool);
|
||||||
|
|
||||||
|
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||||
|
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||||
|
int verbosity);
|
||||||
|
|
||||||
constexpr int EOS_ID = 1;
|
constexpr int EOS_ID = 1;
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,304 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "gemma.h"
|
||||||
|
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
#include "hwy/tests/test_util-inl.h"
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "ops.h"
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "util/args.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class GemmaTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
GemmaTest()
|
||||||
|
: weights("./2b-it-mqa.sbs"),
|
||||||
|
tokenizer("./tokenizer.spm"),
|
||||||
|
pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)),
|
||||||
|
inner_pool(0),
|
||||||
|
model_type(gcpp::Model::GEMMA_2B),
|
||||||
|
model(tokenizer, weights, model_type, pool) {
|
||||||
|
kv_cache = CreateKVCache(model_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string GemmaReply(const std::string& prompt_string) {
|
||||||
|
std::mt19937 gen;
|
||||||
|
gen.seed(42);
|
||||||
|
|
||||||
|
std::vector<int> prompt;
|
||||||
|
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
|
||||||
|
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
||||||
|
// if needed.
|
||||||
|
prompt.insert(prompt.begin(), 2);
|
||||||
|
|
||||||
|
std::vector<int> response;
|
||||||
|
auto stream_token = [&response](int token, float) {
|
||||||
|
response.push_back(token);
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
gcpp::GenerateGemma(
|
||||||
|
model, /*max_tokens=*/3072, /*max_generated_tokens=*/2048,
|
||||||
|
/*temperature=*/1.0, prompt, /*start_pos=*/0, kv_cache, pool,
|
||||||
|
inner_pool, stream_token,
|
||||||
|
/*accept=*/[](int) { return true; }, gen, /*verbosity=*/0);
|
||||||
|
std::string response_text;
|
||||||
|
HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text));
|
||||||
|
return response_text;
|
||||||
|
}
|
||||||
|
|
||||||
|
float GemmaCrossEntropy(const std::string& prompt_string) {
|
||||||
|
std::vector<int> prompt;
|
||||||
|
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
|
||||||
|
return gcpp::ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt,
|
||||||
|
kv_cache, pool, inner_pool,
|
||||||
|
/*verbosity=*/0) /
|
||||||
|
prompt_string.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestQuestions(const char* kQA[][2], size_t num_questions) {
|
||||||
|
for (size_t i = 0; i < num_questions; ++i) {
|
||||||
|
std::cout << "Question " << i + 1 << "\n\n";
|
||||||
|
std::string response = GemmaReply(kQA[i][0]);
|
||||||
|
std::cout << response << "\n\n";
|
||||||
|
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
gcpp::Path weights;
|
||||||
|
gcpp::Path tokenizer;
|
||||||
|
gcpp::KVCache kv_cache;
|
||||||
|
hwy::ThreadPool pool;
|
||||||
|
hwy::ThreadPool inner_pool;
|
||||||
|
gcpp::Model model_type = {};
|
||||||
|
gcpp::Gemma model;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(GemmaTest, Geography) {
|
||||||
|
static const char* kQA[][2] = {
|
||||||
|
{"What is the capital of Hungary?", "Budapest"},
|
||||||
|
{"How many states does the US have?", "50"},
|
||||||
|
{"list me ten biggest cities in the world", "Tokyo"},
|
||||||
|
};
|
||||||
|
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||||
|
TestQuestions(kQA, kNum);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GemmaTest, History) {
|
||||||
|
static const char* kQA[][2] = {
|
||||||
|
{"When was the Battle of Hastings?", "1066"},
|
||||||
|
{"Who fought at the Battle of Marathon?", "Greek"},
|
||||||
|
};
|
||||||
|
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||||
|
TestQuestions(kQA, kNum);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GemmaTest, Arithmetic) {
|
||||||
|
static const char* kQA[][2] = {
|
||||||
|
{"what is 13 + 14?", "27"},
|
||||||
|
{"what is 7 * 8", "56"},
|
||||||
|
};
|
||||||
|
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||||
|
TestQuestions(kQA, kNum);
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char kJingleBells[] = R"(
|
||||||
|
Dashing through the snow
|
||||||
|
In a one-horse open sleigh
|
||||||
|
O'er the fields we go
|
||||||
|
Laughing all the way
|
||||||
|
Bells on bobtails ring
|
||||||
|
Making spirits bright
|
||||||
|
What fun it is to ride and sing
|
||||||
|
A sleighing song tonight
|
||||||
|
)";
|
||||||
|
|
||||||
|
// The "Hay Draft" of the Gettysburg Address.
|
||||||
|
static const char kGettysburg[] = {
|
||||||
|
"Four score and seven years ago our fathers brought forth, upon this "
|
||||||
|
"continent, a new nation, conceived in Liberty, and dedicated to the "
|
||||||
|
"proposition that all men are created equal.\n\nNow we are engaged in a "
|
||||||
|
"great civil war, testing whether that nation, or any nation, so "
|
||||||
|
"conceived, and so dedicated, can long endure. We are met here on a great "
|
||||||
|
"battlefield of that war. We have come to dedicate a portion of it as a "
|
||||||
|
"final resting place for those who here gave their lives that that nation "
|
||||||
|
"might live. It is altogether fitting and proper that we should do "
|
||||||
|
"this.\n\nBut in a larger sense we can not dedicate -- we can not "
|
||||||
|
"consecrate -- we can not hallow this ground. The brave men, living and "
|
||||||
|
"dead, who struggled, here, have consecrated it far above our poor power "
|
||||||
|
"to add or detract. The world will little note, nor long remember, what we "
|
||||||
|
"say here, but can never forget what they did here. It is for us, the "
|
||||||
|
"living, rather to be dedicated here to the unfinished work which they "
|
||||||
|
"have, thus far, so nobly carried on. It is rather for us to be here "
|
||||||
|
"dedicated to the great task remaining before us -- that from these "
|
||||||
|
"honored dead we take increased devotion to that cause for which they here "
|
||||||
|
"gave the last full measure of devotion -- that we here highly resolve "
|
||||||
|
"that these dead shall not have died in vain; that this nation shall have "
|
||||||
|
"a new birth of freedom; and that this government of the people, by the "
|
||||||
|
"people, for the people, shall not perish from the earth.\n"};
|
||||||
|
|
||||||
|
// The Declaration of Independence.
|
||||||
|
static const char kDeclaration[] = {
|
||||||
|
"IN CONGRESS, July 4, 1776.\n\nThe unanimous Declaration of the thirteen "
|
||||||
|
"united States of America,\n\nWhen in the Course of human events, it "
|
||||||
|
"becomes necessary for one people to dissolve the political bands which "
|
||||||
|
"have connected them with another, and to assume among the powers of the "
|
||||||
|
"earth, the separate and equal station to which the Laws of Nature and of "
|
||||||
|
"Nature's God entitle them, a decent respect to the opinions of mankind "
|
||||||
|
"requires that they should declare the causes which impel them to the "
|
||||||
|
"separation.\n\nWe hold these truths to be self-evident, that all men are "
|
||||||
|
"created equal, that they are endowed by their Creator with certain "
|
||||||
|
"unalienable Rights, that among these are Life, Liberty and the pursuit of "
|
||||||
|
"Happiness.--That to secure these rights, Governments are instituted among "
|
||||||
|
"Men, deriving their just powers from the consent of the governed, --That "
|
||||||
|
"whenever any Form of Government becomes destructive of these ends, it is "
|
||||||
|
"the Right of the People to alter or to abolish it, and to institute new "
|
||||||
|
"Government, laying its foundation on such principles and organizing its "
|
||||||
|
"powers in such form, as to them shall seem most likely to effect their "
|
||||||
|
"Safety and Happiness. Prudence, indeed, will dictate that Governments "
|
||||||
|
"long established should not be changed for light and transient causes; "
|
||||||
|
"and accordingly all experience hath shewn, that mankind are more disposed "
|
||||||
|
"to suffer, while evils are sufferable, than to right themselves by "
|
||||||
|
"abolishing the forms to which they are accustomed. But when a long train "
|
||||||
|
"of abuses and usurpations, pursuing invariably the same Object evinces a "
|
||||||
|
"design to reduce them under absolute Despotism, it is their right, it is "
|
||||||
|
"their duty, to throw off such Government, and to provide new Guards for "
|
||||||
|
"their future security.--Such has been the patient sufferance of these "
|
||||||
|
"Colonies; and such is now the necessity which constrains them to alter "
|
||||||
|
"their former Systems of Government. The history of the present King of "
|
||||||
|
"Great Britain is a history of repeated injuries and usurpations, all "
|
||||||
|
"having in direct object the establishment of an absolute Tyranny over "
|
||||||
|
"these States. To prove this, let Facts be submitted to a candid "
|
||||||
|
"world.\n\nHe has refused his Assent to Laws, the most wholesome and "
|
||||||
|
"necessary for the public good.\nHe has forbidden his Governors to pass "
|
||||||
|
"Laws of immediate and pressing importance, unless suspended in their "
|
||||||
|
"operation till his Assent should be obtained; and when so suspended, he "
|
||||||
|
"has utterly neglected to attend to them.\nHe has refused to pass other "
|
||||||
|
"Laws for the accommodation of large districts of people, unless those "
|
||||||
|
"people would relinquish the right of Representation in the Legislature, a "
|
||||||
|
"right inestimable to them and formidable to tyrants only.\nHe has called "
|
||||||
|
"together legislative bodies at places unusual, uncomfortable, and distant "
|
||||||
|
"from the depository of their public Records, for the sole purpose of "
|
||||||
|
"fatiguing them into compliance with his measures.\nHe has dissolved "
|
||||||
|
"Representative Houses repeatedly, for opposing with manly firmness his "
|
||||||
|
"invasions on the rights of the people.\nHe has refused for a long time, "
|
||||||
|
"after such dissolutions, to cause others to be elected; whereby the "
|
||||||
|
"Legislative powers, incapable of Annihilation, have returned to the "
|
||||||
|
"People at large for their exercise; the State remaining in the mean time "
|
||||||
|
"exposed to all the dangers of invasion from without, and convulsions "
|
||||||
|
"within.\nHe has endeavoured to prevent the population of these States; "
|
||||||
|
"for that purpose obstructing the Laws for Naturalization of Foreigners; "
|
||||||
|
"refusing to pass others to encourage their migrations hither, and raising "
|
||||||
|
"the conditions of new Appropriations of Lands.\nHe has obstructed the "
|
||||||
|
"Administration of Justice, by refusing his Assent to Laws for "
|
||||||
|
"establishing Judiciary powers.\nHe has made Judges dependent on his Will "
|
||||||
|
"alone, for the tenure of their offices, and the amount and payment of "
|
||||||
|
"their salaries.\nHe has erected a multitude of New Offices, and sent "
|
||||||
|
"hither swarms of Officers to harrass our people, and eat out their "
|
||||||
|
"substance.\nHe has kept among us, in times of peace, Standing Armies "
|
||||||
|
"without the Consent of our legislatures.\nHe has affected to render the "
|
||||||
|
"Military independent of and superior to the Civil power.\nHe has combined "
|
||||||
|
"with others to subject us to a jurisdiction foreign to our constitution, "
|
||||||
|
"and unacknowledged by our laws; giving his Assent to their Acts of "
|
||||||
|
"pretended Legislation:\nFor Quartering large bodies of armed troops among "
|
||||||
|
"us:\nFor protecting them, by a mock Trial, from punishment for any "
|
||||||
|
"Murders which they should commit on the Inhabitants of these States:\nFor "
|
||||||
|
"cutting off our Trade with all parts of the world:\nFor imposing Taxes on "
|
||||||
|
"us without our Consent:\nFor depriving us in many cases, of the benefits "
|
||||||
|
"of Trial by Jury:\nFor transporting us beyond Seas to be tried for "
|
||||||
|
"pretended offences\nFor abolishing the free System of English Laws in a "
|
||||||
|
"neighbouring Province, establishing therein an Arbitrary government, and "
|
||||||
|
"enlarging its Boundaries so as to render it at once an example and fit "
|
||||||
|
"instrument for introducing the same absolute rule into these "
|
||||||
|
"Colonies:\nFor taking away our Charters, abolishing our most valuable "
|
||||||
|
"Laws, and altering fundamentally the Forms of our Governments:\nFor "
|
||||||
|
"suspending our own Legislatures, and declaring themselves invested with "
|
||||||
|
"power to legislate for us in all cases whatsoever.\nHe has abdicated "
|
||||||
|
"Government here, by declaring us out of his Protection and waging War "
|
||||||
|
"against us.\nHe has plundered our seas, ravaged our Coasts, burnt our "
|
||||||
|
"towns, and destroyed the lives of our people.\nHe is at this time "
|
||||||
|
"transporting large Armies of foreign Mercenaries to compleat the works of "
|
||||||
|
"death, desolation and tyranny, already begun with circumstances of "
|
||||||
|
"Cruelty & perfidy scarcely paralleled in the most barbarous ages, and "
|
||||||
|
"totally unworthy the Head of a civilized nation.\nHe has constrained our "
|
||||||
|
"fellow Citizens taken Captive on the high Seas to bear Arms against their "
|
||||||
|
"Country, to become the executioners of their friends and Brethren, or to "
|
||||||
|
"fall themselves by their Hands.\nHe has excited domestic insurrections "
|
||||||
|
"amongst us, and has endeavoured to bring on the inhabitants of our "
|
||||||
|
"frontiers, the merciless Indian Savages, whose known rule of warfare, is "
|
||||||
|
"an undistinguished destruction of all ages, sexes and conditions.\n\nIn "
|
||||||
|
"every stage of these Oppressions We have Petitioned for Redress in the "
|
||||||
|
"most humble terms: Our repeated Petitions have been answered only by "
|
||||||
|
"repeated injury. A Prince whose character is thus marked by every act "
|
||||||
|
"which may define a Tyrant, is unfit to be the ruler of a free "
|
||||||
|
"people.\n\nNor have We been wanting in attentions to our Brittish "
|
||||||
|
"brethren. We have warned them from time to time of attempts by their "
|
||||||
|
"legislature to extend an unwarrantable jurisdiction over us. We have "
|
||||||
|
"reminded them of the circumstances of our emigration and settlement here. "
|
||||||
|
"We have appealed to their native justice and magnanimity, and we have "
|
||||||
|
"conjured them by the ties of our common kindred to disavow these "
|
||||||
|
"usurpations, which, would inevitably interrupt our connections and "
|
||||||
|
"correspondence. They too have been deaf to the voice of justice and of "
|
||||||
|
"consanguinity. We must, therefore, acquiesce in the necessity, which "
|
||||||
|
"denounces our Separation, and hold them, as we hold the rest of mankind, "
|
||||||
|
"Enemies in War, in Peace Friends.\n\nWe, therefore, the Representatives "
|
||||||
|
"of the united States of America, in General Congress, Assembled, "
|
||||||
|
"appealing to the Supreme Judge of the world for the rectitude of our "
|
||||||
|
"intentions, do, in the Name, and by Authority of the good People of these "
|
||||||
|
"Colonies, solemnly publish and declare, That these United Colonies are, "
|
||||||
|
"and of Right ought to be Free and Independent States; that they are "
|
||||||
|
"Absolved from all Allegiance to the British Crown, and that all political "
|
||||||
|
"connection between them and the State of Great Britain, is and ought to "
|
||||||
|
"be totally dissolved; and that as Free and Independent States, they have "
|
||||||
|
"full Power to levy War, conclude Peace, contract Alliances, establish "
|
||||||
|
"Commerce, and to do all other Acts and Things which Independent States "
|
||||||
|
"may of right do. And for the support of this Declaration, with a firm "
|
||||||
|
"reliance on the protection of divine Providence, we mutually pledge to "
|
||||||
|
"each other our Lives, our Fortunes and our sacred Honor.\n"};
|
||||||
|
|
||||||
|
TEST_F(GemmaTest, CrossEntropySmall) {
|
||||||
|
static const char kSmall[] =
|
||||||
|
"The capital of Hungary is Budapest which is located in Europe.";
|
||||||
|
float entropy = GemmaCrossEntropy(kSmall);
|
||||||
|
std::cout << "per-byte entropy: " << entropy << "\n";
|
||||||
|
EXPECT_LT(entropy, 1.6f);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
||||||
|
float entropy = GemmaCrossEntropy(kJingleBells);
|
||||||
|
std::cout << "per-byte entropy: " << entropy << "\n";
|
||||||
|
EXPECT_LT(entropy, 2.3f);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
||||||
|
float entropy = GemmaCrossEntropy(kGettysburg);
|
||||||
|
std::cout << "per-byte entropy: " << entropy << "\n";
|
||||||
|
EXPECT_LT(entropy, 1.2f);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GemmaTest, CrossEntropyDeclaration) {
|
||||||
|
float entropy = GemmaCrossEntropy(kDeclaration);
|
||||||
|
std::cout << "per-byte entropy: " << entropy << "\n";
|
||||||
|
EXPECT_LT(entropy, 1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gcpp
|
||||||
|
|
@ -0,0 +1,40 @@
|
||||||
|
+++
|
||||||
|
list me ten biggest cities
|
||||||
|
---
|
||||||
|
8. Mexico City
|
||||||
|
+++
|
||||||
|
What is the capital of Switzerland?
|
||||||
|
---
|
||||||
|
The capital of Switzerland is Bern.
|
||||||
|
+++
|
||||||
|
What is the name of the president of the US?
|
||||||
|
---
|
||||||
|
I cannot answer this question
|
||||||
|
+++
|
||||||
|
Is it raining frequently in China?
|
||||||
|
---
|
||||||
|
average of 1,200 rainy days per year
|
||||||
|
+++
|
||||||
|
tell me a french joke
|
||||||
|
---
|
||||||
|
What do you call a French person who's always complaining?
|
||||||
|
+++
|
||||||
|
which year did the 21th century started?
|
||||||
|
---
|
||||||
|
The 21st century started in the year **1900**
|
||||||
|
+++
|
||||||
|
what's your favorite pokemon?
|
||||||
|
---
|
||||||
|
I am unable to provide a favorite Pokémon
|
||||||
|
+++
|
||||||
|
How to bake a tasty cake?
|
||||||
|
---
|
||||||
|
* 1 and 1/2 cups all-purpose flour
|
||||||
|
+++
|
||||||
|
Which is the richest country in the world?
|
||||||
|
---
|
||||||
|
richest country in the world is the United States
|
||||||
|
+++
|
||||||
|
do you like electronic music?
|
||||||
|
---
|
||||||
|
Electronic music is a broad genre
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
+++
|
||||||
|
list me ten biggest cities
|
||||||
|
---
|
||||||
|
6. Mexico City, Mexico
|
||||||
|
+++
|
||||||
|
What is the capital of Switzerland?
|
||||||
|
---
|
||||||
|
Bern is the capital of Switzerland.
|
||||||
|
+++
|
||||||
|
What is the name of the president of the US?
|
||||||
|
---
|
||||||
|
The answer is: Joe Biden.
|
||||||
32
run.cc
32
run.cc
|
|
@ -40,13 +40,14 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
static constexpr std::string_view kAsciiArtBanner =
|
static constexpr std::string_view kAsciiArtBanner = R""(
|
||||||
" __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __\n"
|
__ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __
|
||||||
" / _` |/ _ \\ '_ ` _ \\| '_ ` _ \\ / _` | / __| '_ \\| '_ \\\n"
|
/ _` |/ _ \ '_ ` _ \| '_ ` _ \ / _` | / __| '_ \| '_ \
|
||||||
"| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |\n"
|
| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |
|
||||||
" \\__, |\\___|_| |_| |_|_| |_| |_|\\__,_(_)___| .__/| .__/\n"
|
\__, |\___|_| |_| |_|_| |_| |_|\__,_(_)___| .__/| .__/
|
||||||
" __/ | | | | |\n"
|
__/ | | | | |
|
||||||
" |___/ |_| |_|";
|
|___/ |_| |_|
|
||||||
|
)"";
|
||||||
|
|
||||||
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
loader.Print(app.verbosity);
|
loader.Print(app.verbosity);
|
||||||
|
|
@ -60,7 +61,7 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
<< "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize
|
<< "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize
|
||||||
<< "\n"
|
<< "\n"
|
||||||
<< "Hardware concurrency : "
|
<< "Hardware concurrency : "
|
||||||
<< std::thread::hardware_concurrency() << std::endl
|
<< std::thread::hardware_concurrency() << "\n"
|
||||||
<< "Instruction set : "
|
<< "Instruction set : "
|
||||||
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
|
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
|
||||||
<< hwy::VectorBytes() * 8 << " bits)" << "\n"
|
<< hwy::VectorBytes() * 8 << " bits)" << "\n"
|
||||||
|
|
@ -132,13 +133,13 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text).ok());
|
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text));
|
||||||
// +1 since position is incremented above
|
// +1 since position is incremented above
|
||||||
if (current_pos == prompt_size + 1) {
|
if (current_pos == prompt_size + 1) {
|
||||||
// first token of response
|
// first token of response
|
||||||
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
||||||
if (verbosity >= 1) {
|
if (verbosity >= 1) {
|
||||||
std::cout << std::endl << std::endl;
|
std::cout << "\n\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::cout << token_text << std::flush;
|
std::cout << token_text << std::flush;
|
||||||
|
|
@ -189,7 +190,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt).ok());
|
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
|
||||||
|
|
||||||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
||||||
// if needed.
|
// if needed.
|
||||||
|
|
@ -199,7 +200,8 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
|
|
||||||
prompt_size = prompt.size();
|
prompt_size = prompt.size();
|
||||||
|
|
||||||
std::cerr << std::endl << "[ Reading prompt ] " << std::flush;
|
std::cerr << "\n"
|
||||||
|
<< "[ Reading prompt ] " << std::flush;
|
||||||
|
|
||||||
const double time_start = hwy::platform::Now();
|
const double time_start = hwy::platform::Now();
|
||||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||||
|
|
@ -209,10 +211,10 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
const double tok_sec = current_pos / (time_end - time_start);
|
const double tok_sec = current_pos / (time_end - time_start);
|
||||||
if (verbosity >= 2) {
|
if (verbosity >= 2) {
|
||||||
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
|
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
|
||||||
<< std::endl
|
<< "\n"
|
||||||
<< tok_sec << " tokens / sec" << std::endl;
|
<< tok_sec << " tokens / sec" << "\n";
|
||||||
}
|
}
|
||||||
std::cout << std::endl << std::endl;
|
std::cout << "\n\n";
|
||||||
}
|
}
|
||||||
std::cout
|
std::cout
|
||||||
<< "max_tokens (" << args.max_tokens
|
<< "max_tokens (" << args.max_tokens
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@
|
||||||
#include <cctype>
|
#include <cctype>
|
||||||
#include <cerrno> // IDE does not recognize errno.h as providing errno.
|
#include <cerrno> // IDE does not recognize errno.h as providing errno.
|
||||||
#include <string>
|
#include <string>
|
||||||
#endif
|
#endif // HWY_OS_LINUX
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
|
|
@ -170,8 +170,8 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
weights = compressed_weights;
|
weights = compressed_weights;
|
||||||
} else {
|
} else {
|
||||||
return "Only one of --weights and --compressed_weights can be "
|
return "Only one of --weights and --compressed_weights can be "
|
||||||
"specified. To create compressed weights use the compress_weights "
|
"specified. To create compressed weights use the "
|
||||||
"tool.";
|
"compress_weights tool.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (weights.path.empty()) {
|
if (weights.path.empty()) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue