mirror of https://github.com/google/gemma.cpp.git
Major duplicated code reduction in test/benchmarks
Helper functions to tokenize/wrap Move LayersOutputFunc into RuntimeConfig AcceptFunc passes the probability Implement StringFromType using the parser, and verify results match PiperOrigin-RevId: 643255119
This commit is contained in:
parent
c15ff9529c
commit
d3c6a45b59
58
BUILD.bazel
58
BUILD.bazel
|
|
@ -113,12 +113,8 @@ cc_library(
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "cross_entropy",
|
name = "cross_entropy",
|
||||||
srcs = [
|
srcs = ["gemma/cross_entropy.cc"],
|
||||||
"gemma/cross_entropy.cc",
|
hdrs = ["gemma/cross_entropy.h"],
|
||||||
],
|
|
||||||
hdrs = [
|
|
||||||
"gemma/cross_entropy.h",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
":common",
|
":common",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
|
|
@ -149,6 +145,25 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "benchmark_helper",
|
||||||
|
srcs = ["gemma/benchmark_helper.cc"],
|
||||||
|
hdrs = ["gemma/benchmark_helper.h"],
|
||||||
|
deps = [
|
||||||
|
":app",
|
||||||
|
":args",
|
||||||
|
":common",
|
||||||
|
":cross_entropy",
|
||||||
|
":gemma_lib",
|
||||||
|
# Placeholder for internal dep, do not remove.,
|
||||||
|
"@benchmark//:benchmark",
|
||||||
|
"//compression:compress",
|
||||||
|
"@hwy//:hwy",
|
||||||
|
"@hwy//:nanobenchmark",
|
||||||
|
"@hwy//:thread_pool",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "gemma_test",
|
name = "gemma_test",
|
||||||
srcs = ["gemma/gemma_test.cc"],
|
srcs = ["gemma/gemma_test.cc"],
|
||||||
|
|
@ -161,11 +176,11 @@ cc_test(
|
||||||
deps = [
|
deps = [
|
||||||
":app",
|
":app",
|
||||||
":args",
|
":args",
|
||||||
|
":benchmark_helper",
|
||||||
":common",
|
":common",
|
||||||
":cross_entropy",
|
":cross_entropy",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":ops",
|
":ops",
|
||||||
# Placeholder for internal dep, do not remove.,
|
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@hwy//:hwy_test_util",
|
"@hwy//:hwy_test_util",
|
||||||
|
|
@ -179,6 +194,7 @@ cc_binary(
|
||||||
deps = [
|
deps = [
|
||||||
":app",
|
":app",
|
||||||
":args",
|
":args",
|
||||||
|
":benchmark_helper",
|
||||||
":common",
|
":common",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
# Placeholder for internal dep, do not remove.,
|
# Placeholder for internal dep, do not remove.,
|
||||||
|
|
@ -213,10 +229,10 @@ cc_binary(
|
||||||
deps = [
|
deps = [
|
||||||
":app",
|
":app",
|
||||||
":args",
|
":args",
|
||||||
|
":benchmark_helper",
|
||||||
":common",
|
":common",
|
||||||
":cross_entropy",
|
":cross_entropy",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
# Placeholder for internal dep, do not remove.,
|
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:nanobenchmark",
|
"@hwy//:nanobenchmark",
|
||||||
|
|
@ -230,7 +246,6 @@ cc_binary(
|
||||||
srcs = ["gemma/benchmarks.cc"],
|
srcs = ["gemma/benchmarks.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":benchmark_helper",
|
":benchmark_helper",
|
||||||
# Placeholder for internal dep, do not remove.,
|
|
||||||
"@benchmark//:benchmark",
|
"@benchmark//:benchmark",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -243,8 +258,8 @@ cc_binary(
|
||||||
deps = [
|
deps = [
|
||||||
":app",
|
":app",
|
||||||
":args",
|
":args",
|
||||||
|
":benchmark_helper",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
# Placeholder for internal dep, do not remove.,
|
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:thread_pool",
|
"@hwy//:thread_pool",
|
||||||
|
|
@ -257,8 +272,10 @@ cc_binary(
|
||||||
srcs = ["gemma/run_mmlu.cc"],
|
srcs = ["gemma/run_mmlu.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":app",
|
":app",
|
||||||
|
":args",
|
||||||
|
":benchmark_helper",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
# Placeholder for internal dep, do not remove.,
|
"//compression:io",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:profiler",
|
"@hwy//:profiler",
|
||||||
"@hwy//:thread_pool",
|
"@hwy//:thread_pool",
|
||||||
|
|
@ -318,25 +335,6 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "benchmark_helper",
|
|
||||||
srcs = [
|
|
||||||
"gemma/benchmark_helper.cc",
|
|
||||||
],
|
|
||||||
hdrs = [
|
|
||||||
"gemma/benchmark_helper.h",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
":app",
|
|
||||||
":common",
|
|
||||||
":gemma_lib",
|
|
||||||
"@benchmark//:benchmark",
|
|
||||||
"@hwy//:hwy",
|
|
||||||
"@hwy//:nanobenchmark",
|
|
||||||
"@hwy//:thread_pool",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "backward_scalar_test",
|
name = "backward_scalar_test",
|
||||||
size = "large",
|
size = "large",
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,8 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility> // std::move
|
#include <utility> // std::move
|
||||||
|
|
||||||
|
#include "hwy/base.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// Forward-declare to break the circular dependency: OpenFileOrNull returns
|
// Forward-declare to break the circular dependency: OpenFileOrNull returns
|
||||||
|
|
@ -77,12 +79,30 @@ struct Path {
|
||||||
return path;
|
return path;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Empty() const { return path.empty(); }
|
||||||
|
|
||||||
// Returns whether the file existed when this was called.
|
// Returns whether the file existed when this was called.
|
||||||
bool Exists() const { return !!OpenFileOrNull(*this, "r"); }
|
bool Exists() const { return !!OpenFileOrNull(*this, "r"); }
|
||||||
|
|
||||||
std::string path;
|
std::string path;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static inline HWY_MAYBE_UNUSED std::string ReadFileToString(const Path& path) {
|
||||||
|
std::unique_ptr<File> file = OpenFileOrNull(path, "r");
|
||||||
|
if (!file) {
|
||||||
|
HWY_ABORT("Failed to open %s", path.path.c_str());
|
||||||
|
}
|
||||||
|
const size_t size = file->FileSize();
|
||||||
|
if (size == 0) {
|
||||||
|
HWY_ABORT("Empty file %s", path.path.c_str());
|
||||||
|
}
|
||||||
|
std::string content(size, ' ');
|
||||||
|
if (!file->Read(0, size, content.data())) {
|
||||||
|
HWY_ABORT("Failed to read %s", path.path.c_str());
|
||||||
|
}
|
||||||
|
return content;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_
|
||||||
|
|
|
||||||
155
debug_prompt.cc
155
debug_prompt.cc
|
|
@ -1,146 +1,83 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <random>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// Placeholder for internal header, do not modify.
|
|
||||||
#include "compression/io.h"
|
#include "compression/io.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/benchmark_helper.h"
|
||||||
#include "util/app.h"
|
#include "gemma/gemma.h" // LayersOutputFunc
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
|
||||||
#include "nlohmann/json.hpp"
|
#include "nlohmann/json.hpp"
|
||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
|
||||||
class PromptArgs : public gcpp::ArgsBase<PromptArgs> {
|
namespace gcpp {
|
||||||
|
|
||||||
|
class PromptArgs : public ArgsBase<PromptArgs> {
|
||||||
public:
|
public:
|
||||||
PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
|
|
||||||
gcpp::Path layers_output;
|
Path layers_output; // optional
|
||||||
std::string prompt;
|
std::string prompt;
|
||||||
|
|
||||||
|
// Returns error string or nullptr if OK.
|
||||||
|
const char* Validate() const {
|
||||||
|
if (prompt.empty()) return "Must specify --prompt";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
void ForEach(const Visitor& visitor) {
|
void ForEach(const Visitor& visitor) {
|
||||||
visitor(layers_output.path, "layers_output", std::string(""),
|
visitor(layers_output, "layers_output", Path(""),
|
||||||
"Path to store layers output", 2);
|
"Path to store layers output", 2);
|
||||||
visitor(prompt, "prompt", std::string(""), "Prompt to the model", 2);
|
visitor(prompt, "prompt", std::string(""), "Prompt to the model", 2);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::pair<std::string, int> QueryModel(
|
int Run(int argc, char** argv) {
|
||||||
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
PromptArgs prompt_args(argc, argv);
|
||||||
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input,
|
AbortIfInvalidArgs(prompt_args);
|
||||||
gcpp::LayersOutputT* layers_output) {
|
|
||||||
std::vector<int> prompt;
|
|
||||||
HWY_ASSERT(model.Tokenizer().Encode(input, &prompt));
|
|
||||||
|
|
||||||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
|
||||||
// if needed.
|
|
||||||
prompt.insert(prompt.begin(), gcpp::BOS_ID);
|
|
||||||
std::string res;
|
|
||||||
size_t total_tokens = 0;
|
|
||||||
std::mt19937 gen;
|
|
||||||
gen.seed(42);
|
|
||||||
|
|
||||||
auto stream_token = [&res, &total_tokens, &model](int token, float) {
|
|
||||||
++total_tokens;
|
|
||||||
std::string token_text;
|
|
||||||
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
|
||||||
res += token_text;
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
if (app.verbosity >= 2) {
|
|
||||||
std::cout << args.max_tokens << " " << args.max_generated_tokens << " "
|
|
||||||
<< args.temperature;
|
|
||||||
}
|
|
||||||
gcpp::TimingInfo timing_info;
|
|
||||||
gcpp::RuntimeConfig runtime_config = {
|
|
||||||
.max_tokens = args.max_tokens,
|
|
||||||
.max_generated_tokens = args.max_generated_tokens,
|
|
||||||
.temperature = args.temperature,
|
|
||||||
.verbosity = app.verbosity,
|
|
||||||
.gen = &gen,
|
|
||||||
.stream_token = stream_token,
|
|
||||||
};
|
|
||||||
model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, timing_info,
|
|
||||||
layers_output);
|
|
||||||
return {res, total_tokens};
|
|
||||||
}
|
|
||||||
|
|
||||||
class OutputJsonLogger {
|
|
||||||
public:
|
|
||||||
json json_output;
|
json json_output;
|
||||||
|
GemmaEnv env(argc, argv);
|
||||||
gcpp::LayersOutputT layers_output_log_f =
|
env.MutableConfig().layers_output =
|
||||||
[this](int pos, const std::string& key, const float* values,
|
prompt_args.layers_output.Empty()
|
||||||
|
? LayersOutputFunc()
|
||||||
|
: [&json_output](int pos, const std::string& key, const float* values,
|
||||||
size_t values_len) {
|
size_t values_len) {
|
||||||
std::vector<float> v{values, values + values_len};
|
std::vector<float> v{values, values + values_len};
|
||||||
json_output[std::to_string(pos)][key] = v;
|
json_output[std::to_string(pos)][key] = v;
|
||||||
};
|
};
|
||||||
};
|
|
||||||
|
|
||||||
/* Run this in the same way as gemma, p.ex.:
|
const auto [answer, token_count] = env.QueryModel(prompt_args.prompt);
|
||||||
./debug_prompt --tokenizer tokenizer.spm --model 2b-it --weights \
|
std::cout << answer.substr(prompt_args.prompt.size()) << "\n" << std::flush;
|
||||||
2b-it-sfp.sbs --prompt "..." --layers_output [path]
|
|
||||||
*/
|
|
||||||
int main(int argc, char** argv) {
|
|
||||||
{
|
|
||||||
// Placeholder for internal init, do not modify.
|
|
||||||
}
|
|
||||||
|
|
||||||
gcpp::LoaderArgs loader(argc, argv);
|
if (env.MutableConfig().layers_output) {
|
||||||
gcpp::InferenceArgs args(argc, argv); // inference
|
|
||||||
gcpp::AppArgs app(argc, argv);
|
|
||||||
PromptArgs prompt_args(argc, argv);
|
|
||||||
|
|
||||||
if (const char* error = loader.Validate()) {
|
|
||||||
HWY_ABORT("\nInvalid loader args: %s", error);
|
|
||||||
}
|
|
||||||
if (const char* error = args.Validate()) {
|
|
||||||
HWY_ABORT("\nInvalid inference args: %s", error);
|
|
||||||
}
|
|
||||||
const bool log_layers_output = !prompt_args.layers_output.path.empty();
|
|
||||||
OutputJsonLogger json_logger;
|
|
||||||
gcpp::LayersOutputT* layers_output =
|
|
||||||
log_layers_output ? &json_logger.layers_output_log_f : nullptr;
|
|
||||||
|
|
||||||
hwy::ThreadPool pool(app.num_threads);
|
|
||||||
// For many-core, pinning workers to cores helps.
|
|
||||||
if (app.num_threads > 10) {
|
|
||||||
gcpp::PinWorkersToCores(pool);
|
|
||||||
}
|
|
||||||
|
|
||||||
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
|
|
||||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
|
|
||||||
|
|
||||||
const std::string& prompt = prompt_args.prompt;
|
|
||||||
if (prompt.empty()) {
|
|
||||||
std::cout << "Please specify --prompt" << std::endl;
|
|
||||||
return EXIT_FAILURE;
|
|
||||||
}
|
|
||||||
const auto [answer, token_count] = QueryModel(
|
|
||||||
model, args, app, kv_cache, pool, prompt, layers_output);
|
|
||||||
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
|
|
||||||
|
|
||||||
if (log_layers_output) {
|
|
||||||
std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out);
|
std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out);
|
||||||
if (!output_f) {
|
if (!output_f) HWY_ABORT("Opening layer output file failed");
|
||||||
std::cout << "Opening file failed" << std::endl;
|
output_f << json_output.dump();
|
||||||
return EXIT_FAILURE;
|
if (!output_f) HWY_ABORT("Writing to layer output file failed");
|
||||||
}
|
|
||||||
output_f << json_logger.json_output.dump();
|
|
||||||
if (!output_f) {
|
|
||||||
std::cout << "Writing to file failed" << std::endl;
|
|
||||||
return EXIT_FAILURE;
|
|
||||||
}
|
|
||||||
output_f.close();
|
output_f.close();
|
||||||
}
|
}
|
||||||
|
return 0;
|
||||||
return EXIT_SUCCESS;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
int main(int argc, char** argv) { return gcpp::Run(argc, argv); }
|
||||||
|
|
|
||||||
|
|
@ -1,36 +1,36 @@
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstdlib> // EXIT_FAILURE
|
#include <cstdlib> // EXIT_FAILURE
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <ostream>
|
#include <ostream>
|
||||||
#include <random>
|
|
||||||
#include <sstream>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility> // std::pair
|
#include <utility> // std::pair
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// Placeholder for internal header, do not modify.
|
|
||||||
#include "compression/io.h" // Path
|
#include "compression/io.h" // Path
|
||||||
|
#include "gemma/benchmark_helper.h"
|
||||||
|
#include "gemma/common.h"
|
||||||
#include "gemma/cross_entropy.h"
|
#include "gemma/cross_entropy.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "util/app.h"
|
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
|
||||||
#include "hwy/highway.h"
|
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
#include "nlohmann/json.hpp"
|
#include "nlohmann/json.hpp"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
|
||||||
class BenchmarkArgs : public gcpp::ArgsBase<BenchmarkArgs> {
|
class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
|
||||||
public:
|
public:
|
||||||
BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
|
|
||||||
gcpp::Path goldens;
|
Path goldens;
|
||||||
gcpp::Path summarize_text;
|
Path summarize_text;
|
||||||
gcpp::Path cross_entropy;
|
Path cross_entropy;
|
||||||
gcpp::Path trivia_qa;
|
Path trivia_qa;
|
||||||
size_t max_questions;
|
size_t max_questions;
|
||||||
size_t batch_tokens;
|
size_t batch_tokens;
|
||||||
|
|
||||||
|
|
@ -53,61 +53,6 @@ class BenchmarkArgs : public gcpp::ArgsBase<BenchmarkArgs> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void LogSpeedStats(const double time_start, size_t total_tokens) {
|
|
||||||
const double time_end = hwy::platform::Now();
|
|
||||||
const double time_elapsed = time_end - time_start;
|
|
||||||
const double tok_sec = total_tokens / time_elapsed;
|
|
||||||
std::cout << total_tokens << " tokens in " << time_elapsed << " seconds"
|
|
||||||
<< " [" << tok_sec << " tokens / sec" << "]\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
std::pair<std::string, int> QueryModel(
|
|
||||||
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
|
||||||
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input) {
|
|
||||||
std::vector<int> prompt;
|
|
||||||
HWY_ASSERT(model.Tokenizer().Encode(input, &prompt));
|
|
||||||
|
|
||||||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
|
||||||
// if needed.
|
|
||||||
prompt.insert(prompt.begin(), 2);
|
|
||||||
std::string res;
|
|
||||||
size_t total_tokens = 0;
|
|
||||||
std::mt19937 gen;
|
|
||||||
gen.seed(42);
|
|
||||||
|
|
||||||
const double time_start = hwy::platform::Now();
|
|
||||||
auto stream_token = [&res, &total_tokens, &time_start, &app, &model](
|
|
||||||
int token, float) {
|
|
||||||
++total_tokens;
|
|
||||||
std::string token_text;
|
|
||||||
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
|
||||||
res += token_text;
|
|
||||||
if (app.verbosity >= 1 && total_tokens % 100 == 0) {
|
|
||||||
LogSpeedStats(time_start, total_tokens);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
if (app.verbosity >= 2) {
|
|
||||||
std::cout << args.max_tokens << " " << args.max_generated_tokens << " "
|
|
||||||
<< args.temperature;
|
|
||||||
}
|
|
||||||
gcpp::TimingInfo timing_info;
|
|
||||||
gcpp::RuntimeConfig runtime_config = {
|
|
||||||
.max_tokens = args.max_tokens,
|
|
||||||
.max_generated_tokens = args.max_generated_tokens,
|
|
||||||
.temperature = args.temperature,
|
|
||||||
.verbosity = app.verbosity,
|
|
||||||
.gen = &gen,
|
|
||||||
.stream_token = stream_token,
|
|
||||||
};
|
|
||||||
model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, timing_info,
|
|
||||||
/*layers_output=*/nullptr);
|
|
||||||
if (app.verbosity >= 1) {
|
|
||||||
LogSpeedStats(time_start, total_tokens);
|
|
||||||
}
|
|
||||||
return {res, total_tokens};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::pair<std::string, std::string>> load_goldens(
|
std::vector<std::pair<std::string, std::string>> load_goldens(
|
||||||
const std::string& path) {
|
const std::string& path) {
|
||||||
std::ifstream goldens_file(path);
|
std::ifstream goldens_file(path);
|
||||||
|
|
@ -129,28 +74,14 @@ std::vector<std::pair<std::string, std::string>> load_goldens(
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string ReadFile(const gcpp::Path& path) {
|
int BenchmarkGoldens(GemmaEnv& env, const std::string& golden_path) {
|
||||||
std::ifstream text_file(path.path);
|
std::vector<std::pair<std::string, std::string>> queries_answers =
|
||||||
if (!text_file) {
|
|
||||||
std::cout << "Could not open file: " << path.path << "\n" << std::flush;
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
std::stringstream buffer;
|
|
||||||
buffer << text_file.rdbuf();
|
|
||||||
return buffer.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
|
||||||
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
|
||||||
hwy::ThreadPool& pool, const std::string& golden_path) {
|
|
||||||
const std::vector<std::pair<std::string, std::string>> queries_answers =
|
|
||||||
load_goldens(golden_path);
|
load_goldens(golden_path);
|
||||||
int correct_answers = 0;
|
size_t correct_answers = 0;
|
||||||
int total_tokens = 0;
|
size_t total_tokens = 0;
|
||||||
const double time_start = hwy::platform::Now();
|
const double time_start = hwy::platform::Now();
|
||||||
for (const auto& [question, expected_answer] : queries_answers) {
|
for (auto& [question, expected_answer] : queries_answers) {
|
||||||
const auto [answer, token_count] =
|
const auto [answer, token_count] = env.QueryModel(question);
|
||||||
QueryModel(model, args, app, kv_cache, pool, question);
|
|
||||||
total_tokens += token_count;
|
total_tokens += token_count;
|
||||||
if (answer.find(expected_answer) != std::string::npos) {
|
if (answer.find(expected_answer) != std::string::npos) {
|
||||||
correct_answers++;
|
correct_answers++;
|
||||||
|
|
@ -172,28 +103,22 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||||
return EXIT_SUCCESS;
|
return EXIT_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
int BenchmarkSummary(GemmaEnv& env, const Path& text) {
|
||||||
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
|
||||||
hwy::ThreadPool& pool, const gcpp::Path& text) {
|
|
||||||
std::string prompt("Here is some text to summarize:\n");
|
std::string prompt("Here is some text to summarize:\n");
|
||||||
prompt.append(ReadFile(text));
|
prompt.append(ReadFileToString(text));
|
||||||
prompt.append("\nSummarize this text.\n");
|
prompt.append("\nSummarize this text.\n");
|
||||||
const double time_start = hwy::platform::Now();
|
const double time_start = hwy::platform::Now();
|
||||||
const auto [answer, token_count] =
|
const auto [answer, token_count] = env.QueryModel(prompt);
|
||||||
QueryModel(model, args, app, kv_cache, pool, prompt);
|
|
||||||
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
|
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
|
||||||
LogSpeedStats(time_start, token_count);
|
LogSpeedStats(time_start, token_count);
|
||||||
return EXIT_SUCCESS;
|
return EXIT_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
|
int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
||||||
gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
|
||||||
hwy::ThreadPool& pool, const gcpp::Path& text,
|
|
||||||
size_t batch_tokens) {
|
size_t batch_tokens) {
|
||||||
std::string input = ReadFile(text);
|
std::string input = ReadFileToString(text);
|
||||||
std::vector<int> prompt;
|
std::vector<int> prompt = env.Tokenize(input);
|
||||||
HWY_ASSERT(model.Tokenizer().Encode(input, &prompt));
|
prompt.resize(std::min<size_t>(env.MaxTokens(), prompt.size()));
|
||||||
prompt.resize(std::min<size_t>(args.max_tokens, prompt.size()));
|
|
||||||
std::cout << "Number of input tokens: " << prompt.size() << "\n";
|
std::cout << "Number of input tokens: " << prompt.size() << "\n";
|
||||||
const double time_start = hwy::platform::Now();
|
const double time_start = hwy::platform::Now();
|
||||||
float total_entropy = 0.0f;
|
float total_entropy = 0.0f;
|
||||||
|
|
@ -203,13 +128,12 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
|
||||||
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
||||||
std::vector<int> prompt_slice(prompt.begin() + pos,
|
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||||
prompt.begin() + pos + num_tokens);
|
prompt.begin() + pos + num_tokens);
|
||||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(model_type);
|
KVCache kv_cache = KVCache::Create(env.ModelType());
|
||||||
float entropy = ComputeCrossEntropy(model, num_tokens, prompt_slice,
|
float entropy = ComputeCrossEntropy(
|
||||||
kv_cache, app.verbosity);
|
*env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
|
||||||
total_entropy += entropy;
|
total_entropy += entropy;
|
||||||
LogSpeedStats(time_start, pos + num_tokens);
|
LogSpeedStats(time_start, pos + num_tokens);
|
||||||
std::string text_slice;
|
std::string text_slice = env.StringFromTokens(prompt_slice);
|
||||||
HWY_ASSERT(model.Tokenizer().Decode(prompt_slice, &text_slice));
|
|
||||||
total_input_len += text_slice.size();
|
total_input_len += text_slice.size();
|
||||||
printf("Total cross entropy: %f [cumulative: %f]\n",
|
printf("Total cross entropy: %f [cumulative: %f]\n",
|
||||||
entropy, total_entropy);
|
entropy, total_entropy);
|
||||||
|
|
@ -219,23 +143,19 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
|
||||||
return EXIT_SUCCESS;
|
return EXIT_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file,
|
||||||
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
|
||||||
hwy::ThreadPool& pool, const gcpp::Path& json_file,
|
|
||||||
size_t max_questions) {
|
size_t max_questions) {
|
||||||
std::ifstream trivia_file(json_file.path);
|
std::ifstream trivia_file(json_file.path);
|
||||||
if (!trivia_file) {
|
if (!trivia_file) {
|
||||||
std::cout << "Could not load file: " << json_file.path << "\n"
|
HWY_ABORT("Could not load file %s\n", json_file.path.c_str());
|
||||||
<< std::flush;
|
|
||||||
return EXIT_FAILURE;
|
|
||||||
}
|
}
|
||||||
std::string line;
|
std::string line;
|
||||||
size_t correct_answers = 0;
|
size_t correct_answers = 0;
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
while (std::getline(trivia_file, line)) {
|
while (std::getline(trivia_file, line)) {
|
||||||
json data = json::parse(line);
|
json data = json::parse(line);
|
||||||
const auto [answer, token_count] = QueryModel(
|
std::string q(data["question"]);
|
||||||
model, args, app, kv_cache, pool, data["question"]);
|
const auto [answer, token_count] = env.QueryModel(q);
|
||||||
std::cout << answer << "\n";
|
std::cout << answer << "\n";
|
||||||
bool correct = false;
|
bool correct = false;
|
||||||
for (const std::string expected : data["answer"]["aliases"]) {
|
for (const std::string expected : data["answer"]["aliases"]) {
|
||||||
|
|
@ -256,52 +176,25 @@ int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||||
return EXIT_SUCCESS;
|
return EXIT_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Run this in the same way as gemma, p.ex.:
|
} // namespace gcpp
|
||||||
./benchmark --tokenizer tokenizer.spm --model 2b-it --weights \
|
|
||||||
2b-it-sfp.sbs --goldens_dir "../goldens"
|
|
||||||
*/
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
{
|
gcpp::GemmaEnv env(argc, argv);
|
||||||
// Placeholder for internal init, do not modify.
|
gcpp::BenchmarkArgs benchmark_args(argc, argv);
|
||||||
}
|
|
||||||
|
|
||||||
gcpp::LoaderArgs loader(argc, argv);
|
if (!benchmark_args.goldens.Empty()) {
|
||||||
gcpp::InferenceArgs args(argc, argv); // inference
|
|
||||||
gcpp::AppArgs app(argc, argv);
|
|
||||||
BenchmarkArgs benchmark_args(argc, argv);
|
|
||||||
|
|
||||||
if (const char* error = loader.Validate()) {
|
|
||||||
HWY_ABORT("\nInvalid loader args: %s", error);
|
|
||||||
}
|
|
||||||
if (const char* error = args.Validate()) {
|
|
||||||
HWY_ABORT("\nInvalid inference args: %s", error);
|
|
||||||
}
|
|
||||||
|
|
||||||
hwy::ThreadPool pool(app.num_threads);
|
|
||||||
// For many-core, pinning workers to cores helps.
|
|
||||||
if (app.num_threads > 10) {
|
|
||||||
gcpp::PinWorkersToCores(pool);
|
|
||||||
}
|
|
||||||
|
|
||||||
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
|
|
||||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
|
|
||||||
|
|
||||||
if (!benchmark_args.goldens.path.empty()) {
|
|
||||||
const std::string golden_path =
|
const std::string golden_path =
|
||||||
benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt";
|
benchmark_args.goldens.path + "/" +
|
||||||
return BenchmarkGoldens(model, args, app, kv_cache, pool, golden_path);
|
gcpp::ModelString(env.ModelType(), env.ModelTrainingType()) + ".txt";
|
||||||
} else if (!benchmark_args.summarize_text.path.empty()) {
|
return BenchmarkGoldens(env, golden_path);
|
||||||
return BenchmarkSummary(model, args, app, kv_cache, pool,
|
} else if (!benchmark_args.summarize_text.Empty()) {
|
||||||
benchmark_args.summarize_text);
|
return BenchmarkSummary(env, benchmark_args.summarize_text);
|
||||||
} else if (!benchmark_args.cross_entropy.path.empty()) {
|
} else if (!benchmark_args.cross_entropy.Empty()) {
|
||||||
return BenchmarkCrossEntropy(model, loader.ModelType(), args, app,
|
return BenchmarkCrossEntropy(env, benchmark_args.cross_entropy,
|
||||||
pool, benchmark_args.cross_entropy,
|
|
||||||
benchmark_args.batch_tokens);
|
benchmark_args.batch_tokens);
|
||||||
} else if (!benchmark_args.trivia_qa.path.empty()) {
|
} else if (!benchmark_args.trivia_qa.Empty()) {
|
||||||
return BenchmarkTriviaQA(model, args, app, kv_cache, pool,
|
return BenchmarkTriviaQA(env, benchmark_args.trivia_qa,
|
||||||
benchmark_args.trivia_qa,
|
|
||||||
benchmark_args.max_questions);
|
benchmark_args.max_questions);
|
||||||
}
|
}
|
||||||
std::cout << "No benchmark command given." << "\n" << std::flush;
|
HWY_ABORT("No benchmark command given.");
|
||||||
return EXIT_FAILURE;
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,70 +14,95 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "gemma/benchmark_helper.h"
|
#include "gemma/benchmark_helper.h"
|
||||||
#include <cstdlib> // EXIT_FAILURE
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <time.h>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <ostream>
|
#include <ostream>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <thread> // NOLINT
|
||||||
#include <utility> // std::pair
|
#include <utility> // std::pair
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gemma/common.h"
|
// Placeholder for internal header, do not modify.
|
||||||
|
#include "compression/compress.h" // TypeName
|
||||||
|
#include "gemma/common.h" // StringFromType
|
||||||
|
#include "gemma/cross_entropy.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "util/app.h"
|
#include "util/app.h"
|
||||||
|
#include "util/args.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
#include "hwy/per_target.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
|
||||||
|
if (inference.deterministic) {
|
||||||
|
// Nothing up my sleeve number, at least some upper bits set.
|
||||||
|
gen.seed(0x12345678);
|
||||||
|
} else {
|
||||||
|
// Depending on the library implementation, this may still be deterministic.
|
||||||
|
std::random_device rd;
|
||||||
|
gen.seed(rd());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
GemmaEnv::GemmaEnv(int argc, char** argv)
|
GemmaEnv::GemmaEnv(int argc, char** argv)
|
||||||
: loader_(argc, argv), inference_args_(argc, argv), app_(argc, argv),
|
: loader_(argc, argv),
|
||||||
|
inference_args_(argc, argv),
|
||||||
|
app_(argc, argv),
|
||||||
pool_(app_.num_threads) {
|
pool_(app_.num_threads) {
|
||||||
if (const char* error = loader_.Validate()) {
|
{
|
||||||
HWY_ABORT("\nInvalid loader args: %s", error);
|
// Placeholder for internal init, do not modify.
|
||||||
}
|
|
||||||
if (const char* error = inference_args_.Validate()) {
|
|
||||||
HWY_ABORT("\nInvalid inference args: %s", error);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// For many-core, pinning workers to cores helps.
|
// For many-core, pinning workers to cores helps.
|
||||||
if (app_.num_threads > 10) {
|
if (app_.num_threads > 10) {
|
||||||
gcpp::PinWorkersToCores(pool_);
|
gcpp::PinWorkersToCores(pool_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AbortIfInvalidArgs(inference_args_);
|
||||||
|
|
||||||
|
if (const char* err = loader_.Validate()) {
|
||||||
|
loader_.Help();
|
||||||
|
fprintf(stderr, "Skipping model load because: %s\n", err);
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "Loading model...\n");
|
||||||
model_ = AllocateGemma(loader_, pool_);
|
model_ = AllocateGemma(loader_, pool_);
|
||||||
kv_cache_ = KVCache::Create(loader_.ModelType());
|
kv_cache_ = KVCache::Create(loader_.ModelType());
|
||||||
gen_.seed(42);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::string, int> GemmaEnv::QueryModel(const std::string& input) {
|
InitGenerator(inference_args_, gen_);
|
||||||
std::string prompt_string = input;
|
|
||||||
if (loader_.ModelTrainingType() == ModelTraining::GEMMA_IT) {
|
|
||||||
// For instruction-tuned models: add control tokens.
|
|
||||||
prompt_string = "<start_of_turn>user\n" + input +
|
|
||||||
"<end_of_turn>\n<start_of_turn>model\n";
|
|
||||||
}
|
|
||||||
std::vector<int> prompt;
|
|
||||||
HWY_ASSERT(model_->Tokenizer().Encode(input, &prompt));
|
|
||||||
|
|
||||||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
runtime_config_ = {
|
||||||
// if needed.
|
.max_tokens = inference_args_.max_tokens,
|
||||||
prompt.insert(prompt.begin(), gcpp::BOS_ID);
|
.max_generated_tokens = inference_args_.max_generated_tokens,
|
||||||
|
.temperature = inference_args_.temperature,
|
||||||
|
.verbosity = app_.verbosity,
|
||||||
|
.gen = &gen_,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
||||||
|
const std::vector<int>& tokens) {
|
||||||
std::string res;
|
std::string res;
|
||||||
size_t total_tokens = 0;
|
size_t total_tokens = 0;
|
||||||
auto accept_token = [](int) { return true; };
|
|
||||||
std::mt19937 gen;
|
|
||||||
gen.seed(42);
|
|
||||||
|
|
||||||
const double time_start = hwy::platform::Now();
|
const double time_start = hwy::platform::Now();
|
||||||
auto stream_token = [&res, &total_tokens, &time_start, this](
|
const StreamFunc stream_token = [&res, &total_tokens, &time_start, this](
|
||||||
int token, float) {
|
int token, float) {
|
||||||
++total_tokens;
|
++total_tokens;
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
HWY_ASSERT(model_->Tokenizer().Decode(std::vector<int>{token},
|
HWY_ASSERT(
|
||||||
&token_text));
|
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||||
res += token_text;
|
res += token_text;
|
||||||
if (app_.verbosity >= 1 && total_tokens % 100 == 0) {
|
if (app_.verbosity >= 1 && total_tokens % 128 == 0) {
|
||||||
LogSpeedStats(time_start, total_tokens);
|
LogSpeedStats(time_start, total_tokens);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
|
@ -88,24 +113,32 @@ std::pair<std::string, int> GemmaEnv::QueryModel(const std::string& input) {
|
||||||
<< inference_args_.temperature;
|
<< inference_args_.temperature;
|
||||||
}
|
}
|
||||||
gcpp::TimingInfo timing_info;
|
gcpp::TimingInfo timing_info;
|
||||||
gcpp::RuntimeConfig runtime_config = {
|
runtime_config_.stream_token = stream_token;
|
||||||
.max_tokens = inference_args_.max_tokens,
|
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_cache_,
|
||||||
.max_generated_tokens = inference_args_.max_generated_tokens,
|
timing_info);
|
||||||
.temperature = inference_args_.temperature,
|
|
||||||
.verbosity = app_.verbosity,
|
|
||||||
.gen = &gen,
|
|
||||||
.stream_token = stream_token,
|
|
||||||
.accept_token = accept_token,
|
|
||||||
};
|
|
||||||
model_->Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache_,
|
|
||||||
timing_info, /*layers_output=*/nullptr);
|
|
||||||
if (app_.verbosity >= 1) {
|
if (app_.verbosity >= 1) {
|
||||||
LogSpeedStats(time_start, total_tokens);
|
LogSpeedStats(time_start, total_tokens);
|
||||||
}
|
}
|
||||||
return {res, total_tokens};
|
return {res, total_tokens};
|
||||||
}
|
}
|
||||||
|
|
||||||
void GemmaEnv::LogSpeedStats(double time_start, size_t total_tokens) const {
|
std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) {
|
||||||
|
const std::vector<int> prompt =
|
||||||
|
WrapAndTokenize(model_->Tokenizer(), loader_.ModelTrainingType(),
|
||||||
|
/*pos=*/0, input);
|
||||||
|
return QueryModel(prompt);
|
||||||
|
}
|
||||||
|
|
||||||
|
float GemmaEnv::CrossEntropy(const std::string& input) {
|
||||||
|
std::vector<int> prompt = Tokenize(input);
|
||||||
|
prompt.insert(prompt.begin(), BOS_ID);
|
||||||
|
return ComputeCrossEntropy(*GetModel(), /*max_tokens=*/3072, prompt,
|
||||||
|
MutableKVCache(),
|
||||||
|
/*verbosity=*/0) /
|
||||||
|
static_cast<int>(input.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
void LogSpeedStats(double time_start, size_t total_tokens) {
|
||||||
const double time_end = hwy::platform::Now();
|
const double time_end = hwy::platform::Now();
|
||||||
const double time_elapsed = time_end - time_start;
|
const double time_elapsed = time_end - time_start;
|
||||||
const double tok_sec = total_tokens / time_elapsed;
|
const double tok_sec = total_tokens / time_elapsed;
|
||||||
|
|
@ -113,6 +146,53 @@ void GemmaEnv::LogSpeedStats(double time_start, size_t total_tokens) const {
|
||||||
<< " [" << tok_sec << " tokens / sec" << "]\n";
|
<< " [" << tok_sec << " tokens / sec" << "]\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
|
loader.Print(app.verbosity);
|
||||||
|
inference.Print(app.verbosity);
|
||||||
|
app.Print(app.verbosity);
|
||||||
|
|
||||||
|
if (app.verbosity >= 2) {
|
||||||
|
time_t now = time(nullptr);
|
||||||
|
char* dt = ctime(&now); // NOLINT
|
||||||
|
std::cout << "Date & Time : " << dt
|
||||||
|
<< "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize
|
||||||
|
<< "\n"
|
||||||
|
<< "Hardware concurrency : "
|
||||||
|
<< std::thread::hardware_concurrency() << "\n"
|
||||||
|
<< "Instruction set : "
|
||||||
|
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
|
||||||
|
<< hwy::VectorBytes() * 8 << " bits)" << "\n";
|
||||||
|
char cpu100[100];
|
||||||
|
if (hwy::platform::GetCpuString(cpu100)) {
|
||||||
|
std::cout << "CPU : " << cpu100 << "\n";
|
||||||
|
}
|
||||||
|
std::cout << "Compiled config : " << CompiledConfig() << "\n"
|
||||||
|
<< "Weight Type : "
|
||||||
|
<< gcpp::StringFromType(loader.WeightType()) << "\n"
|
||||||
|
<< "EmbedderInput Type : "
|
||||||
|
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
|
||||||
|
gcpp::AppArgs& app) {
|
||||||
|
std::cerr
|
||||||
|
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
|
||||||
|
"==========================================================\n\n"
|
||||||
|
"To run gemma.cpp, you need to "
|
||||||
|
"specify 3 required model loading arguments:\n"
|
||||||
|
" --tokenizer\n"
|
||||||
|
" --weights\n"
|
||||||
|
" --model.\n";
|
||||||
|
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
|
||||||
|
"--weights 2b-it-sfp.sbs --model 2b-it\n";
|
||||||
|
std::cerr << "\n*Model Loading Arguments*\n\n";
|
||||||
|
loader.Help();
|
||||||
|
std::cerr << "\n*Inference Arguments*\n\n";
|
||||||
|
inference.Help();
|
||||||
|
std::cerr << "\n*Application Arguments*\n\n";
|
||||||
|
app.Help();
|
||||||
|
std::cerr << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,11 +16,15 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "gemma/common.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "util/app.h"
|
#include "util/app.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
@ -28,24 +32,60 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
void InitGenerator(const InferenceArgs& inference, std::mt19937& gen);
|
||||||
|
|
||||||
// Convenience class to load a model and run inference.
|
// Convenience class to load a model and run inference.
|
||||||
class GemmaEnv {
|
class GemmaEnv {
|
||||||
public:
|
public:
|
||||||
GemmaEnv(int argc, char** argv);
|
GemmaEnv(int argc, char** argv);
|
||||||
|
|
||||||
|
size_t MaxTokens() const { return inference_args_.max_tokens; }
|
||||||
// Sets the maximum number of output tokens to generate.
|
// Sets the maximum number of output tokens to generate.
|
||||||
void set_max_generated_tokens(int max_tokens) {
|
void SetMaxGeneratedTokens(size_t max_tokens) {
|
||||||
inference_args_.max_generated_tokens = max_tokens;
|
inference_args_.max_generated_tokens = max_tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int> Tokenize(const std::string& input) const {
|
||||||
|
std::vector<int> tokens;
|
||||||
|
HWY_ASSERT(model_->Tokenizer().Encode(input, &tokens));
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> TokenizeAndPrependBOS(const std::string& input) const {
|
||||||
|
std::vector<int> tokens = Tokenize(input);
|
||||||
|
tokens.insert(tokens.begin(), BOS_ID);
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string StringFromTokens(const std::vector<int>& tokens) const {
|
||||||
|
std::string string;
|
||||||
|
HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string));
|
||||||
|
return string;
|
||||||
|
}
|
||||||
|
|
||||||
// Runs inference on the given input and returns the top-1 result string and
|
// Runs inference on the given input and returns the top-1 result string and
|
||||||
// the number of tokens that were generated.
|
// the number of tokens that were generated.
|
||||||
std::pair<std::string, int> QueryModel(const std::string& input);
|
std::pair<std::string, size_t> QueryModel(const std::vector<int>& tokens);
|
||||||
|
// Adds turn structure to input, tokenizes and calls the above overload.
|
||||||
|
std::pair<std::string, size_t> QueryModel(std::string& input);
|
||||||
|
|
||||||
|
// Runs inference on the given input and returns the cross entropy, a measure
|
||||||
|
// of how well the model predicts the correct output. It is the average
|
||||||
|
// number of bits per token.
|
||||||
|
float CrossEntropy(const std::string& input);
|
||||||
|
|
||||||
|
// Returns nullptr if the model failed to load.
|
||||||
|
Gemma* GetModel() const { return model_.get(); }
|
||||||
|
Model ModelType() const { return loader_.ModelType(); }
|
||||||
|
ModelTraining ModelTrainingType() const {
|
||||||
|
return loader_.ModelTrainingType();
|
||||||
|
}
|
||||||
|
int Verbosity() const { return app_.verbosity; }
|
||||||
|
gcpp::RuntimeConfig& MutableConfig() { return runtime_config_; }
|
||||||
|
std::mt19937& MutableGen() { return gen_; }
|
||||||
|
KVCache& MutableKVCache() { return kv_cache_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Logs the inference speed in tokens/sec.
|
|
||||||
void LogSpeedStats(double time_start, size_t total_tokens) const;
|
|
||||||
|
|
||||||
// Arguments to the model loader: file locations, etc.
|
// Arguments to the model loader: file locations, etc.
|
||||||
LoaderArgs loader_;
|
LoaderArgs loader_;
|
||||||
// Arguments to the inference function: max tokens, etc.
|
// Arguments to the inference function: max tokens, etc.
|
||||||
|
|
@ -60,10 +100,16 @@ class GemmaEnv {
|
||||||
std::unique_ptr<Gemma> model_;
|
std::unique_ptr<Gemma> model_;
|
||||||
// The KV cache to use for inference.
|
// The KV cache to use for inference.
|
||||||
KVCache kv_cache_;
|
KVCache kv_cache_;
|
||||||
|
gcpp::RuntimeConfig runtime_config_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Logs the inference speed in tokens/sec.
|
||||||
|
void LogSpeedStats(double time_start, size_t total_tokens);
|
||||||
|
|
||||||
|
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
|
||||||
|
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
|
||||||
|
gcpp::AppArgs& app);
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_
|
||||||
|
|
|
||||||
|
|
@ -13,86 +13,60 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include <fstream>
|
#include <stddef.h>
|
||||||
#include <iostream>
|
#include <stdio.h>
|
||||||
#include <ostream>
|
|
||||||
#include <random>
|
|
||||||
#include <sstream>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
// Placeholder for internal header, do not modify.
|
|
||||||
#include "benchmark/benchmark.h"
|
#include "benchmark/benchmark.h"
|
||||||
#include "gemma/benchmark_helper.h"
|
#include "gemma/benchmark_helper.h"
|
||||||
|
|
||||||
void run_gemma_prompt(const std::string& prompt_string,
|
namespace gcpp {
|
||||||
gcpp::GemmaEnv& env,
|
|
||||||
benchmark::State& state) {
|
|
||||||
std::mt19937 gen;
|
|
||||||
|
|
||||||
if (prompt_string.empty()) return;
|
// Shared state for benchmarks - unfortunately the library does not allow
|
||||||
|
// passing context nor closures. Raw pointer because style guide forbids
|
||||||
|
// non-local static objects with dtors.t
|
||||||
|
GemmaEnv* s_env = nullptr;
|
||||||
|
|
||||||
int token_counter = 0;
|
void RunPrompt(const std::string& original_prompt, benchmark::State& state) {
|
||||||
|
size_t total_tokens = 0;
|
||||||
for (auto s : state) {
|
for (auto s : state) {
|
||||||
auto [response, n] = env.QueryModel(prompt_string);
|
std::string prompt = original_prompt; // reset from original
|
||||||
std::cout << "response: " << response << "\n";
|
auto [response, n] = s_env->QueryModel(prompt);
|
||||||
std::cout << "n: " << n << "\n";
|
if (s_env->Verbosity() != 0) {
|
||||||
token_counter += n;
|
fprintf(stdout, "|%s|\n", response.c_str());
|
||||||
|
}
|
||||||
|
total_tokens += n;
|
||||||
}
|
}
|
||||||
|
|
||||||
state.SetItemsProcessed(token_counter);
|
state.SetItemsProcessed(total_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Awkward global because benchmarks don't support additional state, so it is
|
} // namespace gcpp
|
||||||
// either this or cast to int64_t.
|
|
||||||
gcpp::GemmaEnv* global_env = nullptr;
|
|
||||||
|
|
||||||
static void BM_short_prompt(benchmark::State& state) {
|
static void BM_short_prompt(benchmark::State& state) {
|
||||||
run_gemma_prompt("What is the capital of Spain?", *global_env,
|
gcpp::RunPrompt("What is the capital of Spain?", state);
|
||||||
state);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void BM_factuality_prompt(benchmark::State& state) {
|
static void BM_factuality_prompt(benchmark::State& state) {
|
||||||
run_gemma_prompt("How does an inkjet printer work?",
|
gcpp::RunPrompt("How does an inkjet printer work?", state);
|
||||||
*global_env, state);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void BM_creative_prompt(benchmark::State& state) {
|
static void BM_creative_prompt(benchmark::State& state) {
|
||||||
run_gemma_prompt(
|
gcpp::RunPrompt("Tell me a story about a magical bunny and their TRS-80.",
|
||||||
"Tell me a story about a magical bunny and their TRS-80.",
|
state);
|
||||||
*global_env, state);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void BM_coding_prompt(benchmark::State& state) {
|
static void BM_coding_prompt(benchmark::State& state) {
|
||||||
run_gemma_prompt(
|
gcpp::RunPrompt("Write a python program to generate a fibonacci sequence.",
|
||||||
"Write a python program to generate a fibonacci sequence.",
|
state);
|
||||||
*global_env, state);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void BM_long_coding_prompt(benchmark::State& state) {
|
|
||||||
std::ifstream t("benchmarks.cc", std::ios_base::in);
|
|
||||||
std::stringstream buffer;
|
|
||||||
buffer << t.rdbuf();
|
|
||||||
std::string prompt_string = buffer.str();
|
|
||||||
t.close();
|
|
||||||
|
|
||||||
run_gemma_prompt("Make improvements to the following code:\n " +
|
|
||||||
prompt_string, *global_env, state);
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
|
||||||
{
|
|
||||||
// Placeholder for internal init, do not modify.
|
|
||||||
}
|
|
||||||
gcpp::GemmaEnv env(argc, argv);
|
|
||||||
|
|
||||||
env.set_max_generated_tokens(128);
|
|
||||||
global_env = &env;
|
|
||||||
BENCHMARK(BM_short_prompt)
|
BENCHMARK(BM_short_prompt)
|
||||||
->Iterations(3)
|
->Iterations(3)
|
||||||
->Unit(benchmark::kMillisecond)
|
->Unit(benchmark::kMillisecond)
|
||||||
->UseRealTime();
|
->UseRealTime();
|
||||||
|
|
||||||
env.set_max_generated_tokens(256);
|
|
||||||
BENCHMARK(BM_factuality_prompt)
|
BENCHMARK(BM_factuality_prompt)
|
||||||
->Iterations(3)
|
->Iterations(3)
|
||||||
->Unit(benchmark::kMillisecond)
|
->Unit(benchmark::kMillisecond)
|
||||||
|
|
@ -108,11 +82,10 @@ int main(int argc, char** argv) {
|
||||||
->Unit(benchmark::kMillisecond)
|
->Unit(benchmark::kMillisecond)
|
||||||
->UseRealTime();
|
->UseRealTime();
|
||||||
|
|
||||||
env.set_max_generated_tokens(1024);
|
int main(int argc, char** argv) {
|
||||||
BENCHMARK(BM_long_coding_prompt)
|
gcpp::GemmaEnv env(argc, argv);
|
||||||
->Iterations(3)
|
env.SetMaxGeneratedTokens(256);
|
||||||
->Unit(benchmark::kMillisecond)
|
gcpp::s_env = &env;
|
||||||
->UseRealTime();
|
|
||||||
|
|
||||||
::benchmark::RunSpecifiedBenchmarks();
|
::benchmark::RunSpecifiedBenchmarks();
|
||||||
::benchmark::Shutdown();
|
::benchmark::Shutdown();
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,8 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
namespace {
|
const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
||||||
|
Model& model, ModelTraining& training) {
|
||||||
constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt", "2b-it",
|
constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt", "2b-it",
|
||||||
"7b-it", "gr2b-it", "tiny"};
|
"7b-it", "gr2b-it", "tiny"};
|
||||||
constexpr Model kModelTypes[] = {
|
constexpr Model kModelTypes[] = {
|
||||||
|
|
@ -38,10 +39,7 @@ constexpr ModelTraining kModelTraining[] = {
|
||||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT,
|
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT,
|
||||||
ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT,
|
ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT,
|
||||||
ModelTraining::GEMMA_IT};
|
ModelTraining::GEMMA_IT};
|
||||||
} // namespace
|
|
||||||
|
|
||||||
const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
|
||||||
Model& model, ModelTraining& training) {
|
|
||||||
constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags);
|
constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags);
|
||||||
static char kErrorMessageBuffer[kNum * 8 + 1024] =
|
static char kErrorMessageBuffer[kNum * 8 + 1024] =
|
||||||
"Invalid or missing model flag, need to specify one of ";
|
"Invalid or missing model flag, need to specify one of ";
|
||||||
|
|
@ -51,37 +49,58 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
||||||
}
|
}
|
||||||
strcat(kErrorMessageBuffer, kModelFlags[kNum - 1]); // NOLINT
|
strcat(kErrorMessageBuffer, kModelFlags[kNum - 1]); // NOLINT
|
||||||
strcat(kErrorMessageBuffer, "."); // NOLINT
|
strcat(kErrorMessageBuffer, "."); // NOLINT
|
||||||
|
|
||||||
std::string model_type_lc = model_flag;
|
std::string model_type_lc = model_flag;
|
||||||
std::transform(begin(model_type_lc), end(model_type_lc), begin(model_type_lc),
|
std::transform(begin(model_type_lc), end(model_type_lc), begin(model_type_lc),
|
||||||
[](unsigned char c) { return std::tolower(c); });
|
[](unsigned char c) { return std::tolower(c); });
|
||||||
|
|
||||||
for (size_t i = 0; i < kNum; i++) {
|
for (size_t i = 0; i < kNum; i++) {
|
||||||
if (kModelFlags[i] == model_type_lc) {
|
if (kModelFlags[i] == model_type_lc) {
|
||||||
model = kModelTypes[i];
|
model = kModelTypes[i];
|
||||||
training = kModelTraining[i];
|
training = kModelTraining[i];
|
||||||
|
HWY_ASSERT(std::string(ModelString(model, training)) == model_type_lc);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return kErrorMessageBuffer;
|
return kErrorMessageBuffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char* ModelString(Model model, ModelTraining training) {
|
||||||
|
if (model == Model::GEMMA_TINY) return "tiny";
|
||||||
|
static_assert(static_cast<size_t>(ModelTraining::GEMMA_IT) == 0);
|
||||||
|
constexpr const char* k2B[] = {"2b-it", "2b-pt"};
|
||||||
|
constexpr const char* k7B[] = {"7b-it", "7b-pt"};
|
||||||
|
constexpr const char* kGr2B[] = {"gr2b-it", "gr2b-pt"};
|
||||||
|
if (model == Model::GEMMA_2B) return k2B[static_cast<size_t>(training)];
|
||||||
|
if (model == Model::GEMMA_7B) return k7B[static_cast<size_t>(training)];
|
||||||
|
if (model == Model::GRIFFIN_2B) return kGr2B[static_cast<size_t>(training)];
|
||||||
|
HWY_ABORT("Unknown model %d training %d\n", static_cast<int>(model),
|
||||||
|
static_cast<int>(training));
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr const char* kTypeStrings[] = {"f32", "bf16", "sfp"};
|
||||||
|
|
||||||
|
const char* StringFromType(Type type) {
|
||||||
|
return kTypeStrings[static_cast<size_t>(type)];
|
||||||
|
}
|
||||||
|
|
||||||
const char* ParseType(const std::string& type_string, Type& type) {
|
const char* ParseType(const std::string& type_string, Type& type) {
|
||||||
constexpr Type kTypes[] = {Type::kF32, Type::kBF16, Type::kSFP};
|
constexpr size_t kNum = std::end(kTypeStrings) - std::begin(kTypeStrings);
|
||||||
constexpr const char* kStrings[] = {"f32", "bf16", "sfp"};
|
|
||||||
constexpr size_t kNum = std::end(kStrings) - std::begin(kStrings);
|
|
||||||
static char kErrorMessageBuffer[kNum * 8 + 100] =
|
static char kErrorMessageBuffer[kNum * 8 + 100] =
|
||||||
"Invalid or missing type, need to specify one of ";
|
"Invalid or missing type, need to specify one of ";
|
||||||
for (size_t i = 0; i + 1 < kNum; i++) {
|
for (size_t i = 0; i + 1 < kNum; i++) {
|
||||||
strcat(kErrorMessageBuffer, kStrings[i]); // NOLINT
|
strcat(kErrorMessageBuffer, kTypeStrings[i]); // NOLINT
|
||||||
strcat(kErrorMessageBuffer, ", "); // NOLINT
|
strcat(kErrorMessageBuffer, ", "); // NOLINT
|
||||||
}
|
}
|
||||||
strcat(kErrorMessageBuffer, kStrings[kNum - 1]); // NOLINT
|
strcat(kErrorMessageBuffer, kTypeStrings[kNum - 1]); // NOLINT
|
||||||
strcat(kErrorMessageBuffer, "."); // NOLINT
|
strcat(kErrorMessageBuffer, "."); // NOLINT
|
||||||
std::string type_lc = type_string;
|
std::string type_lc = type_string;
|
||||||
std::transform(begin(type_lc), end(type_lc), begin(type_lc),
|
std::transform(begin(type_lc), end(type_lc), begin(type_lc),
|
||||||
[](unsigned char c) { return std::tolower(c); });
|
[](unsigned char c) { return std::tolower(c); });
|
||||||
for (size_t i = 0; i < kNum; i++) {
|
for (size_t i = 0; i < kNum; i++) {
|
||||||
if (kStrings[i] == type_lc) {
|
if (kTypeStrings[i] == type_lc) {
|
||||||
type = kTypes[i];
|
type = static_cast<Type>(i);
|
||||||
|
HWY_ASSERT(std::string(StringFromType(type)) == type_lc);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -154,18 +154,12 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
||||||
Model& model, ModelTraining& training);
|
Model& model, ModelTraining& training);
|
||||||
const char* ParseType(const std::string& type_string, Type& type);
|
const char* ParseType(const std::string& type_string, Type& type);
|
||||||
|
|
||||||
static inline const char* StringFromType(Type type) {
|
// Inverse of ParseModelTypeAndTraining.
|
||||||
switch (type) {
|
const char* ModelString(Model model, ModelTraining training);
|
||||||
case Type::kF32:
|
const char* StringFromType(Type type);
|
||||||
return "f32";
|
|
||||||
case Type::kBF16:
|
// ----------------------------------------------------------------------------
|
||||||
return "bf16";
|
//
|
||||||
case Type::kSFP:
|
|
||||||
return "sfp";
|
|
||||||
default:
|
|
||||||
return "?";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// __builtin_sqrt is not constexpr as of Clang 17.
|
// __builtin_sqrt is not constexpr as of Clang 17.
|
||||||
#if HWY_COMPILER_GCC_ACTUAL
|
#if HWY_COMPILER_GCC_ACTUAL
|
||||||
|
|
|
||||||
|
|
@ -111,7 +111,7 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||||
};
|
};
|
||||||
TimingInfo timing_info;
|
TimingInfo timing_info;
|
||||||
|
|
||||||
gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info, nullptr);
|
gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info);
|
||||||
|
|
||||||
const float scale = 1.0f / std::log(2.0f);
|
const float scale = 1.0f / std::log(2.0f);
|
||||||
return cross_entropy * scale;
|
return cross_entropy * scale;
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@
|
||||||
|
|
||||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||||
// which we pass the filename via macro 'argument'.
|
// which we pass the filename via macro 'argument'.
|
||||||
#include <cstdio>
|
|
||||||
#undef HWY_TARGET_INCLUDE
|
#undef HWY_TARGET_INCLUDE
|
||||||
#define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT
|
#define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
|
|
@ -208,8 +207,8 @@ bool GemmaTokenizer::Encode(const std::string& input,
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GemmaTokenizer::Encode(const std::string& input,
|
bool GemmaTokenizer::Encode(const std::string& input,
|
||||||
std::vector<int>* pieces) const {
|
std::vector<int>* ids) const {
|
||||||
return impl_->Encode(input, pieces);
|
return impl_->Encode(input, ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Given a sequence of ids, decodes it into a detokenized output.
|
// Given a sequence of ids, decodes it into a detokenized output.
|
||||||
|
|
@ -653,12 +652,12 @@ HWY_NOINLINE void Transformer(const int *tokens, size_t num_tokens, size_t pos,
|
||||||
const WeightArrayT& weights,
|
const WeightArrayT& weights,
|
||||||
Activations<TConfig, kBatchSize>& activations,
|
Activations<TConfig, kBatchSize>& activations,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
LayersOutputT* layers_output) {
|
const LayersOutputFunc& layers_output) {
|
||||||
HWY_ASSERT(num_tokens <= kBatchSize);
|
HWY_ASSERT(num_tokens <= kBatchSize);
|
||||||
if (layers_output != nullptr) {
|
if (layers_output) {
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
float token_f = tokens[token_idx];
|
float token_f = tokens[token_idx];
|
||||||
(*layers_output)(pos + token_idx, "Tokens", &token_f, 1);
|
layers_output(pos + token_idx, "Tokens", &token_f, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
|
|
@ -713,12 +712,11 @@ HWY_NOINLINE void Transformer(const int *tokens, size_t num_tokens, size_t pos,
|
||||||
}
|
}
|
||||||
AddFromBatched<kBatchSize>(num_tokens, activations.ffw_out.data(),
|
AddFromBatched<kBatchSize>(num_tokens, activations.ffw_out.data(),
|
||||||
activations.x.data(), kModelDim);
|
activations.x.data(), kModelDim);
|
||||||
if (layers_output != nullptr) {
|
if (layers_output) {
|
||||||
std::string block_name = "blocks." + std::to_string(layer);
|
std::string block_name = "blocks." + std::to_string(layer);
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
(*layers_output)(pos + token_idx, block_name,
|
layers_output(pos + token_idx, block_name,
|
||||||
activations.x.data() + token_idx * kModelDim,
|
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||||
kModelDim);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -727,9 +725,9 @@ HWY_NOINLINE void Transformer(const int *tokens, size_t num_tokens, size_t pos,
|
||||||
|
|
||||||
RMSNormInplaceBatched<kBatchSize>(num_tokens, weights.final_norm_scale.data(),
|
RMSNormInplaceBatched<kBatchSize>(num_tokens, weights.final_norm_scale.data(),
|
||||||
activations.x.data(), kModelDim);
|
activations.x.data(), kModelDim);
|
||||||
if (layers_output != nullptr) {
|
if (layers_output) {
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
(*layers_output)(pos + token_idx, "final_norm",
|
layers_output(pos + token_idx, "final_norm",
|
||||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -782,8 +780,7 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
const ByteStorageT& decode_u8,
|
const ByteStorageT& decode_u8,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, TimingInfo& timing_info,
|
hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
||||||
LayersOutputT* layers_output) {
|
|
||||||
const CompressedWeights<TConfig>& weights = GetWeights<TConfig>(weights_u8);
|
const CompressedWeights<TConfig>& weights = GetWeights<TConfig>(weights_u8);
|
||||||
auto& prefill_activations =
|
auto& prefill_activations =
|
||||||
GetActivations<TConfig, kPrefillBatchSize>(prefill_u8);
|
GetActivations<TConfig, kPrefillBatchSize>(prefill_u8);
|
||||||
|
|
@ -860,7 +857,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
pos < max_tokens && generate_pos < max_generated_tokens;
|
pos < max_tokens && generate_pos < max_generated_tokens;
|
||||||
++pos, ++pos_offset, ++generate_pos) {
|
++pos, ++pos_offset, ++generate_pos) {
|
||||||
Transformer<kDecodeBatchSize>(&token, kDecodeBatchSize, pos, weights,
|
Transformer<kDecodeBatchSize>(&token, kDecodeBatchSize, pos, weights,
|
||||||
activations, kv_cache, pool, layers_output);
|
activations, kv_cache, pool,
|
||||||
|
runtime_config.layers_output);
|
||||||
float token_logit = 0.0f;
|
float token_logit = 0.0f;
|
||||||
// The condition below is always true if we are doing Prefill above.
|
// The condition below is always true if we are doing Prefill above.
|
||||||
// We keep it here for clarity so that the code is correct even if Prefill
|
// We keep it here for clarity so that the code is correct even if Prefill
|
||||||
|
|
@ -953,17 +951,37 @@ Gemma::~Gemma() {
|
||||||
|
|
||||||
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
KVCache& kv_cache, TimingInfo& timing_info,
|
KVCache& kv_cache, TimingInfo& timing_info) {
|
||||||
LayersOutputT* layers_output) {
|
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||||
|
|
||||||
GEMMA_EXPORT_AND_DISPATCH(
|
GEMMA_EXPORT_AND_DISPATCH(
|
||||||
model_type_, weight_type_, GenerateT,
|
model_type_, weight_type_, GenerateT,
|
||||||
(weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos,
|
(weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos,
|
||||||
kv_cache, pool_, timing_info, layers_output));
|
kv_cache, pool_, timing_info));
|
||||||
|
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||||
|
const ModelTraining training, size_t pos,
|
||||||
|
std::string& prompt) {
|
||||||
|
// Instruction-tuned models are trained to expect control tokens.
|
||||||
|
if (training == ModelTraining::GEMMA_IT) {
|
||||||
|
// Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation.
|
||||||
|
const std::string start = (pos == 0)
|
||||||
|
? "<start_of_turn>user\n"
|
||||||
|
: "<end_of_turn>\n<start_of_turn>user\n";
|
||||||
|
prompt = start + prompt + "<end_of_turn>\n<start_of_turn>model\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> tokens;
|
||||||
|
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
|
||||||
|
// Both pre-trained and instruction-tuned require BOS as first token.
|
||||||
|
if (pos == 0) {
|
||||||
|
tokens.insert(tokens.begin(), gcpp::BOS_ID);
|
||||||
|
}
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
#endif // HWY_ONCE
|
#endif // HWY_ONCE
|
||||||
|
|
|
||||||
|
|
@ -74,10 +74,17 @@ class GemmaTokenizer {
|
||||||
using StreamFunc = std::function<bool(int, float)>;
|
using StreamFunc = std::function<bool(int, float)>;
|
||||||
// If not empty, AcceptFunc is called with token. It should return false for
|
// If not empty, AcceptFunc is called with token. It should return false for
|
||||||
// tokens you don't want to generate and true for tokens you want to generate.
|
// tokens you don't want to generate and true for tokens you want to generate.
|
||||||
using AcceptFunc = std::function<bool(int)>;
|
using AcceptFunc = std::function<bool(int, float)>;
|
||||||
// If not empty, SampleFunc is called with the probability distribution for the
|
// If not empty, SampleFunc is called with the probability distribution for the
|
||||||
// next token, and its return value is used as the next generated token.
|
// next token, and its return value is used as the next generated token.
|
||||||
using SampleFunc = std::function<int(const float*, size_t)>;
|
using SampleFunc = std::function<int(const float*, size_t)>;
|
||||||
|
// Will be called for layers output with:
|
||||||
|
// - position in the tokens sequence
|
||||||
|
// - name of the data, p.ex. "tokens", "block.1", "final_norm"
|
||||||
|
// - pointer to the data array
|
||||||
|
// - size of the data array
|
||||||
|
using LayersOutputFunc =
|
||||||
|
std::function<void(int, const std::string&, const float*, size_t)>;
|
||||||
|
|
||||||
struct RuntimeConfig {
|
struct RuntimeConfig {
|
||||||
size_t max_tokens;
|
size_t max_tokens;
|
||||||
|
|
@ -88,6 +95,7 @@ struct RuntimeConfig {
|
||||||
StreamFunc stream_token;
|
StreamFunc stream_token;
|
||||||
AcceptFunc accept_token; // if empty, accepts all tokens.
|
AcceptFunc accept_token; // if empty, accepts all tokens.
|
||||||
SampleFunc sample_func; // if empty, uses SampleTopK.
|
SampleFunc sample_func; // if empty, uses SampleTopK.
|
||||||
|
LayersOutputFunc layers_output; // if not empty, called after each layer.
|
||||||
int eos_id = EOS_ID;
|
int eos_id = EOS_ID;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -97,14 +105,6 @@ struct TimingInfo {
|
||||||
double time_to_first_token = 0.0;
|
double time_to_first_token = 0.0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Will be called for layers output with:
|
|
||||||
// - position in the tokens sequence
|
|
||||||
// - name of the data, p.ex. "tokens", "block.1", "final_norm"
|
|
||||||
// - pointer to the data array
|
|
||||||
// - size of the data array
|
|
||||||
using LayersOutputT =
|
|
||||||
std::function<void(int, const std::string&, const float*, size_t)>;
|
|
||||||
|
|
||||||
class Gemma {
|
class Gemma {
|
||||||
public:
|
public:
|
||||||
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||||
|
|
@ -121,12 +121,9 @@ class Gemma {
|
||||||
const ByteStorageT& Prefill() const { return prefill_u8_; }
|
const ByteStorageT& Prefill() const { return prefill_u8_; }
|
||||||
const ByteStorageT& Decode() const { return decode_u8_; }
|
const ByteStorageT& Decode() const { return decode_u8_; }
|
||||||
|
|
||||||
// layers_output is optional; if set - it will be called with the activations
|
|
||||||
// output after applying each layer.
|
|
||||||
void Generate(const RuntimeConfig& runtime_config,
|
void Generate(const RuntimeConfig& runtime_config,
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
KVCache& kv_cache, TimingInfo& timing_info,
|
KVCache& kv_cache, TimingInfo& timing_info);
|
||||||
LayersOutputT* layers_output = nullptr);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
hwy::ThreadPool& pool_;
|
hwy::ThreadPool& pool_;
|
||||||
|
|
@ -141,14 +138,19 @@ class Gemma {
|
||||||
Type weight_type_;
|
Type weight_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Adds BOS token and possibly 'turn' annotations, which depend on `training`
|
||||||
|
// and `pos`, the number of tokens decoded so far; returns the corresponding
|
||||||
|
// tokens. Asserts that tokenization is successful.
|
||||||
|
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||||
|
ModelTraining training, size_t pos,
|
||||||
|
std::string& prompt);
|
||||||
|
|
||||||
// DEPRECATED, call Gemma::Generate directly.
|
// DEPRECATED, call Gemma::Generate directly.
|
||||||
HWY_INLINE void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
|
HWY_INLINE void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& /*pool*/,
|
KVCache& kv_cache, hwy::ThreadPool& /*pool*/,
|
||||||
TimingInfo& timing_info,
|
TimingInfo& timing_info) {
|
||||||
LayersOutputT* layers_output) {
|
gemma.Generate(runtime_config, prompt, start_pos, kv_cache, timing_info);
|
||||||
gemma.Generate(runtime_config, prompt, start_pos, kv_cache, timing_info,
|
|
||||||
layers_output);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -17,89 +17,35 @@
|
||||||
|
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <random>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// Placeholder for internal header, do not modify.
|
#include "gemma/benchmark_helper.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/cross_entropy.h"
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
#include "gemma/ops.h"
|
|
||||||
#include "util/app.h"
|
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
|
||||||
#include "hwy/tests/test_util-inl.h"
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
int s_argc = 0;
|
// Shared state. Requires argc/argv, so construct in main and use the same raw
|
||||||
char** s_argv = nullptr;
|
// pointer approach as in benchmarks.cc. Note that the style guide forbids
|
||||||
|
// non-local static variables with dtors.
|
||||||
|
GemmaEnv* s_env = nullptr;
|
||||||
|
|
||||||
class GemmaTest : public ::testing::Test {
|
class GemmaTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
static void SetUpTestSuite() {
|
std::string GemmaReply(const std::string& prompt) {
|
||||||
gcpp::AppArgs app(s_argc, s_argv);
|
s_env->SetMaxGeneratedTokens(2048);
|
||||||
gcpp::LoaderArgs loader(s_argc, s_argv);
|
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
||||||
if (const char* err = loader.Validate()) {
|
s_env->MutableConfig().verbosity = 0;
|
||||||
fprintf(stderr, "Insufficient LoaderArgs, skipping e2e tests.\n");
|
// Using the turn structure worsens results.
|
||||||
} else {
|
const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt);
|
||||||
fprintf(stderr, "Loading model..\n");
|
auto [response, n] = s_env->QueryModel(tokens);
|
||||||
s_pool = std::make_unique<hwy::ThreadPool>(app.num_threads);
|
return response;
|
||||||
s_gemma = AllocateGemma(loader, *s_pool);
|
|
||||||
s_kv_cache = KVCache::Create(loader.ModelType());
|
|
||||||
s_model = loader.ModelType();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void TearDownTestSuite() {
|
|
||||||
s_pool.reset();
|
|
||||||
s_gemma.reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string GemmaReply(const std::string& prompt_string) {
|
|
||||||
std::mt19937 gen;
|
|
||||||
gen.seed(42);
|
|
||||||
|
|
||||||
std::vector<int> prompt;
|
|
||||||
HWY_ASSERT(s_gemma->Tokenizer().Encode(prompt_string, &prompt));
|
|
||||||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
|
||||||
// if needed.
|
|
||||||
prompt.insert(prompt.begin(), BOS_ID);
|
|
||||||
|
|
||||||
std::vector<int> response;
|
|
||||||
auto stream_token = [&response](int token, float) {
|
|
||||||
response.push_back(token);
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
gcpp::RuntimeConfig runtime_config = {
|
|
||||||
.max_tokens = 3072,
|
|
||||||
.max_generated_tokens = 2048,
|
|
||||||
.temperature = 1.0,
|
|
||||||
.verbosity = 0,
|
|
||||||
.gen = &gen,
|
|
||||||
.stream_token = stream_token,
|
|
||||||
};
|
|
||||||
gcpp::TimingInfo timing_info;
|
|
||||||
s_gemma->Generate(runtime_config, prompt, /*start_pos=*/0, s_kv_cache,
|
|
||||||
timing_info, /*layers_output=*/nullptr);
|
|
||||||
std::string response_text;
|
|
||||||
HWY_ASSERT(s_gemma->Tokenizer().Decode(response, &response_text));
|
|
||||||
return response_text;
|
|
||||||
}
|
|
||||||
|
|
||||||
float GemmaCrossEntropy(const std::string& prompt_string) {
|
|
||||||
std::vector<int> prompt;
|
|
||||||
HWY_ASSERT(s_gemma->Tokenizer().Encode(prompt_string, &prompt));
|
|
||||||
prompt.insert(prompt.begin(), BOS_ID);
|
|
||||||
return ComputeCrossEntropy(*s_gemma, /*max_tokens=*/3072, prompt,
|
|
||||||
s_kv_cache,
|
|
||||||
/*verbosity=*/0) /
|
|
||||||
prompt_string.size();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestQuestions(const char* kQA[][2], size_t num_questions) {
|
void TestQuestions(const char* kQA[][2], size_t num_questions) {
|
||||||
if (!s_gemma) return;
|
if (!s_env->GetModel()) return;
|
||||||
for (size_t i = 0; i < num_questions; ++i) {
|
for (size_t i = 0; i < num_questions; ++i) {
|
||||||
fprintf(stderr, "Question %zu\n\n", i + 1);
|
fprintf(stderr, "Question %zu\n\n", i + 1);
|
||||||
std::string response = GemmaReply(kQA[i][0]);
|
std::string response = GemmaReply(kQA[i][0]);
|
||||||
|
|
@ -107,18 +53,8 @@ class GemmaTest : public ::testing::Test {
|
||||||
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
|
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::unique_ptr<hwy::ThreadPool> s_pool;
|
|
||||||
static std::unique_ptr<gcpp::Gemma> s_gemma;
|
|
||||||
static gcpp::KVCache s_kv_cache;
|
|
||||||
static gcpp::Model s_model;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/*static*/ std::unique_ptr<hwy::ThreadPool> GemmaTest::s_pool;
|
|
||||||
/*static*/ std::unique_ptr<gcpp::Gemma> GemmaTest::s_gemma;
|
|
||||||
/*static*/ gcpp::KVCache GemmaTest::s_kv_cache;
|
|
||||||
/*static*/ gcpp::Model GemmaTest::s_model;
|
|
||||||
|
|
||||||
TEST_F(GemmaTest, Geography) {
|
TEST_F(GemmaTest, Geography) {
|
||||||
static const char* kQA[][2] = {
|
static const char* kQA[][2] = {
|
||||||
{"What is the capital of Hungary?", "Budapest"},
|
{"What is the capital of Hungary?", "Budapest"},
|
||||||
|
|
@ -130,7 +66,7 @@ TEST_F(GemmaTest, Geography) {
|
||||||
|
|
||||||
TEST_F(GemmaTest, History) {
|
TEST_F(GemmaTest, History) {
|
||||||
static const char* kQA[][2] = {
|
static const char* kQA[][2] = {
|
||||||
{"When was the Battle of Hastings?", "1066"},
|
{"When was the battle of Hastings?", "1066"},
|
||||||
};
|
};
|
||||||
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||||
TestQuestions(kQA, kNum);
|
TestQuestions(kQA, kNum);
|
||||||
|
|
@ -181,42 +117,39 @@ static const char kGettysburg[] = {
|
||||||
"people, for the people, shall not perish from the earth.\n"};
|
"people, for the people, shall not perish from the earth.\n"};
|
||||||
|
|
||||||
TEST_F(GemmaTest, CrossEntropySmall) {
|
TEST_F(GemmaTest, CrossEntropySmall) {
|
||||||
if (!s_gemma) return;
|
if (!s_env->GetModel()) return;
|
||||||
static const char kSmall[] =
|
static const char kSmall[] =
|
||||||
"The capital of Hungary is Budapest which is located in Europe.";
|
"The capital of Hungary is Budapest which is located in Europe.";
|
||||||
float entropy = GemmaCrossEntropy(kSmall);
|
float entropy = s_env->CrossEntropy(kSmall);
|
||||||
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
||||||
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 2.1f : 2.0f);
|
EXPECT_LT(entropy,
|
||||||
|
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 2.1f : 2.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
||||||
if (!s_gemma) return;
|
if (!s_env->GetModel()) return;
|
||||||
float entropy = GemmaCrossEntropy(kJingleBells);
|
float entropy = s_env->CrossEntropy(kJingleBells);
|
||||||
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
||||||
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 0.9f : 1.8f);
|
EXPECT_LT(entropy,
|
||||||
|
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 0.9f : 1.8f);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
||||||
if (!s_gemma) return;
|
if (!s_env->GetModel()) return;
|
||||||
float entropy = GemmaCrossEntropy(kGettysburg);
|
float entropy = s_env->CrossEntropy(kGettysburg);
|
||||||
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
||||||
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 0.8f : 1.2f);
|
EXPECT_LT(entropy,
|
||||||
|
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 0.8f : 1.2f);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
{
|
gcpp::GemmaEnv env(argc, argv);
|
||||||
// Placeholder for internal init, do not modify.
|
gcpp::s_env = &env;
|
||||||
}
|
|
||||||
|
|
||||||
// For later use by SetUp.
|
testing::InitGoogleTest(&argc, argv);
|
||||||
gcpp::s_argc = argc;
|
|
||||||
gcpp::s_argv = argv;
|
|
||||||
|
|
||||||
// Probably should be called before SetUpTestSuite.
|
|
||||||
testing::InitGoogleTest(&gcpp::s_argc, argv);
|
|
||||||
|
|
||||||
return RUN_ALL_TESTS();
|
return RUN_ALL_TESTS();
|
||||||
}
|
}
|
||||||
|
|
@ -1668,12 +1668,13 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
||||||
std::array<int, k> indices{};
|
std::array<int, k> indices{};
|
||||||
for (size_t i = 0; i < vocab_size; ++i) {
|
for (size_t i = 0; i < vocab_size; ++i) {
|
||||||
if (probabilities[i] < top_k[k - 1] &&
|
if (probabilities[i] < top_k[k - 1] &&
|
||||||
(!accept_token || accept_token(StaticCast<int>(i)))) {
|
(!accept_token || accept_token(StaticCast<int>(i), probabilities[i]))) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
for (size_t j = 0; j < k; ++j) {
|
for (size_t j = 0; j < k; ++j) {
|
||||||
if (probabilities[i] > top_k[j] &&
|
if (probabilities[i] > top_k[j] &&
|
||||||
(!accept_token || accept_token(StaticCast<int>(i)))) {
|
(!accept_token ||
|
||||||
|
accept_token(StaticCast<int>(i), probabilities[i]))) {
|
||||||
// shift elements by 1, insert the new value, move on to next value
|
// shift elements by 1, insert the new value, move on to next value
|
||||||
for (size_t idx = k - 1; idx > j; --idx) {
|
for (size_t idx = k - 1; idx > j; --idx) {
|
||||||
top_k[idx] = top_k[idx - 1];
|
top_k[idx] = top_k[idx - 1];
|
||||||
|
|
|
||||||
97
gemma/run.cc
97
gemma/run.cc
|
|
@ -15,27 +15,22 @@
|
||||||
|
|
||||||
// Command line text interface to gemma.
|
// Command line text interface to gemma.
|
||||||
|
|
||||||
#include <ctime>
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
#include <thread> // NOLINT
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// Placeholder for internal header, do not modify.
|
// Placeholder for internal header, do not modify.
|
||||||
#include "compression/compress.h"
|
#include "gemma/benchmark_helper.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/configs.h"
|
|
||||||
#include "gemma/gemma.h" // Gemma
|
#include "gemma/gemma.h" // Gemma
|
||||||
#include "util/app.h"
|
#include "util/app.h"
|
||||||
#include "util/args.h" // HasHelp
|
#include "util/args.h" // HasHelp
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/per_target.h"
|
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
#include "hwy/timer.h"
|
|
||||||
|
|
||||||
#if (!defined(HWY_VERSION_LT) || HWY_VERSION_LT(1, 2)) && !HWY_IDE
|
#if (!defined(HWY_VERSION_LT) || HWY_VERSION_LT(1, 2)) && !HWY_IDE
|
||||||
#error "Please update to version 1.2 of github.com/google/highway."
|
#error "Please update to version 1.2 of github.com/google/highway."
|
||||||
|
|
@ -57,56 +52,6 @@ static constexpr std::string_view kAsciiArtBanner = R""(
|
||||||
|___/ |_| |_|
|
|___/ |_| |_|
|
||||||
)"";
|
)"";
|
||||||
|
|
||||||
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|
||||||
loader.Print(app.verbosity);
|
|
||||||
inference.Print(app.verbosity);
|
|
||||||
app.Print(app.verbosity);
|
|
||||||
|
|
||||||
if (app.verbosity >= 2) {
|
|
||||||
time_t now = time(nullptr);
|
|
||||||
char* dt = ctime(&now); // NOLINT
|
|
||||||
std::cout << "Date & Time : " << dt
|
|
||||||
<< "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize
|
|
||||||
<< "\n"
|
|
||||||
<< "Hardware concurrency : "
|
|
||||||
<< std::thread::hardware_concurrency() << "\n"
|
|
||||||
<< "Instruction set : "
|
|
||||||
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
|
|
||||||
<< hwy::VectorBytes() * 8 << " bits)" << "\n";
|
|
||||||
char cpu100[100];
|
|
||||||
if (hwy::platform::GetCpuString(cpu100)) {
|
|
||||||
std::cout << "CPU : " << cpu100 << "\n";
|
|
||||||
}
|
|
||||||
std::cout << "Compiled config : " << CompiledConfig() << "\n"
|
|
||||||
<< "Weight Type : "
|
|
||||||
<< gcpp::StringFromType(loader.WeightType()) << "\n"
|
|
||||||
<< "EmbedderInput Type : "
|
|
||||||
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
|
|
||||||
gcpp::AppArgs& app) {
|
|
||||||
std::cerr
|
|
||||||
<< kAsciiArtBanner
|
|
||||||
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
|
|
||||||
"==========================================================\n\n"
|
|
||||||
"To run gemma.cpp, you need to "
|
|
||||||
"specify 3 required model loading arguments:\n"
|
|
||||||
" --tokenizer\n"
|
|
||||||
" --weights\n"
|
|
||||||
" --model.\n";
|
|
||||||
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
|
|
||||||
"--weights 2b-it-sfp.sbs --model 2b-it\n";
|
|
||||||
std::cerr << "\n*Model Loading Arguments*\n\n";
|
|
||||||
loader.Help();
|
|
||||||
std::cerr << "\n*Inference Arguments*\n\n";
|
|
||||||
inference.Help();
|
|
||||||
std::cerr << "\n*Application Arguments*\n\n";
|
|
||||||
app.Help();
|
|
||||||
std::cerr << "\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
// The main Read-Eval-Print Loop.
|
// The main Read-Eval-Print Loop.
|
||||||
void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
|
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
|
|
@ -118,12 +63,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
int prompt_size{};
|
int prompt_size{};
|
||||||
|
|
||||||
std::mt19937 gen;
|
std::mt19937 gen;
|
||||||
if (args.deterministic) {
|
InitGenerator(args, gen);
|
||||||
gen.seed(42);
|
|
||||||
} else {
|
|
||||||
std::random_device rd;
|
|
||||||
gen.seed(rd());
|
|
||||||
}
|
|
||||||
|
|
||||||
// callback function invoked for each generated token.
|
// callback function invoked for each generated token.
|
||||||
auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size,
|
auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size,
|
||||||
|
|
@ -162,7 +102,6 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
|
|
||||||
while (abs_pos < args.max_tokens) {
|
while (abs_pos < args.max_tokens) {
|
||||||
std::string prompt_string;
|
std::string prompt_string;
|
||||||
std::vector<int> prompt;
|
|
||||||
current_pos = 0;
|
current_pos = 0;
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Gen.input");
|
PROFILER_ZONE("Gen.input");
|
||||||
|
|
@ -192,30 +131,11 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (training == ModelTraining::GEMMA_IT) {
|
const std::vector<int> prompt =
|
||||||
// For instruction-tuned models: add control tokens.
|
WrapAndTokenize(model.Tokenizer(), training, abs_pos, prompt_string);
|
||||||
prompt_string = "<start_of_turn>user\n" + prompt_string +
|
|
||||||
"<end_of_turn>\n<start_of_turn>model\n";
|
|
||||||
if (abs_pos != 0) {
|
|
||||||
// Prepend "<end_of_turn>" token if this is a multi-turn dialogue
|
|
||||||
// continuation.
|
|
||||||
prompt_string = "<end_of_turn>\n" + prompt_string;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
|
|
||||||
|
|
||||||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
|
||||||
// if needed.
|
|
||||||
if (abs_pos == 0) {
|
|
||||||
prompt.insert(prompt.begin(), gcpp::BOS_ID);
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt_size = prompt.size();
|
prompt_size = prompt.size();
|
||||||
|
|
||||||
std::cerr << "\n"
|
std::cerr << "\n"
|
||||||
<< "[ Reading prompt ] " << std::flush;
|
<< "[ Reading prompt ] " << std::flush;
|
||||||
|
|
||||||
if constexpr (kVerboseLogTokens) {
|
if constexpr (kVerboseLogTokens) {
|
||||||
for (int i = 0; i < prompt_size; ++i) {
|
for (int i = 0; i < prompt_size; ++i) {
|
||||||
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
||||||
|
|
@ -301,17 +221,20 @@ int main(int argc, char** argv) {
|
||||||
gcpp::AppArgs app(argc, argv);
|
gcpp::AppArgs app(argc, argv);
|
||||||
|
|
||||||
if (gcpp::HasHelp(argc, argv)) {
|
if (gcpp::HasHelp(argc, argv)) {
|
||||||
ShowHelp(loader, inference, app);
|
std::cerr << gcpp::kAsciiArtBanner;
|
||||||
|
gcpp::ShowHelp(loader, inference, app);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (const char* error = loader.Validate()) {
|
if (const char* error = loader.Validate()) {
|
||||||
ShowHelp(loader, inference, app);
|
std::cerr << gcpp::kAsciiArtBanner;
|
||||||
|
gcpp::ShowHelp(loader, inference, app);
|
||||||
HWY_ABORT("\nInvalid args: %s", error);
|
HWY_ABORT("\nInvalid args: %s", error);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (const char* error = inference.Validate()) {
|
if (const char* error = inference.Validate()) {
|
||||||
ShowHelp(loader, inference, app);
|
std::cerr << gcpp::kAsciiArtBanner;
|
||||||
|
gcpp::ShowHelp(loader, inference, app);
|
||||||
HWY_ABORT("\nInvalid args: %s", error);
|
HWY_ABORT("\nInvalid args: %s", error);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,19 +13,16 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
// Command line text interface to gemma.
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <fstream>
|
#include <algorithm>
|
||||||
#include <iostream>
|
|
||||||
#include <random>
|
|
||||||
#include <set>
|
|
||||||
#include <sstream>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// Placeholder for internal header, do not modify.
|
#include "compression/io.h" // Path
|
||||||
|
#include "gemma/benchmark_helper.h"
|
||||||
#include "gemma/gemma.h" // Gemma
|
#include "gemma/gemma.h" // Gemma
|
||||||
#include "util/app.h"
|
#include "util/args.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
@ -34,164 +31,134 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
void JsonGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
|
struct JsonArgs : public ArgsBase<JsonArgs> {
|
||||||
hwy::ThreadPool& pool,
|
JsonArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
const InferenceArgs& args, int verbosity,
|
|
||||||
std::string& eot_line) {
|
|
||||||
PROFILER_ZONE("Gen.misc");
|
|
||||||
// token index within the current turn
|
|
||||||
int max_tokens = 4096;
|
|
||||||
|
|
||||||
std::mt19937 gen;
|
Path input;
|
||||||
if (args.deterministic) {
|
|
||||||
gen.seed(42);
|
// Returns error string or nullptr if OK.
|
||||||
} else {
|
const char* Validate() const {
|
||||||
std::random_device rd;
|
if (input.Empty()) return "Must specify --input";
|
||||||
gen.seed(rd());
|
if (!input.Exists()) return "--input file does not exist";
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
float answers = 0.0;
|
template <class Visitor>
|
||||||
float correct_answers = 0.0;
|
void ForEach(const Visitor& visitor) {
|
||||||
|
visitor(input, "input", Path(), "Full pathname of mmlu.json.");
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
std::ifstream fJson("/tmp/mmlu.json");
|
// Linear search for a few tokens is faster than std::set.
|
||||||
std::stringstream buffer;
|
// TODO: instead of accepting for each vocab entry, filter the logits once.
|
||||||
buffer << fJson.rdbuf();
|
class TokenSet {
|
||||||
auto json = nlohmann::json::parse(buffer.str());
|
public:
|
||||||
|
TokenSet(const GemmaTokenizer& tokenizer,
|
||||||
std::vector<std::string> accept_tokens = {"A", "B", "C", "D"};
|
const std::vector<std::string>& strings) {
|
||||||
std::set<int> accept_token_set{};
|
all_tokens_.reserve(strings.size());
|
||||||
for (const std::string& accept_token : accept_tokens) {
|
for (const std::string& str : strings) {
|
||||||
std::vector<int> accept_token_ids;
|
std::vector<int> tokens;
|
||||||
HWY_ASSERT(model.Tokenizer().Encode(accept_token, &accept_token_ids));
|
fprintf(stderr, "%s -> ", str.c_str());
|
||||||
accept_token_set.insert(accept_token_ids.begin(), accept_token_ids.end());
|
HWY_ASSERT(tokenizer.Encode(str, &tokens));
|
||||||
|
for (int token : tokens) {
|
||||||
|
fprintf(stderr, "%d, ", token);
|
||||||
|
all_tokens_.push_back(token);
|
||||||
|
}
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto sample : json["samples"]) {
|
bool Contains(int token) const {
|
||||||
int abs_pos = 0; // absolute token index over all turns
|
return std::find(all_tokens_.begin(), all_tokens_.end(), token) !=
|
||||||
int current_pos = 0;
|
all_tokens_.end();
|
||||||
int prompt_size{};
|
}
|
||||||
|
|
||||||
// cout << "prompt:" << sample["prompt"] << endl;
|
private:
|
||||||
const std::string& prompt_string = sample["prompt"];
|
std::vector<int> all_tokens_;
|
||||||
std::vector<int> prompt;
|
};
|
||||||
|
|
||||||
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
|
void Run(GemmaEnv& env, JsonArgs& json) {
|
||||||
prompt_size = prompt.size();
|
PROFILER_ZONE("Run.all");
|
||||||
|
|
||||||
const std::string& correct_answer = accept_tokens[sample["input_label"]];
|
float answers = 0.0f;
|
||||||
|
float correct_answers = 0.0f;
|
||||||
|
|
||||||
// max_tokens = prompt_size + max_tokens;
|
auto json_data = nlohmann::json::parse(ReadFileToString(json.input));
|
||||||
|
|
||||||
|
const std::vector<std::string> accept_strings = {
|
||||||
|
"A", "B", "C", "D", //
|
||||||
|
" A", " B", " C", " D", //
|
||||||
|
"**", "**:", ":**", "The", "Answer", "is", ":", "."};
|
||||||
|
const TokenSet accept_set(env.GetModel()->Tokenizer(), accept_strings);
|
||||||
|
|
||||||
|
for (auto sample : json_data["samples"]) {
|
||||||
|
const int id = sample["i"];
|
||||||
|
fprintf(stderr, "Processing question %d\n", id);
|
||||||
|
const std::string& correct_answer = accept_strings[sample["input_label"]];
|
||||||
|
std::string prompt_string = sample["prompt"];
|
||||||
|
// AcceptFunc restricts the output to one of these four tokens, so make an
|
||||||
|
// effort to steer the model towards that. See
|
||||||
|
// https://huggingface.co/blog/open-llm-leaderboard-mmlu
|
||||||
|
prompt_string +=
|
||||||
|
"What is start of the line with the correct answer? "
|
||||||
|
"Do not include any justifications or explanations. Reply only with a "
|
||||||
|
"letter.";
|
||||||
|
const std::vector<int> prompt =
|
||||||
|
WrapAndTokenize(env.GetModel()->Tokenizer(), env.ModelTrainingType(),
|
||||||
|
/*pos=*/0, prompt_string);
|
||||||
|
const size_t prompt_size = prompt.size();
|
||||||
|
|
||||||
std::vector<int> predicted_token_ids;
|
std::vector<int> predicted_token_ids;
|
||||||
predicted_token_ids.reserve(max_tokens);
|
predicted_token_ids.reserve(4096);
|
||||||
auto stream_token = [¤t_pos, &prompt_size, &predicted_token_ids,
|
size_t current_pos = 0;
|
||||||
&accept_token_set](int token, float proba) {
|
const StreamFunc stream_token = [¤t_pos, prompt_size,
|
||||||
|
&predicted_token_ids](int token,
|
||||||
|
float proba) {
|
||||||
|
PROFILER_ZONE("Stream");
|
||||||
++current_pos;
|
++current_pos;
|
||||||
if (current_pos > prompt_size) {
|
if (current_pos > prompt_size) {
|
||||||
predicted_token_ids.push_back(token);
|
predicted_token_ids.push_back(token);
|
||||||
|
|
||||||
// If the generated token is in the accepted token set, return False.
|
|
||||||
// This will stop further generation.
|
|
||||||
return accept_token_set.find(token) == accept_token_set.end();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
const AcceptFunc accept_token = [¤t_pos, &prompt_size,
|
// Although " A" is a token, it is difficult to associate that with the
|
||||||
&accept_token_set](int token) {
|
// correct answer. Only accepting certain tokens is risky: (A) is easily
|
||||||
// i.e. we have no constraints on accepted tokens
|
// confused with the word "A".
|
||||||
if (accept_token_set.empty()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (current_pos >= prompt_size) {
|
|
||||||
return accept_token_set.find(token) != accept_token_set.end();
|
|
||||||
} else {
|
|
||||||
// auto-accept early tokens
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
gcpp::TimingInfo timing_info;
|
gcpp::TimingInfo timing_info;
|
||||||
gcpp::RuntimeConfig runtime_config = {
|
gcpp::RuntimeConfig runtime_config = {
|
||||||
.max_tokens = args.max_tokens,
|
.max_tokens = env.MaxTokens(),
|
||||||
.max_generated_tokens = args.max_generated_tokens,
|
.max_generated_tokens = 30,
|
||||||
.temperature = args.temperature,
|
.temperature = 0.0f,
|
||||||
.verbosity = verbosity,
|
.verbosity = env.Verbosity(),
|
||||||
.gen = &gen,
|
.gen = &env.MutableGen(),
|
||||||
.stream_token = stream_token,
|
.stream_token = stream_token,
|
||||||
.accept_token = accept_token,
|
|
||||||
};
|
};
|
||||||
model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info);
|
env.GetModel()->Generate(runtime_config, prompt, /*pos=*/0,
|
||||||
|
env.MutableKVCache(), timing_info);
|
||||||
|
|
||||||
std::string output_string;
|
std::string output_string = env.StringFromTokens(predicted_token_ids);
|
||||||
HWY_ASSERT(model.Tokenizer().Decode(predicted_token_ids, &output_string));
|
fprintf(stderr, "Correct %s, model '%s'\n", correct_answer.c_str(),
|
||||||
std::cout << "QuestionId: " << sample["i"] << "; "
|
output_string.c_str());
|
||||||
<< "Predicted Answer: " << output_string << "; "
|
|
||||||
<< "Correct Answer: " << correct_answer << std::endl;
|
|
||||||
|
|
||||||
answers += 1.0;
|
answers += 1.0f;
|
||||||
if (output_string == correct_answer) {
|
if (output_string == correct_answer) {
|
||||||
correct_answers += 1.0;
|
correct_answers += 1.0f;
|
||||||
}
|
}
|
||||||
std::cout << "Running accuracy = " << "["
|
fprintf(stderr, "%.0f/%.0f = %.2f%%\n", correct_answers, answers,
|
||||||
<< static_cast<int>(correct_answers) << "/"
|
correct_answers / answers);
|
||||||
<< static_cast<int>(answers) << "]" << " = "
|
|
||||||
<< correct_answers / answers << std::endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|
||||||
loader.Print(app.verbosity);
|
|
||||||
inference.Print(app.verbosity);
|
|
||||||
app.Print(app.verbosity);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|
||||||
PROFILER_ZONE("Run.misc");
|
|
||||||
|
|
||||||
hwy::ThreadPool pool(app.num_threads);
|
|
||||||
// For many-core, pinning workers to cores helps.
|
|
||||||
if (app.num_threads > 10) {
|
|
||||||
PinWorkersToCores(pool);
|
|
||||||
}
|
|
||||||
|
|
||||||
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
|
|
||||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
|
|
||||||
|
|
||||||
JsonGemma(model, kv_cache, pool, inference, app.verbosity, app.eot_line);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Startup.misc");
|
PROFILER_ZONE("Startup.all");
|
||||||
|
gcpp::GemmaEnv env(argc, argv);
|
||||||
// Placeholder for internal init, do not modify.
|
gcpp::JsonArgs json(argc, argv);
|
||||||
|
gcpp::AbortIfInvalidArgs(json);
|
||||||
gcpp::LoaderArgs loader(argc, argv);
|
gcpp::Run(env, json);
|
||||||
gcpp::InferenceArgs inference(argc, argv);
|
|
||||||
gcpp::AppArgs app(argc, argv);
|
|
||||||
|
|
||||||
if (const char* error = loader.Validate()) {
|
|
||||||
fprintf(stderr,
|
|
||||||
"\ngemma.cpp\n---------\n\nTo run gemma.cpp, you need to "
|
|
||||||
"specify 3 required model loading arguments: --tokenizer, "
|
|
||||||
"--compressed_weights, "
|
|
||||||
"and --model.\n\nModel Loading Arguments\n\n");
|
|
||||||
|
|
||||||
loader.Help();
|
|
||||||
fprintf(stderr, "\nInference Arguments\n\n");
|
|
||||||
inference.Help();
|
|
||||||
fprintf(stderr, "\nApplication Arguments\n\n");
|
|
||||||
app.Help();
|
|
||||||
fprintf(stderr, "\n\n");
|
|
||||||
HWY_ABORT("\nInvalid args: %s", error);
|
|
||||||
}
|
|
||||||
|
|
||||||
gcpp::Run(loader, inference, app);
|
|
||||||
}
|
}
|
||||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||||
return 0;
|
return 0;
|
||||||
|
|
|
||||||
|
|
@ -197,6 +197,14 @@ static inline HWY_MAYBE_UNUSED bool HasHelp(int argc, char* argv[]) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class TArgs>
|
||||||
|
static inline HWY_MAYBE_UNUSED void AbortIfInvalidArgs(TArgs& args) {
|
||||||
|
if (const char* err = args.Validate()) {
|
||||||
|
args.Help();
|
||||||
|
HWY_ABORT("Problem with args: %s\n", err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_ARGS_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_ARGS_H_
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue