mirror of https://github.com/google/gemma.cpp.git
Merge pull request #131 from veluca93:benchmark-and-test
PiperOrigin-RevId: 622452794
This commit is contained in:
commit
a3a0f78fda
|
|
@ -90,6 +90,7 @@ Checks: "-*,\
|
|||
-concurrency-mt-unsafe,\
|
||||
-cppcoreguidelines-avoid-c-arrays,\
|
||||
-cppcoreguidelines-avoid-const-or-ref-data-members,\
|
||||
-cppcoreguidelines-avoid-do-while,\
|
||||
-cppcoreguidelines-avoid-goto,\
|
||||
-cppcoreguidelines-avoid-magic-numbers,\
|
||||
-cppcoreguidelines-avoid-non-const-global-variables,\
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
name: build
|
||||
|
||||
# Trigger on push, pull request, or via manual dispatch.
|
||||
on: [push, pull_request, workflow_dispatch]
|
||||
on:
|
||||
push:
|
||||
pull_request:
|
||||
types: [opened, reopened, labeled, unlabeled, synchronize]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,4 @@
|
|||
.cache/
|
||||
bazel-*/
|
||||
build-*/
|
||||
python/*/__pycache__
|
||||
38
BUILD.bazel
38
BUILD.bazel
|
|
@ -86,6 +86,25 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "gemma_test",
|
||||
srcs = ["gemma_test.cc"],
|
||||
# Requires model files
|
||||
tags = [
|
||||
"local",
|
||||
"manual",
|
||||
"no_tap",
|
||||
],
|
||||
deps = [
|
||||
":args",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
"@googletest//:gtest_main",
|
||||
"@hwy//:hwy_test_util",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "app",
|
||||
hdrs = [
|
||||
|
|
@ -132,3 +151,22 @@ cc_binary(
|
|||
"@hwy//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "benchmark",
|
||||
srcs = [
|
||||
"benchmark.cc",
|
||||
],
|
||||
deps = [
|
||||
":app",
|
||||
":args",
|
||||
":gemma_lib",
|
||||
# "//base",
|
||||
"//compression:compress",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:nanobenchmark",
|
||||
"@hwy//:profiler",
|
||||
"@hwy//:thread_pool",
|
||||
"//third_party/json",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -30,6 +30,9 @@ FetchContent_MakeAvailable(highway)
|
|||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
||||
FetchContent_Declare(json GIT_REPOSITORY https://github.com/nlohmann/json.git GIT_TAG 9cca280a4d0ccf0c08f47a99aa71d1b0e52f8d03 EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(json)
|
||||
|
||||
set(SOURCES
|
||||
gemma.cc
|
||||
compression/blob_store.cc
|
||||
|
|
@ -60,17 +63,7 @@ if (WEIGHT_TYPE)
|
|||
add_definitions(-DGEMMA_WEIGHT_T=${WEIGHT_TYPE})
|
||||
endif()
|
||||
|
||||
# Executable Target
|
||||
|
||||
add_executable(gemma run.cc)
|
||||
target_sources(gemma PRIVATE ${SOURCES})
|
||||
set_property(TARGET gemma PROPERTY CXX_STANDARD 17)
|
||||
target_link_libraries(gemma hwy hwy_contrib sentencepiece)
|
||||
target_include_directories(gemma PRIVATE ./)
|
||||
FetchContent_GetProperties(sentencepiece)
|
||||
target_include_directories(gemma PRIVATE ${sentencepiece_SOURCE_DIR})
|
||||
target_compile_definitions(gemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
|
||||
target_compile_options(gemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
||||
|
||||
## Library Target
|
||||
|
||||
|
|
@ -84,11 +77,21 @@ target_include_directories(libgemma PUBLIC ${sentencepiece_SOURCE_DIR})
|
|||
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
|
||||
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
||||
|
||||
# Executable Target
|
||||
|
||||
add_executable(gemma run.cc)
|
||||
target_link_libraries(gemma libgemma hwy hwy_contrib)
|
||||
|
||||
add_executable(benchmark benchmark.cc)
|
||||
target_link_libraries(benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
|
||||
|
||||
## Tests
|
||||
set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests")
|
||||
if (GEMMA_ENABLE_TESTS)
|
||||
|
||||
set(GEMMA_TEST_FILES
|
||||
ops_test.cc
|
||||
gemma_test.cc
|
||||
)
|
||||
|
||||
foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,13 @@ http_archive(
|
|||
strip_prefix = "highway-1.1.0",
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "nlohmann_json",
|
||||
urls = ["https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.zip"],
|
||||
integrity = "sha256-BAIrBdgG61/3MCPCgLaGl9Erk+G3JnoLIqGjnsdXgGk=",
|
||||
strip_prefix = "json-3.11.3",
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "com_google_sentencepiece",
|
||||
sha256 = "8409b0126ebd62b256c685d5757150cf7fcb2b92a2f2b98efb3f38fc36719754",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,293 @@
|
|||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "nlohmann/json.hpp"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/app.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/per_target.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
class BenchmarkArgs : public gcpp::ArgsBase<BenchmarkArgs> {
|
||||
public:
|
||||
BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
gcpp::Path goldens;
|
||||
gcpp::Path summarize_text;
|
||||
gcpp::Path cross_entropy;
|
||||
gcpp::Path trivia_qa;
|
||||
size_t max_questions;
|
||||
size_t batch_tokens;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(goldens.path, "goldens_dir", std::string(""),
|
||||
"Directory containing golden files", 2);
|
||||
visitor(summarize_text.path, "summarize_text", std::string(""),
|
||||
"Path to text file to summarize", 2);
|
||||
visitor(cross_entropy.path, "cross_entropy", std::string(""),
|
||||
"Path to text file to compute cross entropy on", 2);
|
||||
visitor(trivia_qa.path, "trivia_qa", std::string(""),
|
||||
"Path to json file containing TriviaQA entries", 2);
|
||||
visitor(max_questions, "max_questions", (size_t)20,
|
||||
"Maximum number of questions to ask from --trivial_qa input", 2);
|
||||
visitor(batch_tokens, "batch_tokens", (size_t)0,
|
||||
"If not zero, break prompt into batches of this size and compute "
|
||||
"cross entropy on them independently.",
|
||||
2);
|
||||
}
|
||||
};
|
||||
|
||||
void LogSpeedStats(const double time_start, size_t total_tokens) {
|
||||
const double time_end = hwy::platform::Now();
|
||||
const double time_elapsed = time_end - time_start;
|
||||
const double tok_sec = total_tokens / time_elapsed;
|
||||
std::cout << total_tokens << " tokens in " << time_elapsed << " seconds"
|
||||
<< " [" << tok_sec << " tokens / sec" << "]\n";
|
||||
}
|
||||
|
||||
std::pair<std::string, int> QueryModel(
|
||||
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
||||
gcpp::KVCache& kv_cache, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
||||
const std::string& input) {
|
||||
std::vector<int> prompt;
|
||||
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
||||
|
||||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
||||
// if needed.
|
||||
prompt.insert(prompt.begin(), 2);
|
||||
std::string res;
|
||||
size_t total_tokens = 0;
|
||||
auto accept_token = [](int) { return true; };
|
||||
std::mt19937 gen;
|
||||
gen.seed(42);
|
||||
|
||||
const double time_start = hwy::platform::Now();
|
||||
auto stream_token = [&res, &total_tokens, &time_start, &app,
|
||||
tokenizer = model.Tokenizer()](int token, float) {
|
||||
++total_tokens;
|
||||
std::string token_text;
|
||||
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text));
|
||||
res += token_text;
|
||||
if (app.verbosity >= 1 && total_tokens % 100 == 0) {
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
if (app.verbosity >= 2) {
|
||||
std::cout << args.max_tokens << " " << args.max_generated_tokens << " "
|
||||
<< args.temperature;
|
||||
}
|
||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
|
||||
inner_pool, stream_token, accept_token, gen, app.verbosity);
|
||||
if (app.verbosity >= 1) {
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
}
|
||||
return {res, total_tokens};
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> load_goldens(
|
||||
const std::string& path) {
|
||||
std::ifstream goldens_file(path);
|
||||
if (!goldens_file) {
|
||||
std::cout << "Could not load goldens file: " << path << "\n" << std::flush;
|
||||
return {};
|
||||
}
|
||||
std::vector<std::pair<std::string, std::string>> res;
|
||||
std::string query_separator;
|
||||
std::string query;
|
||||
std::string answer_separator;
|
||||
std::string answer;
|
||||
while (std::getline(goldens_file, query_separator) &&
|
||||
std::getline(goldens_file, query) &&
|
||||
std::getline(goldens_file, answer_separator) &&
|
||||
std::getline(goldens_file, answer)) {
|
||||
res.push_back({query, answer});
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::string ReadFile(const gcpp::Path& path) {
|
||||
std::ifstream text_file(path.path);
|
||||
if (!text_file) {
|
||||
std::cout << "Could not open file: " << path.path << "\n" << std::flush;
|
||||
return {};
|
||||
}
|
||||
std::stringstream buffer;
|
||||
buffer << text_file.rdbuf();
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
||||
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
||||
const std::string& golden_path) {
|
||||
const std::vector<std::pair<std::string, std::string>> queries_answers =
|
||||
load_goldens(golden_path);
|
||||
int correct_answers = 0;
|
||||
int total_tokens = 0;
|
||||
const double time_start = hwy::platform::Now();
|
||||
for (const auto& [question, expected_answer] : queries_answers) {
|
||||
const auto [answer, token_count] =
|
||||
QueryModel(model, args, app, kv_cache, inner_pool, pool, question);
|
||||
total_tokens += token_count;
|
||||
if (answer.find(expected_answer) != std::string::npos) {
|
||||
correct_answers++;
|
||||
} else {
|
||||
std::cout << "Wrong!\n";
|
||||
std::cout << "Input: " << question << "\n";
|
||||
std::cout << "Expected: " << expected_answer << "\n";
|
||||
std::cout << "Output: " << answer << "\n\n" << std::flush;
|
||||
}
|
||||
}
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
|
||||
std::cout << "Correct: " << correct_answers << " out of "
|
||||
<< queries_answers.size() << "\n"
|
||||
<< std::flush;
|
||||
if (correct_answers != queries_answers.size()) {
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
||||
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
||||
const gcpp::Path& text) {
|
||||
std::string prompt("Here is some text to summarize:\n");
|
||||
prompt.append(ReadFile(text));
|
||||
prompt.append("\nSummarize this text.\n");
|
||||
const double time_start = hwy::platform::Now();
|
||||
const auto [answer, token_count] =
|
||||
QueryModel(model, args, app, kv_cache, inner_pool, pool, prompt);
|
||||
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
|
||||
LogSpeedStats(time_start, token_count);
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
|
||||
gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
||||
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
||||
const gcpp::Path& text, size_t batch_tokens) {
|
||||
std::string input = ReadFile(text);
|
||||
std::vector<int> prompt;
|
||||
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
||||
prompt.resize(std::min<size_t>(args.max_tokens, prompt.size()));
|
||||
std::cout << "Number of input tokens: " << prompt.size() << "\n";
|
||||
const double time_start = hwy::platform::Now();
|
||||
float total_entropy = 0.0f;
|
||||
size_t total_input_len = 0;
|
||||
if (batch_tokens == 0) batch_tokens = prompt.size();
|
||||
for (size_t pos = 0; pos < prompt.size(); pos += batch_tokens) {
|
||||
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
||||
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||
prompt.begin() + pos + num_tokens);
|
||||
auto kv_cache = CreateKVCache(model_type);
|
||||
float entropy =
|
||||
ComputeCrossEntropy(model, num_tokens, prompt_slice, kv_cache, pool,
|
||||
inner_pool, app.verbosity);
|
||||
total_entropy += entropy;
|
||||
LogSpeedStats(time_start, pos + num_tokens);
|
||||
std::string text_slice;
|
||||
HWY_ASSERT(model.Tokenizer()->Decode(prompt_slice, &text_slice));
|
||||
total_input_len += text_slice.size();
|
||||
printf("Cross entropy per byte: %f [cumulative: %f]\n",
|
||||
entropy / text_slice.size(), total_entropy / total_input_len);
|
||||
}
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
||||
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
||||
const gcpp::Path& json_file, size_t max_questions) {
|
||||
std::ifstream trivia_file(json_file.path);
|
||||
if (!trivia_file) {
|
||||
std::cout << "Could not load file: " << json_file.path << "\n"
|
||||
<< std::flush;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
std::string line;
|
||||
size_t correct_answers = 0;
|
||||
size_t i = 0;
|
||||
while (std::getline(trivia_file, line)) {
|
||||
json data = json::parse(line);
|
||||
const auto [answer, token_count] = QueryModel(
|
||||
model, args, app, kv_cache, inner_pool, pool, data["question"]);
|
||||
std::cout << answer << "\n";
|
||||
bool correct = false;
|
||||
for (const std::string expected : data["answer"]["aliases"]) {
|
||||
if (answer.find(expected) != std::string::npos) {
|
||||
correct = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (correct) {
|
||||
++correct_answers;
|
||||
std::cout << "CORRECT\n\n";
|
||||
} else {
|
||||
std::cout << "WRONG\n\n";
|
||||
}
|
||||
if (++i >= max_questions) break;
|
||||
}
|
||||
printf("Correct answers: %zu / %zu\n", correct_answers, i);
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
/* Run this in the same way as gemma, p.ex.:
|
||||
./benchmark --tokenizer tokenizer.spm --model 2b-it --weights \
|
||||
2b-it-sfp.sbs --goldens_dir "../goldens"
|
||||
*/
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
gcpp::InferenceArgs args(argc, argv); // inference
|
||||
gcpp::AppArgs app(argc, argv);
|
||||
BenchmarkArgs benchmark_args(argc, argv);
|
||||
|
||||
hwy::ThreadPool inner_pool(0);
|
||||
hwy::ThreadPool pool(app.num_threads);
|
||||
// For many-core, pinning threads to cores helps.
|
||||
if (app.num_threads > 10) {
|
||||
gcpp::PinThreadToCore(app.num_threads - 1); // Main thread
|
||||
|
||||
pool.Run(0, pool.NumThreads(), [](uint64_t /*task*/, size_t thread) {
|
||||
gcpp::PinThreadToCore(thread);
|
||||
});
|
||||
}
|
||||
|
||||
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||
|
||||
if (!benchmark_args.goldens.path.empty()) {
|
||||
const std::string golden_path =
|
||||
benchmark_args.goldens.path + "/" + loader.model_type + ".txt";
|
||||
return BenchmarkGoldens(model, args, app, kv_cache, inner_pool, pool,
|
||||
golden_path);
|
||||
} else if (!benchmark_args.summarize_text.path.empty()) {
|
||||
return BenchmarkSummary(model, args, app, kv_cache, inner_pool, pool,
|
||||
benchmark_args.summarize_text);
|
||||
} else if (!benchmark_args.cross_entropy.path.empty()) {
|
||||
return BenchmarkCrossEntropy(model, loader.ModelType(), args, app,
|
||||
inner_pool, pool, benchmark_args.cross_entropy,
|
||||
benchmark_args.batch_tokens);
|
||||
} else if (!benchmark_args.trivia_qa.path.empty()) {
|
||||
return BenchmarkTriviaQA(model, args, app, kv_cache, inner_pool, pool,
|
||||
benchmark_args.trivia_qa,
|
||||
benchmark_args.max_questions);
|
||||
}
|
||||
std::cout << "No benchmark command given." << "\n" << std::flush;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
|
@ -0,0 +1,299 @@
|
|||
#!/usr/bin/env bash
|
||||
# Copyright 2024 Google LLC
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
MYDIR=$(dirname $(realpath "$0"))
|
||||
BUILD_DIR="${BUILD_DIR:-${MYDIR}/build}"
|
||||
|
||||
CMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE:-RelWithDebInfo}"
|
||||
CMAKE_C_COMPILER="${CMAKE_C_COMPILER:-clang-14}"
|
||||
CMAKE_CXX_COMPILER="${CMAKE_CXX_COMPILER:-clang++-14}"
|
||||
# Convenience flag to pass both CMAKE_C_FLAGS and CMAKE_CXX_FLAGS
|
||||
CMAKE_FLAGS="${CMAKE_FLAGS:-}"
|
||||
CMAKE_C_FLAGS="${CMAKE_C_FLAGS:-} ${CMAKE_FLAGS}"
|
||||
CMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS:-} ${CMAKE_FLAGS}"
|
||||
CMAKE_EXE_LINKER_FLAGS="${CMAKE_EXE_LINKER_FLAGS:-}"
|
||||
CMAKE_MODULE_LINKER_FLAGS="${CMAKE_MODULE_LINKER_FLAGS:-}"
|
||||
CMAKE_SHARED_LINKER_FLAGS="${CMAKE_SHARED_LINKER_FLAGS:-}"
|
||||
|
||||
# Local flags passed to sanitizers.
|
||||
UBSAN_FLAGS=(
|
||||
-fsanitize=alignment
|
||||
-fsanitize=bool
|
||||
-fsanitize=bounds
|
||||
-fsanitize=builtin
|
||||
-fsanitize=enum
|
||||
-fsanitize=float-cast-overflow
|
||||
-fsanitize=float-divide-by-zero
|
||||
-fsanitize=integer-divide-by-zero
|
||||
-fsanitize=null
|
||||
-fsanitize=object-size
|
||||
-fsanitize=pointer-overflow
|
||||
-fsanitize=return
|
||||
-fsanitize=returns-nonnull-attribute
|
||||
-fsanitize=shift-base
|
||||
-fsanitize=shift-exponent
|
||||
-fsanitize=unreachable
|
||||
-fsanitize=vla-bound
|
||||
|
||||
-fno-sanitize-recover=undefined
|
||||
-fsanitize-recover=alignment
|
||||
)
|
||||
|
||||
CLANG_VERSION="${CLANG_VERSION:-}"
|
||||
# Detect the clang version suffix and store it in CLANG_VERSION. For example,
|
||||
# "6.0" for clang 6 or "7" for clang 7.
|
||||
detect_clang_version() {
|
||||
if [[ -n "${CLANG_VERSION}" ]]; then
|
||||
return 0
|
||||
fi
|
||||
local clang_version=$("${CMAKE_C_COMPILER:-clang}" --version | head -n1)
|
||||
clang_version=${clang_version#"Debian "}
|
||||
clang_version=${clang_version#"Ubuntu "}
|
||||
local llvm_tag
|
||||
case "${clang_version}" in
|
||||
"clang version 6."*)
|
||||
CLANG_VERSION="6.0"
|
||||
;;
|
||||
"clang version "*)
|
||||
# Any other clang version uses just the major version number.
|
||||
local suffix="${clang_version#clang version }"
|
||||
CLANG_VERSION="${suffix%%.*}"
|
||||
;;
|
||||
"emcc"*)
|
||||
# We can't use asan or msan in the emcc case.
|
||||
;;
|
||||
*)
|
||||
echo "Unknown clang version: ${clang_version}" >&2
|
||||
return 1
|
||||
esac
|
||||
}
|
||||
|
||||
# Temporary files cleanup hooks.
|
||||
CLEANUP_FILES=()
|
||||
cleanup() {
|
||||
if [[ ${#CLEANUP_FILES[@]} -ne 0 ]]; then
|
||||
rm -fr "${CLEANUP_FILES[@]}"
|
||||
fi
|
||||
}
|
||||
|
||||
# Executed on exit.
|
||||
on_exit() {
|
||||
local retcode="$1"
|
||||
# Always cleanup the CLEANUP_FILES.
|
||||
cleanup
|
||||
}
|
||||
|
||||
trap 'retcode=$?; { set +x; } 2>/dev/null; on_exit ${retcode}' INT TERM EXIT
|
||||
|
||||
|
||||
# Install libc++ libraries compiled with msan in the msan_prefix for the current
|
||||
# compiler version.
|
||||
cmd_msan_install() {
|
||||
local tmpdir=$(mktemp -d)
|
||||
CLEANUP_FILES+=("${tmpdir}")
|
||||
# Detect the llvm to install:
|
||||
detect_clang_version
|
||||
# Allow overriding the LLVM checkout.
|
||||
local llvm_root="${LLVM_ROOT:-}"
|
||||
if [ -z "${llvm_root}" ]; then
|
||||
local llvm_tag="llvmorg-${CLANG_VERSION}.0.0"
|
||||
case "${CLANG_VERSION}" in
|
||||
"6.0")
|
||||
llvm_tag="llvmorg-6.0.1"
|
||||
;;
|
||||
"7")
|
||||
llvm_tag="llvmorg-7.0.1"
|
||||
;;
|
||||
esac
|
||||
local llvm_targz="${tmpdir}/${llvm_tag}.tar.gz"
|
||||
curl -L --show-error -o "${llvm_targz}" \
|
||||
"https://github.com/llvm/llvm-project/archive/${llvm_tag}.tar.gz"
|
||||
tar -C "${tmpdir}" -zxf "${llvm_targz}"
|
||||
llvm_root="${tmpdir}/llvm-project-${llvm_tag}"
|
||||
fi
|
||||
|
||||
local msan_prefix="${HOME}/.msan/${CLANG_VERSION}"
|
||||
rm -rf "${msan_prefix}"
|
||||
|
||||
declare -A CMAKE_EXTRAS
|
||||
CMAKE_EXTRAS[libcxx]="\
|
||||
-DLIBCXX_CXX_ABI=libstdc++ \
|
||||
-DLIBCXX_INSTALL_EXPERIMENTAL_LIBRARY=ON \
|
||||
-DLIBCXX_INCLUDE_BENCHMARKS=OFF"
|
||||
|
||||
for project in libcxx; do
|
||||
local proj_build="${tmpdir}/build-${project}"
|
||||
local proj_dir="${llvm_root}/${project}"
|
||||
mkdir -p "${proj_build}"
|
||||
cmake -B"${proj_build}" -H"${proj_dir}" \
|
||||
-G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DLLVM_USE_SANITIZER=Memory \
|
||||
-DLLVM_PATH="${llvm_root}/llvm" \
|
||||
-DLLVM_CONFIG_PATH="$(which llvm-config llvm-config-7 llvm-config-6.0 | \
|
||||
head -n1)" \
|
||||
-DCMAKE_C_COMPILER="${CMAKE_C_COMPILER}" \
|
||||
-DCMAKE_CXX_COMPILER="${CMAKE_CXX_COMPILER}" \
|
||||
-DCMAKE_C_FLAGS="${CMAKE_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS}" \
|
||||
-DCMAKE_EXE_LINKER_FLAGS="${CMAKE_EXE_LINKER_FLAGS}" \
|
||||
-DCMAKE_SHARED_LINKER_FLAGS="${CMAKE_SHARED_LINKER_FLAGS}" \
|
||||
-DCMAKE_INSTALL_PREFIX="${msan_prefix}" \
|
||||
${CMAKE_EXTRAS[${project}]}
|
||||
cmake --build "${proj_build}"
|
||||
ninja -C "${proj_build}" install
|
||||
done
|
||||
}
|
||||
|
||||
cmd_msan() {
|
||||
detect_clang_version
|
||||
local msan_prefix="${HOME}/.msan/${CLANG_VERSION}"
|
||||
if [[ ! -d "${msan_prefix}" || -e "${msan_prefix}/lib/libc++abi.a" ]]; then
|
||||
# Install msan libraries for this version if needed or if an older version
|
||||
# with libc++abi was installed.
|
||||
cmd_msan_install
|
||||
fi
|
||||
|
||||
local msan_c_flags=(
|
||||
-fsanitize=memory
|
||||
-fno-omit-frame-pointer
|
||||
|
||||
-g
|
||||
-DMEMORY_SANITIZER
|
||||
|
||||
# Force gtest to not use the cxxbai.
|
||||
-DGTEST_HAS_CXXABI_H_=0
|
||||
|
||||
-fsanitize-memory-track-origins
|
||||
)
|
||||
|
||||
local msan_cxx_flags=(
|
||||
"${msan_c_flags[@]}"
|
||||
|
||||
# Some C++ sources don't use the std at all, so the -stdlib=libc++ is unused
|
||||
# in those cases. Ignore the warning.
|
||||
-Wno-unused-command-line-argument
|
||||
-stdlib=libc++
|
||||
|
||||
# We include the libc++ from the msan directory instead, so we don't want
|
||||
# the std includes.
|
||||
-nostdinc++
|
||||
-cxx-isystem"${msan_prefix}/include/c++/v1"
|
||||
)
|
||||
|
||||
local msan_linker_flags=(
|
||||
-L"${msan_prefix}"/lib
|
||||
-Wl,-rpath -Wl,"${msan_prefix}"/lib/
|
||||
)
|
||||
|
||||
CMAKE_C_FLAGS+=" ${msan_c_flags[@]} ${UBSAN_FLAGS[@]}"
|
||||
CMAKE_CXX_FLAGS+=" ${msan_cxx_flags[@]} ${UBSAN_FLAGS[@]}"
|
||||
CMAKE_EXE_LINKER_FLAGS+=" ${msan_linker_flags[@]}"
|
||||
CMAKE_MODULE_LINKER_FLAGS+=" ${msan_linker_flags[@]}"
|
||||
CMAKE_SHARED_LINKER_FLAGS+=" ${msan_linker_flags[@]}"
|
||||
cmake_configure "$@" \
|
||||
-DCMAKE_CROSSCOMPILING=1 -DRUN_HAVE_STD_REGEX=0 -DRUN_HAVE_POSIX_REGEX=0 \
|
||||
-DCMAKE_REQUIRED_LINK_OPTIONS="${msan_linker_flags[@]}"
|
||||
}
|
||||
|
||||
cmake_configure() {
|
||||
local args=(
|
||||
-B"${BUILD_DIR}" -H"${MYDIR}"
|
||||
-DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}"
|
||||
-G Ninja
|
||||
-DCMAKE_C_COMPILER="${CMAKE_C_COMPILER}"
|
||||
-DCMAKE_CXX_COMPILER="${CMAKE_CXX_COMPILER}"
|
||||
-DCMAKE_C_FLAGS="${CMAKE_C_FLAGS}"
|
||||
-DCMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS}"
|
||||
-DCMAKE_EXE_LINKER_FLAGS="${CMAKE_EXE_LINKER_FLAGS}"
|
||||
-DCMAKE_MODULE_LINKER_FLAGS="${CMAKE_MODULE_LINKER_FLAGS}"
|
||||
-DCMAKE_SHARED_LINKER_FLAGS="${CMAKE_SHARED_LINKER_FLAGS}"
|
||||
-DGEMMA_ENABLE_TESTS=ON
|
||||
)
|
||||
|
||||
cmake "${args[@]}" "$@"
|
||||
}
|
||||
|
||||
cmd_opt() {
|
||||
CMAKE_BUILD_TYPE="RelWithDebInfo"
|
||||
cmake_configure "$@"
|
||||
}
|
||||
|
||||
cmd_asan() {
|
||||
CMAKE_C_FLAGS+=" -g -DADDRESS_SANITIZER -fsanitize=address ${UBSAN_FLAGS[@]}"
|
||||
CMAKE_CXX_FLAGS+=" -g -DADDRESS_SANITIZER -fsanitize=address \
|
||||
${UBSAN_FLAGS[@]}"
|
||||
cmake_configure "$@"
|
||||
}
|
||||
|
||||
cmd_tsan() {
|
||||
SANITIZER="tsan"
|
||||
local tsan_args=(
|
||||
-g
|
||||
-DTHREAD_SANITIZER
|
||||
${UBSAN_FLAGS[@]}
|
||||
-fsanitize=thread
|
||||
)
|
||||
CMAKE_C_FLAGS+=" ${tsan_args[@]}"
|
||||
CMAKE_CXX_FLAGS+=" ${tsan_args[@]}"
|
||||
|
||||
CMAKE_BUILD_TYPE="RelWithDebInfo"
|
||||
cmake_configure "$@"
|
||||
}
|
||||
|
||||
main() {
|
||||
local cmd="${1:-}"
|
||||
if [[ -z "${cmd}" ]]; then
|
||||
cat >&2 <<EOF
|
||||
Use: $0 CMD
|
||||
|
||||
Where cmd is one of:
|
||||
opt Build and test a Release with symbols build.
|
||||
asan Build and test an ASan (AddressSanitizer) build.
|
||||
msan Build and test an MSan (MemorySanitizer) build. Needs to have msan
|
||||
c++ libs installed with msan_install first.
|
||||
msan_install Install the libc++ libraries required to build in msan mode. This
|
||||
needs to be done once.
|
||||
tsan Build and test a TSan (ThreadSanitizer) build.
|
||||
|
||||
You can pass some optional environment variables as well:
|
||||
- BUILD_DIR: The output build directory (by default "$$repo/build")
|
||||
- CMAKE_FLAGS: Convenience flag to pass both CMAKE_C_FLAGS and CMAKE_CXX_FLAGS.
|
||||
|
||||
These optional environment variables are forwarded to the cmake call as
|
||||
parameters:
|
||||
- CMAKE_BUILD_TYPE
|
||||
- CMAKE_C_FLAGS
|
||||
- CMAKE_CXX_FLAGS
|
||||
- CMAKE_C_COMPILER
|
||||
- CMAKE_CXX_COMPILER
|
||||
- CMAKE_EXE_LINKER_FLAGS
|
||||
- CMAKE_MODULE_LINKER_FLAGS
|
||||
- CMAKE_SHARED_LINKER_FLAGS
|
||||
|
||||
Example:
|
||||
BUILD_DIR=/tmp/build $0 opt
|
||||
EOF
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cmd="cmd_${cmd}"
|
||||
shift
|
||||
set -x
|
||||
"${cmd}" "$@"
|
||||
}
|
||||
|
||||
main "$@"
|
||||
|
|
@ -23,13 +23,12 @@
|
|||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h"
|
||||
|
||||
std::vector<int> tokenize(
|
||||
const std::string& prompt_string,
|
||||
const sentencepiece::SentencePieceProcessor* tokenizer) {
|
||||
std::vector<int> tokenize(const std::string& prompt_string,
|
||||
const gcpp::GemmaTokenizer* tokenizer) {
|
||||
std::string formatted = "<start_of_turn>user\n" + prompt_string +
|
||||
"<end_of_turn>\n<start_of_turn>model\n";
|
||||
std::vector<int> tokens;
|
||||
HWY_ASSERT(tokenizer->Encode(formatted, &tokens).ok());
|
||||
HWY_ASSERT(tokenizer->Encode(formatted, &tokens));
|
||||
tokens.insert(tokens.begin(), 2); // BOS token
|
||||
return tokens;
|
||||
}
|
||||
|
|
@ -58,14 +57,14 @@ int main(int argc, char** argv) {
|
|||
size_t ntokens = tokens.size();
|
||||
|
||||
// This callback function gets invoked everytime a token is generated
|
||||
auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()](
|
||||
int token, float) {
|
||||
auto stream_token = [&pos, &ntokens, tokenizer = model.Tokenizer()](int token,
|
||||
float) {
|
||||
++pos;
|
||||
if (pos < ntokens) {
|
||||
// print feedback
|
||||
} else if (token != gcpp::EOS_ID) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text).ok());
|
||||
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text));
|
||||
std::cout << token_text << std::flush;
|
||||
}
|
||||
return true;
|
||||
|
|
@ -78,5 +77,5 @@ int main(int argc, char** argv) {
|
|||
.verbosity = 0},
|
||||
tokens, /*KV cache position = */ 0, kv_cache, pool,
|
||||
stream_token, gen);
|
||||
std::cout << std::endl;
|
||||
std::cout << "\n";
|
||||
}
|
||||
|
|
|
|||
172
gemma.cc
172
gemma.cc
|
|
@ -25,12 +25,15 @@
|
|||
#include "compression/compress-inl.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "ops.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h" // Path
|
||||
#include "hwy/contrib/matvec/matvec-inl.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/timer.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h" // Path
|
||||
// copybara:import_next_line:sentencepiece
|
||||
#include "src/sentencepiece_processor.h"
|
||||
// copybara:end
|
||||
|
||||
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
|
||||
// compile pass, whereas we want this defined in the first.
|
||||
|
|
@ -49,6 +52,7 @@
|
|||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -330,7 +334,7 @@ struct Activations {
|
|||
struct GemmaInterface {
|
||||
virtual ~GemmaInterface() = default;
|
||||
|
||||
virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0;
|
||||
virtual const GemmaTokenizer* Tokenizer() const = 0;
|
||||
|
||||
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
|
|
@ -339,6 +343,12 @@ struct GemmaInterface {
|
|||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) = 0;
|
||||
|
||||
virtual float ComputeCrossEntropy(size_t max_tokens,
|
||||
const std::vector<int>& prompt,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool,
|
||||
int verbosity) = 0;
|
||||
};
|
||||
|
||||
template <class Config>
|
||||
|
|
@ -358,6 +368,29 @@ KVCache CreateKVCache(Model type) {
|
|||
}
|
||||
}
|
||||
|
||||
class GemmaTokenizerImpl : public GemmaTokenizer {
|
||||
public:
|
||||
GemmaTokenizerImpl(
|
||||
std::unique_ptr<sentencepiece::SentencePieceProcessor>&& impl)
|
||||
: impl_(std::move(impl)) {}
|
||||
bool Encode(const std::string& input,
|
||||
std::vector<std::string>* pieces) const override {
|
||||
return impl_->Encode(input, pieces).ok();
|
||||
}
|
||||
bool Encode(const std::string& input,
|
||||
std::vector<int>* pieces) const override {
|
||||
return impl_->Encode(input, pieces).ok();
|
||||
}
|
||||
// Given a sequence of ids, decodes it into a detokenized output.
|
||||
bool Decode(const std::vector<int>& ids,
|
||||
std::string* detokenized) const override {
|
||||
return impl_->Decode(ids, detokenized).ok();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> impl_;
|
||||
};
|
||||
|
||||
namespace {
|
||||
template <class Config>
|
||||
void DeleteLayersPtrs(CompressedWeights<Config>* c_weights) {
|
||||
|
|
@ -381,9 +414,7 @@ struct GemmaImpl : public GemmaInterface {
|
|||
DeleteLayersPtrs(weights);
|
||||
}
|
||||
|
||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const override {
|
||||
return tokenizer.get();
|
||||
}
|
||||
const GemmaTokenizer* Tokenizer() const override { return &tokenizer; }
|
||||
|
||||
void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
|
|
@ -392,7 +423,12 @@ struct GemmaImpl : public GemmaInterface {
|
|||
const AcceptFunc& accept_token, std::mt19937&,
|
||||
int verbosity) override;
|
||||
|
||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
||||
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool,
|
||||
int verbosity) override;
|
||||
|
||||
GemmaTokenizerImpl tokenizer;
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8;
|
||||
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
|
||||
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
|
||||
|
|
@ -804,6 +840,77 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
}
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
std::string TokenString(GemmaImpl<TConfig>& gemma, int token) {
|
||||
std::string token_str;
|
||||
gemma.Tokenizer()->Decode({token}, &token_str);
|
||||
return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'";
|
||||
}
|
||||
|
||||
#define TOKEN(token_id) TokenString(gemma, token_id).c_str()
|
||||
|
||||
template <class TConfig>
|
||||
void LogTopK(GemmaImpl<TConfig>& gemma, float* logits, float* dist, size_t len,
|
||||
size_t k) {
|
||||
std::vector<std::pair<float, int>> sorted(len);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
sorted[i] = std::make_pair(dist[i], i);
|
||||
}
|
||||
std::sort(sorted.begin(), sorted.end(),
|
||||
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
||||
if (a.first != b.first) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
return a.second < b.second;
|
||||
});
|
||||
for (int i = 0; i < k; ++i) {
|
||||
printf(" [#%-2d token %6d = %-12s %.2e %f]\n", i + 1, sorted[i].second,
|
||||
TOKEN(sorted[i].second), sorted[i].first, logits[sorted[i].second]);
|
||||
}
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, int verbosity) {
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
Activations<TConfig, 1>& activations = *gemma.state.get();
|
||||
const WeightsT<TConfig>& weights =
|
||||
*reinterpret_cast<const WeightsT<TConfig>*>(gemma.weights_u8.get());
|
||||
std::vector<float> logits(kVocabSize);
|
||||
Softmax(activations.logits.data(), kVocabSize);
|
||||
float total_entropy = 0.0f;
|
||||
for (size_t pos = 0; pos < max_tokens && pos < prompt.size(); ++pos) {
|
||||
if (verbosity >= 4) {
|
||||
LogTopK(gemma, logits.data(), activations.logits.data(), kVocabSize, 10);
|
||||
}
|
||||
const int token = prompt[pos];
|
||||
const float prob = activations.logits[token];
|
||||
if (verbosity >= 3) {
|
||||
printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, token,
|
||||
TOKEN(token), prob, -std::log(prob) / std::log(2.0));
|
||||
}
|
||||
total_entropy -= std::max(std::log(prob), -64.0f);
|
||||
if (verbosity >= 2 && pos % 100 == 99) {
|
||||
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
|
||||
total_entropy / std::log(2.0) / (pos + 1));
|
||||
}
|
||||
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool);
|
||||
MatVec<kVocabSize, kModelDim>(weights.embedder_input_embedding, 0,
|
||||
activations.x.data(),
|
||||
activations.logits.data(), pool);
|
||||
LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize);
|
||||
memcpy(logits.data(), activations.logits.data(),
|
||||
kVocabSize * sizeof(logits[0]));
|
||||
Softmax(activations.logits.data(), kVocabSize);
|
||||
}
|
||||
return total_entropy / std::log(2.0);
|
||||
}
|
||||
|
||||
#undef TOKEN
|
||||
|
||||
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
|
|
@ -828,6 +935,22 @@ void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
|||
accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
int verbosity) {
|
||||
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
||||
inner_pool, verbosity);
|
||||
}
|
||||
|
||||
float ComputeCrossEntropy7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
int verbosity) {
|
||||
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
||||
inner_pool, verbosity);
|
||||
}
|
||||
|
||||
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
||||
// if weights = null, which happens during the first call where we attempt to
|
||||
// load from cache.
|
||||
|
|
@ -983,6 +1106,8 @@ HWY_EXPORT(LoadWeightsT);
|
|||
HWY_EXPORT(CompressWeightsT);
|
||||
HWY_EXPORT(Generate2B);
|
||||
HWY_EXPORT(Generate7B);
|
||||
HWY_EXPORT(ComputeCrossEntropy2B);
|
||||
HWY_EXPORT(ComputeCrossEntropy7B);
|
||||
|
||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
|
||||
KVCache kv_cache = {};
|
||||
|
|
@ -995,7 +1120,7 @@ template <class Config>
|
|||
GemmaImpl<Config>::GemmaImpl(
|
||||
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]>& weights_u8, hwy::ThreadPool& pool)
|
||||
: tokenizer(std::move(tokenizer)),
|
||||
: tokenizer(GemmaTokenizerImpl(std::move(tokenizer))),
|
||||
weights_u8(std::move(weights_u8)),
|
||||
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
|
||||
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {}
|
||||
|
|
@ -1023,6 +1148,22 @@ void GemmaImpl<ConfigGemma7B>::Generate(
|
|||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
template <>
|
||||
float GemmaImpl<ConfigGemma2B>::ComputeCrossEntropy(
|
||||
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) {
|
||||
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy2B)(
|
||||
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
|
||||
}
|
||||
|
||||
template <>
|
||||
float GemmaImpl<ConfigGemma7B>::ComputeCrossEntropy(
|
||||
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) {
|
||||
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy7B)(
|
||||
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
|
||||
}
|
||||
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||
hwy::ThreadPool& pool) {
|
||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
||||
|
|
@ -1056,9 +1197,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
|||
|
||||
Gemma::~Gemma() = default; // after GemmaInterface is defined
|
||||
|
||||
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
|
||||
return impl_->Tokenizer();
|
||||
}
|
||||
const GemmaTokenizer* Gemma::Tokenizer() const { return impl_->Tokenizer(); }
|
||||
|
||||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
|
|
@ -1090,5 +1229,16 @@ void CompressWeights(gcpp::Model model, const Path& weights,
|
|||
(model, weights, compressed_weights, pool);
|
||||
}
|
||||
|
||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
int verbosity) {
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
const float result = gemma.impl_->ComputeCrossEntropy(
|
||||
max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
17
gemma.h
17
gemma.h
|
|
@ -60,11 +60,21 @@ struct RuntimeConfig {
|
|||
|
||||
struct GemmaInterface;
|
||||
|
||||
class GemmaTokenizer {
|
||||
public:
|
||||
virtual bool Encode(const std::string& input,
|
||||
std::vector<std::string>* pieces) const = 0;
|
||||
virtual bool Encode(const std::string& input,
|
||||
std::vector<int>* pieces) const = 0;
|
||||
virtual bool Decode(const std::vector<int>& ids,
|
||||
std::string* detokenized) const = 0;
|
||||
};
|
||||
|
||||
struct Gemma {
|
||||
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||
hwy::ThreadPool& pool);
|
||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
||||
const GemmaTokenizer* Tokenizer() const;
|
||||
std::unique_ptr<GemmaInterface> impl_;
|
||||
};
|
||||
|
||||
|
|
@ -95,6 +105,11 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
|||
void CompressWeights(gcpp::Model model, const Path& weights,
|
||||
const Path& compressed_weights, hwy::ThreadPool& pool);
|
||||
|
||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
int verbosity);
|
||||
|
||||
constexpr int EOS_ID = 1;
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -0,0 +1,304 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "gemma.h"
|
||||
|
||||
#include <thread>
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "ops.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
class GemmaTest : public ::testing::Test {
|
||||
protected:
|
||||
GemmaTest()
|
||||
: weights("./2b-it-mqa.sbs"),
|
||||
tokenizer("./tokenizer.spm"),
|
||||
pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)),
|
||||
inner_pool(0),
|
||||
model_type(gcpp::Model::GEMMA_2B),
|
||||
model(tokenizer, weights, model_type, pool) {
|
||||
kv_cache = CreateKVCache(model_type);
|
||||
}
|
||||
|
||||
std::string GemmaReply(const std::string& prompt_string) {
|
||||
std::mt19937 gen;
|
||||
gen.seed(42);
|
||||
|
||||
std::vector<int> prompt;
|
||||
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
|
||||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
||||
// if needed.
|
||||
prompt.insert(prompt.begin(), 2);
|
||||
|
||||
std::vector<int> response;
|
||||
auto stream_token = [&response](int token, float) {
|
||||
response.push_back(token);
|
||||
return true;
|
||||
};
|
||||
gcpp::GenerateGemma(
|
||||
model, /*max_tokens=*/3072, /*max_generated_tokens=*/2048,
|
||||
/*temperature=*/1.0, prompt, /*start_pos=*/0, kv_cache, pool,
|
||||
inner_pool, stream_token,
|
||||
/*accept=*/[](int) { return true; }, gen, /*verbosity=*/0);
|
||||
std::string response_text;
|
||||
HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text));
|
||||
return response_text;
|
||||
}
|
||||
|
||||
float GemmaCrossEntropy(const std::string& prompt_string) {
|
||||
std::vector<int> prompt;
|
||||
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
|
||||
return gcpp::ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt,
|
||||
kv_cache, pool, inner_pool,
|
||||
/*verbosity=*/0) /
|
||||
prompt_string.size();
|
||||
}
|
||||
|
||||
void TestQuestions(const char* kQA[][2], size_t num_questions) {
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
std::cout << "Question " << i + 1 << "\n\n";
|
||||
std::string response = GemmaReply(kQA[i][0]);
|
||||
std::cout << response << "\n\n";
|
||||
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos);
|
||||
}
|
||||
}
|
||||
|
||||
gcpp::Path weights;
|
||||
gcpp::Path tokenizer;
|
||||
gcpp::KVCache kv_cache;
|
||||
hwy::ThreadPool pool;
|
||||
hwy::ThreadPool inner_pool;
|
||||
gcpp::Model model_type = {};
|
||||
gcpp::Gemma model;
|
||||
};
|
||||
|
||||
TEST_F(GemmaTest, Geography) {
|
||||
static const char* kQA[][2] = {
|
||||
{"What is the capital of Hungary?", "Budapest"},
|
||||
{"How many states does the US have?", "50"},
|
||||
{"list me ten biggest cities in the world", "Tokyo"},
|
||||
};
|
||||
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||
TestQuestions(kQA, kNum);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, History) {
|
||||
static const char* kQA[][2] = {
|
||||
{"When was the Battle of Hastings?", "1066"},
|
||||
{"Who fought at the Battle of Marathon?", "Greek"},
|
||||
};
|
||||
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||
TestQuestions(kQA, kNum);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, Arithmetic) {
|
||||
static const char* kQA[][2] = {
|
||||
{"what is 13 + 14?", "27"},
|
||||
{"what is 7 * 8", "56"},
|
||||
};
|
||||
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||
TestQuestions(kQA, kNum);
|
||||
}
|
||||
|
||||
static const char kJingleBells[] = R"(
|
||||
Dashing through the snow
|
||||
In a one-horse open sleigh
|
||||
O'er the fields we go
|
||||
Laughing all the way
|
||||
Bells on bobtails ring
|
||||
Making spirits bright
|
||||
What fun it is to ride and sing
|
||||
A sleighing song tonight
|
||||
)";
|
||||
|
||||
// The "Hay Draft" of the Gettysburg Address.
|
||||
static const char kGettysburg[] = {
|
||||
"Four score and seven years ago our fathers brought forth, upon this "
|
||||
"continent, a new nation, conceived in Liberty, and dedicated to the "
|
||||
"proposition that all men are created equal.\n\nNow we are engaged in a "
|
||||
"great civil war, testing whether that nation, or any nation, so "
|
||||
"conceived, and so dedicated, can long endure. We are met here on a great "
|
||||
"battlefield of that war. We have come to dedicate a portion of it as a "
|
||||
"final resting place for those who here gave their lives that that nation "
|
||||
"might live. It is altogether fitting and proper that we should do "
|
||||
"this.\n\nBut in a larger sense we can not dedicate -- we can not "
|
||||
"consecrate -- we can not hallow this ground. The brave men, living and "
|
||||
"dead, who struggled, here, have consecrated it far above our poor power "
|
||||
"to add or detract. The world will little note, nor long remember, what we "
|
||||
"say here, but can never forget what they did here. It is for us, the "
|
||||
"living, rather to be dedicated here to the unfinished work which they "
|
||||
"have, thus far, so nobly carried on. It is rather for us to be here "
|
||||
"dedicated to the great task remaining before us -- that from these "
|
||||
"honored dead we take increased devotion to that cause for which they here "
|
||||
"gave the last full measure of devotion -- that we here highly resolve "
|
||||
"that these dead shall not have died in vain; that this nation shall have "
|
||||
"a new birth of freedom; and that this government of the people, by the "
|
||||
"people, for the people, shall not perish from the earth.\n"};
|
||||
|
||||
// The Declaration of Independence.
|
||||
static const char kDeclaration[] = {
|
||||
"IN CONGRESS, July 4, 1776.\n\nThe unanimous Declaration of the thirteen "
|
||||
"united States of America,\n\nWhen in the Course of human events, it "
|
||||
"becomes necessary for one people to dissolve the political bands which "
|
||||
"have connected them with another, and to assume among the powers of the "
|
||||
"earth, the separate and equal station to which the Laws of Nature and of "
|
||||
"Nature's God entitle them, a decent respect to the opinions of mankind "
|
||||
"requires that they should declare the causes which impel them to the "
|
||||
"separation.\n\nWe hold these truths to be self-evident, that all men are "
|
||||
"created equal, that they are endowed by their Creator with certain "
|
||||
"unalienable Rights, that among these are Life, Liberty and the pursuit of "
|
||||
"Happiness.--That to secure these rights, Governments are instituted among "
|
||||
"Men, deriving their just powers from the consent of the governed, --That "
|
||||
"whenever any Form of Government becomes destructive of these ends, it is "
|
||||
"the Right of the People to alter or to abolish it, and to institute new "
|
||||
"Government, laying its foundation on such principles and organizing its "
|
||||
"powers in such form, as to them shall seem most likely to effect their "
|
||||
"Safety and Happiness. Prudence, indeed, will dictate that Governments "
|
||||
"long established should not be changed for light and transient causes; "
|
||||
"and accordingly all experience hath shewn, that mankind are more disposed "
|
||||
"to suffer, while evils are sufferable, than to right themselves by "
|
||||
"abolishing the forms to which they are accustomed. But when a long train "
|
||||
"of abuses and usurpations, pursuing invariably the same Object evinces a "
|
||||
"design to reduce them under absolute Despotism, it is their right, it is "
|
||||
"their duty, to throw off such Government, and to provide new Guards for "
|
||||
"their future security.--Such has been the patient sufferance of these "
|
||||
"Colonies; and such is now the necessity which constrains them to alter "
|
||||
"their former Systems of Government. The history of the present King of "
|
||||
"Great Britain is a history of repeated injuries and usurpations, all "
|
||||
"having in direct object the establishment of an absolute Tyranny over "
|
||||
"these States. To prove this, let Facts be submitted to a candid "
|
||||
"world.\n\nHe has refused his Assent to Laws, the most wholesome and "
|
||||
"necessary for the public good.\nHe has forbidden his Governors to pass "
|
||||
"Laws of immediate and pressing importance, unless suspended in their "
|
||||
"operation till his Assent should be obtained; and when so suspended, he "
|
||||
"has utterly neglected to attend to them.\nHe has refused to pass other "
|
||||
"Laws for the accommodation of large districts of people, unless those "
|
||||
"people would relinquish the right of Representation in the Legislature, a "
|
||||
"right inestimable to them and formidable to tyrants only.\nHe has called "
|
||||
"together legislative bodies at places unusual, uncomfortable, and distant "
|
||||
"from the depository of their public Records, for the sole purpose of "
|
||||
"fatiguing them into compliance with his measures.\nHe has dissolved "
|
||||
"Representative Houses repeatedly, for opposing with manly firmness his "
|
||||
"invasions on the rights of the people.\nHe has refused for a long time, "
|
||||
"after such dissolutions, to cause others to be elected; whereby the "
|
||||
"Legislative powers, incapable of Annihilation, have returned to the "
|
||||
"People at large for their exercise; the State remaining in the mean time "
|
||||
"exposed to all the dangers of invasion from without, and convulsions "
|
||||
"within.\nHe has endeavoured to prevent the population of these States; "
|
||||
"for that purpose obstructing the Laws for Naturalization of Foreigners; "
|
||||
"refusing to pass others to encourage their migrations hither, and raising "
|
||||
"the conditions of new Appropriations of Lands.\nHe has obstructed the "
|
||||
"Administration of Justice, by refusing his Assent to Laws for "
|
||||
"establishing Judiciary powers.\nHe has made Judges dependent on his Will "
|
||||
"alone, for the tenure of their offices, and the amount and payment of "
|
||||
"their salaries.\nHe has erected a multitude of New Offices, and sent "
|
||||
"hither swarms of Officers to harrass our people, and eat out their "
|
||||
"substance.\nHe has kept among us, in times of peace, Standing Armies "
|
||||
"without the Consent of our legislatures.\nHe has affected to render the "
|
||||
"Military independent of and superior to the Civil power.\nHe has combined "
|
||||
"with others to subject us to a jurisdiction foreign to our constitution, "
|
||||
"and unacknowledged by our laws; giving his Assent to their Acts of "
|
||||
"pretended Legislation:\nFor Quartering large bodies of armed troops among "
|
||||
"us:\nFor protecting them, by a mock Trial, from punishment for any "
|
||||
"Murders which they should commit on the Inhabitants of these States:\nFor "
|
||||
"cutting off our Trade with all parts of the world:\nFor imposing Taxes on "
|
||||
"us without our Consent:\nFor depriving us in many cases, of the benefits "
|
||||
"of Trial by Jury:\nFor transporting us beyond Seas to be tried for "
|
||||
"pretended offences\nFor abolishing the free System of English Laws in a "
|
||||
"neighbouring Province, establishing therein an Arbitrary government, and "
|
||||
"enlarging its Boundaries so as to render it at once an example and fit "
|
||||
"instrument for introducing the same absolute rule into these "
|
||||
"Colonies:\nFor taking away our Charters, abolishing our most valuable "
|
||||
"Laws, and altering fundamentally the Forms of our Governments:\nFor "
|
||||
"suspending our own Legislatures, and declaring themselves invested with "
|
||||
"power to legislate for us in all cases whatsoever.\nHe has abdicated "
|
||||
"Government here, by declaring us out of his Protection and waging War "
|
||||
"against us.\nHe has plundered our seas, ravaged our Coasts, burnt our "
|
||||
"towns, and destroyed the lives of our people.\nHe is at this time "
|
||||
"transporting large Armies of foreign Mercenaries to compleat the works of "
|
||||
"death, desolation and tyranny, already begun with circumstances of "
|
||||
"Cruelty & perfidy scarcely paralleled in the most barbarous ages, and "
|
||||
"totally unworthy the Head of a civilized nation.\nHe has constrained our "
|
||||
"fellow Citizens taken Captive on the high Seas to bear Arms against their "
|
||||
"Country, to become the executioners of their friends and Brethren, or to "
|
||||
"fall themselves by their Hands.\nHe has excited domestic insurrections "
|
||||
"amongst us, and has endeavoured to bring on the inhabitants of our "
|
||||
"frontiers, the merciless Indian Savages, whose known rule of warfare, is "
|
||||
"an undistinguished destruction of all ages, sexes and conditions.\n\nIn "
|
||||
"every stage of these Oppressions We have Petitioned for Redress in the "
|
||||
"most humble terms: Our repeated Petitions have been answered only by "
|
||||
"repeated injury. A Prince whose character is thus marked by every act "
|
||||
"which may define a Tyrant, is unfit to be the ruler of a free "
|
||||
"people.\n\nNor have We been wanting in attentions to our Brittish "
|
||||
"brethren. We have warned them from time to time of attempts by their "
|
||||
"legislature to extend an unwarrantable jurisdiction over us. We have "
|
||||
"reminded them of the circumstances of our emigration and settlement here. "
|
||||
"We have appealed to their native justice and magnanimity, and we have "
|
||||
"conjured them by the ties of our common kindred to disavow these "
|
||||
"usurpations, which, would inevitably interrupt our connections and "
|
||||
"correspondence. They too have been deaf to the voice of justice and of "
|
||||
"consanguinity. We must, therefore, acquiesce in the necessity, which "
|
||||
"denounces our Separation, and hold them, as we hold the rest of mankind, "
|
||||
"Enemies in War, in Peace Friends.\n\nWe, therefore, the Representatives "
|
||||
"of the united States of America, in General Congress, Assembled, "
|
||||
"appealing to the Supreme Judge of the world for the rectitude of our "
|
||||
"intentions, do, in the Name, and by Authority of the good People of these "
|
||||
"Colonies, solemnly publish and declare, That these United Colonies are, "
|
||||
"and of Right ought to be Free and Independent States; that they are "
|
||||
"Absolved from all Allegiance to the British Crown, and that all political "
|
||||
"connection between them and the State of Great Britain, is and ought to "
|
||||
"be totally dissolved; and that as Free and Independent States, they have "
|
||||
"full Power to levy War, conclude Peace, contract Alliances, establish "
|
||||
"Commerce, and to do all other Acts and Things which Independent States "
|
||||
"may of right do. And for the support of this Declaration, with a firm "
|
||||
"reliance on the protection of divine Providence, we mutually pledge to "
|
||||
"each other our Lives, our Fortunes and our sacred Honor.\n"};
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropySmall) {
|
||||
static const char kSmall[] =
|
||||
"The capital of Hungary is Budapest which is located in Europe.";
|
||||
float entropy = GemmaCrossEntropy(kSmall);
|
||||
std::cout << "per-byte entropy: " << entropy << "\n";
|
||||
EXPECT_LT(entropy, 1.6f);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
||||
float entropy = GemmaCrossEntropy(kJingleBells);
|
||||
std::cout << "per-byte entropy: " << entropy << "\n";
|
||||
EXPECT_LT(entropy, 2.3f);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
||||
float entropy = GemmaCrossEntropy(kGettysburg);
|
||||
std::cout << "per-byte entropy: " << entropy << "\n";
|
||||
EXPECT_LT(entropy, 1.2f);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropyDeclaration) {
|
||||
float entropy = GemmaCrossEntropy(kDeclaration);
|
||||
std::cout << "per-byte entropy: " << entropy << "\n";
|
||||
EXPECT_LT(entropy, 1.0f);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
+++
|
||||
list me ten biggest cities
|
||||
---
|
||||
8. Mexico City
|
||||
+++
|
||||
What is the capital of Switzerland?
|
||||
---
|
||||
The capital of Switzerland is Bern.
|
||||
+++
|
||||
What is the name of the president of the US?
|
||||
---
|
||||
I cannot answer this question
|
||||
+++
|
||||
Is it raining frequently in China?
|
||||
---
|
||||
average of 1,200 rainy days per year
|
||||
+++
|
||||
tell me a french joke
|
||||
---
|
||||
What do you call a French person who's always complaining?
|
||||
+++
|
||||
which year did the 21th century started?
|
||||
---
|
||||
The 21st century started in the year **1900**
|
||||
+++
|
||||
what's your favorite pokemon?
|
||||
---
|
||||
I am unable to provide a favorite Pokémon
|
||||
+++
|
||||
How to bake a tasty cake?
|
||||
---
|
||||
* 1 and 1/2 cups all-purpose flour
|
||||
+++
|
||||
Which is the richest country in the world?
|
||||
---
|
||||
richest country in the world is the United States
|
||||
+++
|
||||
do you like electronic music?
|
||||
---
|
||||
Electronic music is a broad genre
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
+++
|
||||
list me ten biggest cities
|
||||
---
|
||||
6. Mexico City, Mexico
|
||||
+++
|
||||
What is the capital of Switzerland?
|
||||
---
|
||||
Bern is the capital of Switzerland.
|
||||
+++
|
||||
What is the name of the president of the US?
|
||||
---
|
||||
The answer is: Joe Biden.
|
||||
32
run.cc
32
run.cc
|
|
@ -40,13 +40,14 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
static constexpr std::string_view kAsciiArtBanner =
|
||||
" __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __\n"
|
||||
" / _` |/ _ \\ '_ ` _ \\| '_ ` _ \\ / _` | / __| '_ \\| '_ \\\n"
|
||||
"| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |\n"
|
||||
" \\__, |\\___|_| |_| |_|_| |_| |_|\\__,_(_)___| .__/| .__/\n"
|
||||
" __/ | | | | |\n"
|
||||
" |___/ |_| |_|";
|
||||
static constexpr std::string_view kAsciiArtBanner = R""(
|
||||
__ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __
|
||||
/ _` |/ _ \ '_ ` _ \| '_ ` _ \ / _` | / __| '_ \| '_ \
|
||||
| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |
|
||||
\__, |\___|_| |_| |_|_| |_| |_|\__,_(_)___| .__/| .__/
|
||||
__/ | | | | |
|
||||
|___/ |_| |_|
|
||||
)"";
|
||||
|
||||
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||
loader.Print(app.verbosity);
|
||||
|
|
@ -60,7 +61,7 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
<< "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize
|
||||
<< "\n"
|
||||
<< "Hardware concurrency : "
|
||||
<< std::thread::hardware_concurrency() << std::endl
|
||||
<< std::thread::hardware_concurrency() << "\n"
|
||||
<< "Instruction set : "
|
||||
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
|
||||
<< hwy::VectorBytes() * 8 << " bits)" << "\n"
|
||||
|
|
@ -132,13 +133,13 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
|||
}
|
||||
} else {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text).ok());
|
||||
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text));
|
||||
// +1 since position is incremented above
|
||||
if (current_pos == prompt_size + 1) {
|
||||
// first token of response
|
||||
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
||||
if (verbosity >= 1) {
|
||||
std::cout << std::endl << std::endl;
|
||||
std::cout << "\n\n";
|
||||
}
|
||||
}
|
||||
std::cout << token_text << std::flush;
|
||||
|
|
@ -189,7 +190,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
|||
}
|
||||
}
|
||||
|
||||
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt).ok());
|
||||
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
|
||||
|
||||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
||||
// if needed.
|
||||
|
|
@ -199,7 +200,8 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
|||
|
||||
prompt_size = prompt.size();
|
||||
|
||||
std::cerr << std::endl << "[ Reading prompt ] " << std::flush;
|
||||
std::cerr << "\n"
|
||||
<< "[ Reading prompt ] " << std::flush;
|
||||
|
||||
const double time_start = hwy::platform::Now();
|
||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||
|
|
@ -209,10 +211,10 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
|||
const double tok_sec = current_pos / (time_end - time_start);
|
||||
if (verbosity >= 2) {
|
||||
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
|
||||
<< std::endl
|
||||
<< tok_sec << " tokens / sec" << std::endl;
|
||||
<< "\n"
|
||||
<< tok_sec << " tokens / sec" << "\n";
|
||||
}
|
||||
std::cout << std::endl << std::endl;
|
||||
std::cout << "\n\n";
|
||||
}
|
||||
std::cout
|
||||
<< "max_tokens (" << args.max_tokens
|
||||
|
|
|
|||
10
util/app.h
10
util/app.h
|
|
@ -25,7 +25,7 @@
|
|||
#include <cctype>
|
||||
#include <cerrno> // IDE does not recognize errno.h as providing errno.
|
||||
#include <string>
|
||||
#endif
|
||||
#endif // HWY_OS_LINUX
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
|
|
@ -170,8 +170,8 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
weights = compressed_weights;
|
||||
} else {
|
||||
return "Only one of --weights and --compressed_weights can be "
|
||||
"specified. To create compressed weights use the compress_weights "
|
||||
"tool.";
|
||||
"specified. To create compressed weights use the "
|
||||
"compress_weights tool.";
|
||||
}
|
||||
}
|
||||
if (weights.path.empty()) {
|
||||
|
|
@ -179,12 +179,12 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
}
|
||||
if (!weights.exists()) {
|
||||
return "Can't open file specified with --weights flag.";
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Path tokenizer;
|
||||
Path weights; // weights file location
|
||||
Path weights; // weights file location
|
||||
Path compressed_weights;
|
||||
std::string model_type;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue