diff --git a/BUILD.bazel b/BUILD.bazel index 598d92f..fc64dc5 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -113,12 +113,8 @@ cc_library( cc_library( name = "cross_entropy", - srcs = [ - "gemma/cross_entropy.cc", - ], - hdrs = [ - "gemma/cross_entropy.h", - ], + srcs = ["gemma/cross_entropy.cc"], + hdrs = ["gemma/cross_entropy.h"], deps = [ ":common", ":gemma_lib", @@ -149,6 +145,25 @@ cc_library( ], ) +cc_library( + name = "benchmark_helper", + srcs = ["gemma/benchmark_helper.cc"], + hdrs = ["gemma/benchmark_helper.h"], + deps = [ + ":app", + ":args", + ":common", + ":cross_entropy", + ":gemma_lib", + # Placeholder for internal dep, do not remove., + "@benchmark//:benchmark", + "//compression:compress", + "@hwy//:hwy", + "@hwy//:nanobenchmark", + "@hwy//:thread_pool", + ], +) + cc_test( name = "gemma_test", srcs = ["gemma/gemma_test.cc"], @@ -161,11 +176,11 @@ cc_test( deps = [ ":app", ":args", + ":benchmark_helper", ":common", ":cross_entropy", ":gemma_lib", ":ops", - # Placeholder for internal dep, do not remove., "@googletest//:gtest_main", "//compression:io", "@hwy//:hwy_test_util", @@ -179,6 +194,7 @@ cc_binary( deps = [ ":app", ":args", + ":benchmark_helper", ":common", ":gemma_lib", # Placeholder for internal dep, do not remove., @@ -213,10 +229,10 @@ cc_binary( deps = [ ":app", ":args", + ":benchmark_helper", ":common", ":cross_entropy", ":gemma_lib", - # Placeholder for internal dep, do not remove., "//compression:io", "@hwy//:hwy", "@hwy//:nanobenchmark", @@ -230,7 +246,6 @@ cc_binary( srcs = ["gemma/benchmarks.cc"], deps = [ ":benchmark_helper", - # Placeholder for internal dep, do not remove., "@benchmark//:benchmark", ], ) @@ -243,8 +258,8 @@ cc_binary( deps = [ ":app", ":args", + ":benchmark_helper", ":gemma_lib", - # Placeholder for internal dep, do not remove., "//compression:io", "@hwy//:hwy", "@hwy//:thread_pool", @@ -257,8 +272,10 @@ cc_binary( srcs = ["gemma/run_mmlu.cc"], deps = [ ":app", + ":args", + ":benchmark_helper", ":gemma_lib", - # Placeholder for internal dep, do not remove., + "//compression:io", "@hwy//:hwy", "@hwy//:profiler", "@hwy//:thread_pool", @@ -318,25 +335,6 @@ cc_library( ], ) -cc_library( - name = "benchmark_helper", - srcs = [ - "gemma/benchmark_helper.cc", - ], - hdrs = [ - "gemma/benchmark_helper.h", - ], - deps = [ - ":app", - ":common", - ":gemma_lib", - "@benchmark//:benchmark", - "@hwy//:hwy", - "@hwy//:nanobenchmark", - "@hwy//:thread_pool", - ], -) - cc_test( name = "backward_scalar_test", size = "large", diff --git a/compression/io.h b/compression/io.h index c5287b8..1d47143 100644 --- a/compression/io.h +++ b/compression/io.h @@ -23,6 +23,8 @@ #include #include // std::move +#include "hwy/base.h" + namespace gcpp { // Forward-declare to break the circular dependency: OpenFileOrNull returns @@ -77,12 +79,30 @@ struct Path { return path; } + bool Empty() const { return path.empty(); } + // Returns whether the file existed when this was called. bool Exists() const { return !!OpenFileOrNull(*this, "r"); } std::string path; }; +static inline HWY_MAYBE_UNUSED std::string ReadFileToString(const Path& path) { + std::unique_ptr file = OpenFileOrNull(path, "r"); + if (!file) { + HWY_ABORT("Failed to open %s", path.path.c_str()); + } + const size_t size = file->FileSize(); + if (size == 0) { + HWY_ABORT("Empty file %s", path.path.c_str()); + } + std::string content(size, ' '); + if (!file->Read(0, size, content.data())) { + HWY_ABORT("Failed to read %s", path.path.c_str()); + } + return content; +} + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_ diff --git a/debug_prompt.cc b/debug_prompt.cc index adb9d92..7bc33e5 100644 --- a/debug_prompt.cc +++ b/debug_prompt.cc @@ -1,146 +1,83 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include #include #include -#include #include -#include #include -// Placeholder for internal header, do not modify. #include "compression/io.h" -#include "gemma/gemma.h" -#include "util/app.h" +#include "gemma/benchmark_helper.h" +#include "gemma/gemma.h" // LayersOutputFunc #include "util/args.h" #include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" #include "nlohmann/json.hpp" using json = nlohmann::json; -class PromptArgs : public gcpp::ArgsBase { +namespace gcpp { + +class PromptArgs : public ArgsBase { public: PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } - gcpp::Path layers_output; + Path layers_output; // optional std::string prompt; + // Returns error string or nullptr if OK. + const char* Validate() const { + if (prompt.empty()) return "Must specify --prompt"; + return nullptr; + } + template void ForEach(const Visitor& visitor) { - visitor(layers_output.path, "layers_output", std::string(""), + visitor(layers_output, "layers_output", Path(""), "Path to store layers output", 2); visitor(prompt, "prompt", std::string(""), "Prompt to the model", 2); } }; -std::pair QueryModel( - gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app, - gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input, - gcpp::LayersOutputT* layers_output) { - std::vector prompt; - HWY_ASSERT(model.Tokenizer().Encode(input, &prompt)); - - // For both pre-trained and instruction-tuned models: prepend "" token - // if needed. - prompt.insert(prompt.begin(), gcpp::BOS_ID); - std::string res; - size_t total_tokens = 0; - std::mt19937 gen; - gen.seed(42); - - auto stream_token = [&res, &total_tokens, &model](int token, float) { - ++total_tokens; - std::string token_text; - HWY_ASSERT(model.Tokenizer().Decode(std::vector{token}, &token_text)); - res += token_text; - return true; - }; - if (app.verbosity >= 2) { - std::cout << args.max_tokens << " " << args.max_generated_tokens << " " - << args.temperature; - } - gcpp::TimingInfo timing_info; - gcpp::RuntimeConfig runtime_config = { - .max_tokens = args.max_tokens, - .max_generated_tokens = args.max_generated_tokens, - .temperature = args.temperature, - .verbosity = app.verbosity, - .gen = &gen, - .stream_token = stream_token, - }; - model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, timing_info, - layers_output); - return {res, total_tokens}; -} - -class OutputJsonLogger { - public: - json json_output; - - gcpp::LayersOutputT layers_output_log_f = - [this](int pos, const std::string& key, const float* values, - size_t values_len) { - std::vector v{values, values + values_len}; - json_output[std::to_string(pos)][key] = v; - }; -}; - -/* Run this in the same way as gemma, p.ex.: - ./debug_prompt --tokenizer tokenizer.spm --model 2b-it --weights \ - 2b-it-sfp.sbs --prompt "..." --layers_output [path] -*/ -int main(int argc, char** argv) { - { - // Placeholder for internal init, do not modify. - } - - gcpp::LoaderArgs loader(argc, argv); - gcpp::InferenceArgs args(argc, argv); // inference - gcpp::AppArgs app(argc, argv); +int Run(int argc, char** argv) { PromptArgs prompt_args(argc, argv); + AbortIfInvalidArgs(prompt_args); - if (const char* error = loader.Validate()) { - HWY_ABORT("\nInvalid loader args: %s", error); - } - if (const char* error = args.Validate()) { - HWY_ABORT("\nInvalid inference args: %s", error); - } - const bool log_layers_output = !prompt_args.layers_output.path.empty(); - OutputJsonLogger json_logger; - gcpp::LayersOutputT* layers_output = - log_layers_output ? &json_logger.layers_output_log_f : nullptr; + json json_output; + GemmaEnv env(argc, argv); + env.MutableConfig().layers_output = + prompt_args.layers_output.Empty() + ? LayersOutputFunc() + : [&json_output](int pos, const std::string& key, const float* values, + size_t values_len) { + std::vector v{values, values + values_len}; + json_output[std::to_string(pos)][key] = v; + }; - hwy::ThreadPool pool(app.num_threads); - // For many-core, pinning workers to cores helps. - if (app.num_threads > 10) { - gcpp::PinWorkersToCores(pool); - } + const auto [answer, token_count] = env.QueryModel(prompt_args.prompt); + std::cout << answer.substr(prompt_args.prompt.size()) << "\n" << std::flush; - gcpp::Gemma model = gcpp::CreateGemma(loader, pool); - gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType()); - - const std::string& prompt = prompt_args.prompt; - if (prompt.empty()) { - std::cout << "Please specify --prompt" << std::endl; - return EXIT_FAILURE; - } - const auto [answer, token_count] = QueryModel( - model, args, app, kv_cache, pool, prompt, layers_output); - std::cout << answer.substr(prompt.size()) << "\n" << std::flush; - - if (log_layers_output) { + if (env.MutableConfig().layers_output) { std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out); - if (!output_f) { - std::cout << "Opening file failed" << std::endl; - return EXIT_FAILURE; - } - output_f << json_logger.json_output.dump(); - if (!output_f) { - std::cout << "Writing to file failed" << std::endl; - return EXIT_FAILURE; - } + if (!output_f) HWY_ABORT("Opening layer output file failed"); + output_f << json_output.dump(); + if (!output_f) HWY_ABORT("Writing to layer output file failed"); output_f.close(); } - - return EXIT_SUCCESS; + return 0; } + +} // namespace gcpp + +int main(int argc, char** argv) { return gcpp::Run(argc, argv); } diff --git a/gemma/benchmark.cc b/gemma/benchmark.cc index e8d9539..af05142 100644 --- a/gemma/benchmark.cc +++ b/gemma/benchmark.cc @@ -1,36 +1,36 @@ +#include + #include #include // EXIT_FAILURE #include #include #include -#include -#include #include #include // std::pair #include -// Placeholder for internal header, do not modify. #include "compression/io.h" // Path +#include "gemma/benchmark_helper.h" +#include "gemma/common.h" #include "gemma/cross_entropy.h" #include "gemma/gemma.h" -#include "util/app.h" #include "util/args.h" #include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/highway.h" #include "hwy/timer.h" #include "nlohmann/json.hpp" +namespace gcpp { + using json = nlohmann::json; -class BenchmarkArgs : public gcpp::ArgsBase { +class BenchmarkArgs : public ArgsBase { public: BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } - gcpp::Path goldens; - gcpp::Path summarize_text; - gcpp::Path cross_entropy; - gcpp::Path trivia_qa; + Path goldens; + Path summarize_text; + Path cross_entropy; + Path trivia_qa; size_t max_questions; size_t batch_tokens; @@ -53,61 +53,6 @@ class BenchmarkArgs : public gcpp::ArgsBase { } }; -void LogSpeedStats(const double time_start, size_t total_tokens) { - const double time_end = hwy::platform::Now(); - const double time_elapsed = time_end - time_start; - const double tok_sec = total_tokens / time_elapsed; - std::cout << total_tokens << " tokens in " << time_elapsed << " seconds" - << " [" << tok_sec << " tokens / sec" << "]\n"; -} - -std::pair QueryModel( - gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app, - gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input) { - std::vector prompt; - HWY_ASSERT(model.Tokenizer().Encode(input, &prompt)); - - // For both pre-trained and instruction-tuned models: prepend "" token - // if needed. - prompt.insert(prompt.begin(), 2); - std::string res; - size_t total_tokens = 0; - std::mt19937 gen; - gen.seed(42); - - const double time_start = hwy::platform::Now(); - auto stream_token = [&res, &total_tokens, &time_start, &app, &model]( - int token, float) { - ++total_tokens; - std::string token_text; - HWY_ASSERT(model.Tokenizer().Decode(std::vector{token}, &token_text)); - res += token_text; - if (app.verbosity >= 1 && total_tokens % 100 == 0) { - LogSpeedStats(time_start, total_tokens); - } - return true; - }; - if (app.verbosity >= 2) { - std::cout << args.max_tokens << " " << args.max_generated_tokens << " " - << args.temperature; - } - gcpp::TimingInfo timing_info; - gcpp::RuntimeConfig runtime_config = { - .max_tokens = args.max_tokens, - .max_generated_tokens = args.max_generated_tokens, - .temperature = args.temperature, - .verbosity = app.verbosity, - .gen = &gen, - .stream_token = stream_token, - }; - model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, timing_info, - /*layers_output=*/nullptr); - if (app.verbosity >= 1) { - LogSpeedStats(time_start, total_tokens); - } - return {res, total_tokens}; -} - std::vector> load_goldens( const std::string& path) { std::ifstream goldens_file(path); @@ -129,28 +74,14 @@ std::vector> load_goldens( return res; } -std::string ReadFile(const gcpp::Path& path) { - std::ifstream text_file(path.path); - if (!text_file) { - std::cout << "Could not open file: " << path.path << "\n" << std::flush; - return {}; - } - std::stringstream buffer; - buffer << text_file.rdbuf(); - return buffer.str(); -} - -int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args, - gcpp::AppArgs& app, gcpp::KVCache& kv_cache, - hwy::ThreadPool& pool, const std::string& golden_path) { - const std::vector> queries_answers = +int BenchmarkGoldens(GemmaEnv& env, const std::string& golden_path) { + std::vector> queries_answers = load_goldens(golden_path); - int correct_answers = 0; - int total_tokens = 0; + size_t correct_answers = 0; + size_t total_tokens = 0; const double time_start = hwy::platform::Now(); - for (const auto& [question, expected_answer] : queries_answers) { - const auto [answer, token_count] = - QueryModel(model, args, app, kv_cache, pool, question); + for (auto& [question, expected_answer] : queries_answers) { + const auto [answer, token_count] = env.QueryModel(question); total_tokens += token_count; if (answer.find(expected_answer) != std::string::npos) { correct_answers++; @@ -172,28 +103,22 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args, return EXIT_SUCCESS; } -int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args, - gcpp::AppArgs& app, gcpp::KVCache& kv_cache, - hwy::ThreadPool& pool, const gcpp::Path& text) { +int BenchmarkSummary(GemmaEnv& env, const Path& text) { std::string prompt("Here is some text to summarize:\n"); - prompt.append(ReadFile(text)); + prompt.append(ReadFileToString(text)); prompt.append("\nSummarize this text.\n"); const double time_start = hwy::platform::Now(); - const auto [answer, token_count] = - QueryModel(model, args, app, kv_cache, pool, prompt); + const auto [answer, token_count] = env.QueryModel(prompt); std::cout << answer.substr(prompt.size()) << "\n" << std::flush; LogSpeedStats(time_start, token_count); return EXIT_SUCCESS; } -int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type, - gcpp::InferenceArgs& args, gcpp::AppArgs& app, - hwy::ThreadPool& pool, const gcpp::Path& text, +int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text, size_t batch_tokens) { - std::string input = ReadFile(text); - std::vector prompt; - HWY_ASSERT(model.Tokenizer().Encode(input, &prompt)); - prompt.resize(std::min(args.max_tokens, prompt.size())); + std::string input = ReadFileToString(text); + std::vector prompt = env.Tokenize(input); + prompt.resize(std::min(env.MaxTokens(), prompt.size())); std::cout << "Number of input tokens: " << prompt.size() << "\n"; const double time_start = hwy::platform::Now(); float total_entropy = 0.0f; @@ -203,13 +128,12 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type, size_t num_tokens = std::min(prompt.size() - pos, batch_tokens); std::vector prompt_slice(prompt.begin() + pos, prompt.begin() + pos + num_tokens); - gcpp::KVCache kv_cache = gcpp::KVCache::Create(model_type); - float entropy = ComputeCrossEntropy(model, num_tokens, prompt_slice, - kv_cache, app.verbosity); + KVCache kv_cache = KVCache::Create(env.ModelType()); + float entropy = ComputeCrossEntropy( + *env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity()); total_entropy += entropy; LogSpeedStats(time_start, pos + num_tokens); - std::string text_slice; - HWY_ASSERT(model.Tokenizer().Decode(prompt_slice, &text_slice)); + std::string text_slice = env.StringFromTokens(prompt_slice); total_input_len += text_slice.size(); printf("Total cross entropy: %f [cumulative: %f]\n", entropy, total_entropy); @@ -219,23 +143,19 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type, return EXIT_SUCCESS; } -int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args, - gcpp::AppArgs& app, gcpp::KVCache& kv_cache, - hwy::ThreadPool& pool, const gcpp::Path& json_file, +int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file, size_t max_questions) { std::ifstream trivia_file(json_file.path); if (!trivia_file) { - std::cout << "Could not load file: " << json_file.path << "\n" - << std::flush; - return EXIT_FAILURE; + HWY_ABORT("Could not load file %s\n", json_file.path.c_str()); } std::string line; size_t correct_answers = 0; size_t i = 0; while (std::getline(trivia_file, line)) { json data = json::parse(line); - const auto [answer, token_count] = QueryModel( - model, args, app, kv_cache, pool, data["question"]); + std::string q(data["question"]); + const auto [answer, token_count] = env.QueryModel(q); std::cout << answer << "\n"; bool correct = false; for (const std::string expected : data["answer"]["aliases"]) { @@ -256,52 +176,25 @@ int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args, return EXIT_SUCCESS; } -/* Run this in the same way as gemma, p.ex.: - ./benchmark --tokenizer tokenizer.spm --model 2b-it --weights \ - 2b-it-sfp.sbs --goldens_dir "../goldens" -*/ +} // namespace gcpp + int main(int argc, char** argv) { - { - // Placeholder for internal init, do not modify. - } + gcpp::GemmaEnv env(argc, argv); + gcpp::BenchmarkArgs benchmark_args(argc, argv); - gcpp::LoaderArgs loader(argc, argv); - gcpp::InferenceArgs args(argc, argv); // inference - gcpp::AppArgs app(argc, argv); - BenchmarkArgs benchmark_args(argc, argv); - - if (const char* error = loader.Validate()) { - HWY_ABORT("\nInvalid loader args: %s", error); - } - if (const char* error = args.Validate()) { - HWY_ABORT("\nInvalid inference args: %s", error); - } - - hwy::ThreadPool pool(app.num_threads); - // For many-core, pinning workers to cores helps. - if (app.num_threads > 10) { - gcpp::PinWorkersToCores(pool); - } - - gcpp::Gemma model = gcpp::CreateGemma(loader, pool); - gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType()); - - if (!benchmark_args.goldens.path.empty()) { + if (!benchmark_args.goldens.Empty()) { const std::string golden_path = - benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt"; - return BenchmarkGoldens(model, args, app, kv_cache, pool, golden_path); - } else if (!benchmark_args.summarize_text.path.empty()) { - return BenchmarkSummary(model, args, app, kv_cache, pool, - benchmark_args.summarize_text); - } else if (!benchmark_args.cross_entropy.path.empty()) { - return BenchmarkCrossEntropy(model, loader.ModelType(), args, app, - pool, benchmark_args.cross_entropy, + benchmark_args.goldens.path + "/" + + gcpp::ModelString(env.ModelType(), env.ModelTrainingType()) + ".txt"; + return BenchmarkGoldens(env, golden_path); + } else if (!benchmark_args.summarize_text.Empty()) { + return BenchmarkSummary(env, benchmark_args.summarize_text); + } else if (!benchmark_args.cross_entropy.Empty()) { + return BenchmarkCrossEntropy(env, benchmark_args.cross_entropy, benchmark_args.batch_tokens); - } else if (!benchmark_args.trivia_qa.path.empty()) { - return BenchmarkTriviaQA(model, args, app, kv_cache, pool, - benchmark_args.trivia_qa, + } else if (!benchmark_args.trivia_qa.Empty()) { + return BenchmarkTriviaQA(env, benchmark_args.trivia_qa, benchmark_args.max_questions); } - std::cout << "No benchmark command given." << "\n" << std::flush; - return EXIT_FAILURE; + HWY_ABORT("No benchmark command given."); } diff --git a/gemma/benchmark_helper.cc b/gemma/benchmark_helper.cc index d48914b..34a2d26 100644 --- a/gemma/benchmark_helper.cc +++ b/gemma/benchmark_helper.cc @@ -14,70 +14,95 @@ // limitations under the License. #include "gemma/benchmark_helper.h" -#include // EXIT_FAILURE + +#include +#include + #include #include #include #include #include +#include // NOLINT #include // std::pair #include -#include "gemma/common.h" +// Placeholder for internal header, do not modify. +#include "compression/compress.h" // TypeName +#include "gemma/common.h" // StringFromType +#include "gemma/cross_entropy.h" #include "gemma/gemma.h" #include "util/app.h" +#include "util/args.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" +#include "hwy/per_target.h" #include "hwy/timer.h" namespace gcpp { - GemmaEnv::GemmaEnv(int argc, char** argv) - : loader_(argc, argv), inference_args_(argc, argv), app_(argc, argv), - pool_(app_.num_threads) { - if (const char* error = loader_.Validate()) { - HWY_ABORT("\nInvalid loader args: %s", error); - } - if (const char* error = inference_args_.Validate()) { - HWY_ABORT("\nInvalid inference args: %s", error); - } - // For many-core, pinning workers to cores helps. - if (app_.num_threads > 10) { - gcpp::PinWorkersToCores(pool_); - } + +void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) { + if (inference.deterministic) { + // Nothing up my sleeve number, at least some upper bits set. + gen.seed(0x12345678); + } else { + // Depending on the library implementation, this may still be deterministic. + std::random_device rd; + gen.seed(rd()); + } +} + +GemmaEnv::GemmaEnv(int argc, char** argv) + : loader_(argc, argv), + inference_args_(argc, argv), + app_(argc, argv), + pool_(app_.num_threads) { + { + // Placeholder for internal init, do not modify. + } + + // For many-core, pinning workers to cores helps. + if (app_.num_threads > 10) { + gcpp::PinWorkersToCores(pool_); + } + + AbortIfInvalidArgs(inference_args_); + + if (const char* err = loader_.Validate()) { + loader_.Help(); + fprintf(stderr, "Skipping model load because: %s\n", err); + } else { + fprintf(stderr, "Loading model...\n"); model_ = AllocateGemma(loader_, pool_); kv_cache_ = KVCache::Create(loader_.ModelType()); - gen_.seed(42); } -std::pair GemmaEnv::QueryModel(const std::string& input) { - std::string prompt_string = input; - if (loader_.ModelTrainingType() == ModelTraining::GEMMA_IT) { - // For instruction-tuned models: add control tokens. - prompt_string = "user\n" + input + - "\nmodel\n"; - } - std::vector prompt; - HWY_ASSERT(model_->Tokenizer().Encode(input, &prompt)); + InitGenerator(inference_args_, gen_); - // For both pre-trained and instruction-tuned models: prepend "" token - // if needed. - prompt.insert(prompt.begin(), gcpp::BOS_ID); + runtime_config_ = { + .max_tokens = inference_args_.max_tokens, + .max_generated_tokens = inference_args_.max_generated_tokens, + .temperature = inference_args_.temperature, + .verbosity = app_.verbosity, + .gen = &gen_, + }; +} + +std::pair GemmaEnv::QueryModel( + const std::vector& tokens) { std::string res; size_t total_tokens = 0; - auto accept_token = [](int) { return true; }; - std::mt19937 gen; - gen.seed(42); const double time_start = hwy::platform::Now(); - auto stream_token = [&res, &total_tokens, &time_start, this]( - int token, float) { + const StreamFunc stream_token = [&res, &total_tokens, &time_start, this]( + int token, float) { ++total_tokens; std::string token_text; - HWY_ASSERT(model_->Tokenizer().Decode(std::vector{token}, - &token_text)); + HWY_ASSERT( + model_->Tokenizer().Decode(std::vector{token}, &token_text)); res += token_text; - if (app_.verbosity >= 1 && total_tokens % 100 == 0) { + if (app_.verbosity >= 1 && total_tokens % 128 == 0) { LogSpeedStats(time_start, total_tokens); } return true; @@ -88,24 +113,32 @@ std::pair GemmaEnv::QueryModel(const std::string& input) { << inference_args_.temperature; } gcpp::TimingInfo timing_info; - gcpp::RuntimeConfig runtime_config = { - .max_tokens = inference_args_.max_tokens, - .max_generated_tokens = inference_args_.max_generated_tokens, - .temperature = inference_args_.temperature, - .verbosity = app_.verbosity, - .gen = &gen, - .stream_token = stream_token, - .accept_token = accept_token, - }; - model_->Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache_, - timing_info, /*layers_output=*/nullptr); + runtime_config_.stream_token = stream_token; + model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_cache_, + timing_info); if (app_.verbosity >= 1) { LogSpeedStats(time_start, total_tokens); } return {res, total_tokens}; } -void GemmaEnv::LogSpeedStats(double time_start, size_t total_tokens) const { +std::pair GemmaEnv::QueryModel(std::string& input) { + const std::vector prompt = + WrapAndTokenize(model_->Tokenizer(), loader_.ModelTrainingType(), + /*pos=*/0, input); + return QueryModel(prompt); +} + +float GemmaEnv::CrossEntropy(const std::string& input) { + std::vector prompt = Tokenize(input); + prompt.insert(prompt.begin(), BOS_ID); + return ComputeCrossEntropy(*GetModel(), /*max_tokens=*/3072, prompt, + MutableKVCache(), + /*verbosity=*/0) / + static_cast(input.size()); +} + +void LogSpeedStats(double time_start, size_t total_tokens) { const double time_end = hwy::platform::Now(); const double time_elapsed = time_end - time_start; const double tok_sec = total_tokens / time_elapsed; @@ -113,6 +146,53 @@ void GemmaEnv::LogSpeedStats(double time_start, size_t total_tokens) const { << " [" << tok_sec << " tokens / sec" << "]\n"; } +void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { + loader.Print(app.verbosity); + inference.Print(app.verbosity); + app.Print(app.verbosity); + + if (app.verbosity >= 2) { + time_t now = time(nullptr); + char* dt = ctime(&now); // NOLINT + std::cout << "Date & Time : " << dt + << "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize + << "\n" + << "Hardware concurrency : " + << std::thread::hardware_concurrency() << "\n" + << "Instruction set : " + << hwy::TargetName(hwy::DispatchedTarget()) << " (" + << hwy::VectorBytes() * 8 << " bits)" << "\n"; + char cpu100[100]; + if (hwy::platform::GetCpuString(cpu100)) { + std::cout << "CPU : " << cpu100 << "\n"; + } + std::cout << "Compiled config : " << CompiledConfig() << "\n" + << "Weight Type : " + << gcpp::StringFromType(loader.WeightType()) << "\n" + << "EmbedderInput Type : " + << gcpp::TypeName(gcpp::EmbedderInputT()) << "\n"; + } +} + +void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, + gcpp::AppArgs& app) { + std::cerr + << "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n" + "==========================================================\n\n" + "To run gemma.cpp, you need to " + "specify 3 required model loading arguments:\n" + " --tokenizer\n" + " --weights\n" + " --model.\n"; + std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm " + "--weights 2b-it-sfp.sbs --model 2b-it\n"; + std::cerr << "\n*Model Loading Arguments*\n\n"; + loader.Help(); + std::cerr << "\n*Inference Arguments*\n\n"; + inference.Help(); + std::cerr << "\n*Application Arguments*\n\n"; + app.Help(); + std::cerr << "\n"; +} } // namespace gcpp - diff --git a/gemma/benchmark_helper.h b/gemma/benchmark_helper.h index 1feac4a..ac6eef6 100644 --- a/gemma/benchmark_helper.h +++ b/gemma/benchmark_helper.h @@ -16,11 +16,15 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_ +#include + #include #include #include #include +#include +#include "gemma/common.h" #include "gemma/gemma.h" #include "util/app.h" #include "hwy/base.h" @@ -28,24 +32,60 @@ namespace gcpp { +void InitGenerator(const InferenceArgs& inference, std::mt19937& gen); + // Convenience class to load a model and run inference. class GemmaEnv { public: GemmaEnv(int argc, char** argv); + size_t MaxTokens() const { return inference_args_.max_tokens; } // Sets the maximum number of output tokens to generate. - void set_max_generated_tokens(int max_tokens) { + void SetMaxGeneratedTokens(size_t max_tokens) { inference_args_.max_generated_tokens = max_tokens; } + std::vector Tokenize(const std::string& input) const { + std::vector tokens; + HWY_ASSERT(model_->Tokenizer().Encode(input, &tokens)); + return tokens; + } + + std::vector TokenizeAndPrependBOS(const std::string& input) const { + std::vector tokens = Tokenize(input); + tokens.insert(tokens.begin(), BOS_ID); + return tokens; + } + + std::string StringFromTokens(const std::vector& tokens) const { + std::string string; + HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string)); + return string; + } + // Runs inference on the given input and returns the top-1 result string and // the number of tokens that were generated. - std::pair QueryModel(const std::string& input); + std::pair QueryModel(const std::vector& tokens); + // Adds turn structure to input, tokenizes and calls the above overload. + std::pair QueryModel(std::string& input); + + // Runs inference on the given input and returns the cross entropy, a measure + // of how well the model predicts the correct output. It is the average + // number of bits per token. + float CrossEntropy(const std::string& input); + + // Returns nullptr if the model failed to load. + Gemma* GetModel() const { return model_.get(); } + Model ModelType() const { return loader_.ModelType(); } + ModelTraining ModelTrainingType() const { + return loader_.ModelTrainingType(); + } + int Verbosity() const { return app_.verbosity; } + gcpp::RuntimeConfig& MutableConfig() { return runtime_config_; } + std::mt19937& MutableGen() { return gen_; } + KVCache& MutableKVCache() { return kv_cache_; } private: - // Logs the inference speed in tokens/sec. - void LogSpeedStats(double time_start, size_t total_tokens) const; - // Arguments to the model loader: file locations, etc. LoaderArgs loader_; // Arguments to the inference function: max tokens, etc. @@ -60,10 +100,16 @@ class GemmaEnv { std::unique_ptr model_; // The KV cache to use for inference. KVCache kv_cache_; + gcpp::RuntimeConfig runtime_config_; }; +// Logs the inference speed in tokens/sec. +void LogSpeedStats(double time_start, size_t total_tokens); + +void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app); +void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, + gcpp::AppArgs& app); + } // namespace gcpp - - #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_ diff --git a/gemma/benchmarks.cc b/gemma/benchmarks.cc index 630d5c8..283500c 100644 --- a/gemma/benchmarks.cc +++ b/gemma/benchmarks.cc @@ -13,108 +13,81 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include -#include -#include -#include +#include +#include + #include -// Placeholder for internal header, do not modify. #include "benchmark/benchmark.h" #include "gemma/benchmark_helper.h" -void run_gemma_prompt(const std::string& prompt_string, - gcpp::GemmaEnv& env, - benchmark::State& state) { - std::mt19937 gen; +namespace gcpp { - if (prompt_string.empty()) return; +// Shared state for benchmarks - unfortunately the library does not allow +// passing context nor closures. Raw pointer because style guide forbids +// non-local static objects with dtors.t +GemmaEnv* s_env = nullptr; - int token_counter = 0; +void RunPrompt(const std::string& original_prompt, benchmark::State& state) { + size_t total_tokens = 0; for (auto s : state) { - auto [response, n] = env.QueryModel(prompt_string); - std::cout << "response: " << response << "\n"; - std::cout << "n: " << n << "\n"; - token_counter += n; + std::string prompt = original_prompt; // reset from original + auto [response, n] = s_env->QueryModel(prompt); + if (s_env->Verbosity() != 0) { + fprintf(stdout, "|%s|\n", response.c_str()); + } + total_tokens += n; } - state.SetItemsProcessed(token_counter); + state.SetItemsProcessed(total_tokens); } -// Awkward global because benchmarks don't support additional state, so it is -// either this or cast to int64_t. -gcpp::GemmaEnv* global_env = nullptr; +} // namespace gcpp static void BM_short_prompt(benchmark::State& state) { - run_gemma_prompt("What is the capital of Spain?", *global_env, - state); + gcpp::RunPrompt("What is the capital of Spain?", state); } static void BM_factuality_prompt(benchmark::State& state) { - run_gemma_prompt("How does an inkjet printer work?", - *global_env, state); + gcpp::RunPrompt("How does an inkjet printer work?", state); } static void BM_creative_prompt(benchmark::State& state) { - run_gemma_prompt( - "Tell me a story about a magical bunny and their TRS-80.", - *global_env, state); + gcpp::RunPrompt("Tell me a story about a magical bunny and their TRS-80.", + state); } static void BM_coding_prompt(benchmark::State& state) { - run_gemma_prompt( - "Write a python program to generate a fibonacci sequence.", - *global_env, state); + gcpp::RunPrompt("Write a python program to generate a fibonacci sequence.", + state); } -static void BM_long_coding_prompt(benchmark::State& state) { - std::ifstream t("benchmarks.cc", std::ios_base::in); - std::stringstream buffer; - buffer << t.rdbuf(); - std::string prompt_string = buffer.str(); - t.close(); +BENCHMARK(BM_short_prompt) + ->Iterations(3) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); - run_gemma_prompt("Make improvements to the following code:\n " + - prompt_string, *global_env, state); -} +BENCHMARK(BM_factuality_prompt) + ->Iterations(3) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +BENCHMARK(BM_creative_prompt) + ->Iterations(3) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +BENCHMARK(BM_coding_prompt) + ->Iterations(3) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); int main(int argc, char** argv) { - { - // Placeholder for internal init, do not modify. - } gcpp::GemmaEnv env(argc, argv); + env.SetMaxGeneratedTokens(256); + gcpp::s_env = &env; - env.set_max_generated_tokens(128); - global_env = &env; - BENCHMARK(BM_short_prompt) - ->Iterations(3) - ->Unit(benchmark::kMillisecond) - ->UseRealTime(); - - env.set_max_generated_tokens(256); - BENCHMARK(BM_factuality_prompt) - ->Iterations(3) - ->Unit(benchmark::kMillisecond) - ->UseRealTime(); - - BENCHMARK(BM_creative_prompt) - ->Iterations(3) - ->Unit(benchmark::kMillisecond) - ->UseRealTime(); - - BENCHMARK(BM_coding_prompt) - ->Iterations(3) - ->Unit(benchmark::kMillisecond) - ->UseRealTime(); - - env.set_max_generated_tokens(1024); - BENCHMARK(BM_long_coding_prompt) - ->Iterations(3) - ->Unit(benchmark::kMillisecond) - ->UseRealTime(); - - ::benchmark ::RunSpecifiedBenchmarks(); - ::benchmark ::Shutdown(); + ::benchmark::RunSpecifiedBenchmarks(); + ::benchmark::Shutdown(); return 0; } diff --git a/gemma/common.cc b/gemma/common.cc index 45f98d7..7fff656 100644 --- a/gemma/common.cc +++ b/gemma/common.cc @@ -28,20 +28,18 @@ namespace gcpp { -namespace { -constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt", "2b-it", - "7b-it", "gr2b-it", "tiny"}; -constexpr Model kModelTypes[] = { - Model::GEMMA_2B, Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_2B, - Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_TINY}; -constexpr ModelTraining kModelTraining[] = { - ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, - ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, - ModelTraining::GEMMA_IT}; -} // namespace - const char* ParseModelTypeAndTraining(const std::string& model_flag, Model& model, ModelTraining& training) { + constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt", "2b-it", + "7b-it", "gr2b-it", "tiny"}; + constexpr Model kModelTypes[] = { + Model::GEMMA_2B, Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_2B, + Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_TINY}; + constexpr ModelTraining kModelTraining[] = { + ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, + ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, + ModelTraining::GEMMA_IT}; + constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags); static char kErrorMessageBuffer[kNum * 8 + 1024] = "Invalid or missing model flag, need to specify one of "; @@ -51,37 +49,58 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag, } strcat(kErrorMessageBuffer, kModelFlags[kNum - 1]); // NOLINT strcat(kErrorMessageBuffer, "."); // NOLINT + std::string model_type_lc = model_flag; std::transform(begin(model_type_lc), end(model_type_lc), begin(model_type_lc), [](unsigned char c) { return std::tolower(c); }); + for (size_t i = 0; i < kNum; i++) { if (kModelFlags[i] == model_type_lc) { model = kModelTypes[i]; training = kModelTraining[i]; + HWY_ASSERT(std::string(ModelString(model, training)) == model_type_lc); return nullptr; } } return kErrorMessageBuffer; } +const char* ModelString(Model model, ModelTraining training) { + if (model == Model::GEMMA_TINY) return "tiny"; + static_assert(static_cast(ModelTraining::GEMMA_IT) == 0); + constexpr const char* k2B[] = {"2b-it", "2b-pt"}; + constexpr const char* k7B[] = {"7b-it", "7b-pt"}; + constexpr const char* kGr2B[] = {"gr2b-it", "gr2b-pt"}; + if (model == Model::GEMMA_2B) return k2B[static_cast(training)]; + if (model == Model::GEMMA_7B) return k7B[static_cast(training)]; + if (model == Model::GRIFFIN_2B) return kGr2B[static_cast(training)]; + HWY_ABORT("Unknown model %d training %d\n", static_cast(model), + static_cast(training)); +} + +constexpr const char* kTypeStrings[] = {"f32", "bf16", "sfp"}; + +const char* StringFromType(Type type) { + return kTypeStrings[static_cast(type)]; +} + const char* ParseType(const std::string& type_string, Type& type) { - constexpr Type kTypes[] = {Type::kF32, Type::kBF16, Type::kSFP}; - constexpr const char* kStrings[] = {"f32", "bf16", "sfp"}; - constexpr size_t kNum = std::end(kStrings) - std::begin(kStrings); + constexpr size_t kNum = std::end(kTypeStrings) - std::begin(kTypeStrings); static char kErrorMessageBuffer[kNum * 8 + 100] = "Invalid or missing type, need to specify one of "; for (size_t i = 0; i + 1 < kNum; i++) { - strcat(kErrorMessageBuffer, kStrings[i]); // NOLINT + strcat(kErrorMessageBuffer, kTypeStrings[i]); // NOLINT strcat(kErrorMessageBuffer, ", "); // NOLINT } - strcat(kErrorMessageBuffer, kStrings[kNum - 1]); // NOLINT + strcat(kErrorMessageBuffer, kTypeStrings[kNum - 1]); // NOLINT strcat(kErrorMessageBuffer, "."); // NOLINT std::string type_lc = type_string; std::transform(begin(type_lc), end(type_lc), begin(type_lc), [](unsigned char c) { return std::tolower(c); }); for (size_t i = 0; i < kNum; i++) { - if (kStrings[i] == type_lc) { - type = kTypes[i]; + if (kTypeStrings[i] == type_lc) { + type = static_cast(i); + HWY_ASSERT(std::string(StringFromType(type)) == type_lc); return nullptr; } } diff --git a/gemma/common.h b/gemma/common.h index f234786..e9d86f6 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -154,18 +154,12 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag, Model& model, ModelTraining& training); const char* ParseType(const std::string& type_string, Type& type); -static inline const char* StringFromType(Type type) { - switch (type) { - case Type::kF32: - return "f32"; - case Type::kBF16: - return "bf16"; - case Type::kSFP: - return "sfp"; - default: - return "?"; - } -} +// Inverse of ParseModelTypeAndTraining. +const char* ModelString(Model model, ModelTraining training); +const char* StringFromType(Type type); + +// ---------------------------------------------------------------------------- +// // __builtin_sqrt is not constexpr as of Clang 17. #if HWY_COMPILER_GCC_ACTUAL diff --git a/gemma/cross_entropy.cc b/gemma/cross_entropy.cc index 9d345e2..f5ae88a 100644 --- a/gemma/cross_entropy.cc +++ b/gemma/cross_entropy.cc @@ -111,7 +111,7 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, }; TimingInfo timing_info; - gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info, nullptr); + gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info); const float scale = 1.0f / std::log(2.0f); return cross_entropy * scale; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index d0ed6a0..5de541b 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -17,7 +17,6 @@ // Compiles this file for multiple architectures via "foreach_target.h", to // which we pass the filename via macro 'argument'. -#include #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT #include "hwy/foreach_target.h" // IWYU pragma: keep @@ -208,8 +207,8 @@ bool GemmaTokenizer::Encode(const std::string& input, } bool GemmaTokenizer::Encode(const std::string& input, - std::vector* pieces) const { - return impl_->Encode(input, pieces); + std::vector* ids) const { + return impl_->Encode(input, ids); } // Given a sequence of ids, decodes it into a detokenized output. @@ -649,16 +648,16 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, // Compute the transformer for a batch of input tokens. During generation, // we usually have num_tokens == 1 (and also kBatchSize == 1). template -HWY_NOINLINE void Transformer(const int *tokens, size_t num_tokens, size_t pos, +HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, size_t pos, const WeightArrayT& weights, Activations& activations, KVCache& kv_cache, hwy::ThreadPool& pool, - LayersOutputT* layers_output) { + const LayersOutputFunc& layers_output) { HWY_ASSERT(num_tokens <= kBatchSize); - if (layers_output != nullptr) { + if (layers_output) { for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { float token_f = tokens[token_idx]; - (*layers_output)(pos + token_idx, "Tokens", &token_f, 1); + layers_output(pos + token_idx, "Tokens", &token_f, 1); } } static constexpr size_t kModelDim = TConfig::kModelDim; @@ -713,12 +712,11 @@ HWY_NOINLINE void Transformer(const int *tokens, size_t num_tokens, size_t pos, } AddFromBatched(num_tokens, activations.ffw_out.data(), activations.x.data(), kModelDim); - if (layers_output != nullptr) { + if (layers_output) { std::string block_name = "blocks." + std::to_string(layer); for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { - (*layers_output)(pos + token_idx, block_name, - activations.x.data() + token_idx * kModelDim, - kModelDim); + layers_output(pos + token_idx, block_name, + activations.x.data() + token_idx * kModelDim, kModelDim); } } } @@ -727,10 +725,10 @@ HWY_NOINLINE void Transformer(const int *tokens, size_t num_tokens, size_t pos, RMSNormInplaceBatched(num_tokens, weights.final_norm_scale.data(), activations.x.data(), kModelDim); - if (layers_output != nullptr) { + if (layers_output) { for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { - (*layers_output)(pos + token_idx, "final_norm", - activations.x.data() + token_idx * kModelDim, kModelDim); + layers_output(pos + token_idx, "final_norm", + activations.x.data() + token_idx * kModelDim, kModelDim); } } } @@ -782,8 +780,7 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8, const RuntimeConfig& runtime_config, const std::vector& prompt, size_t pos, KVCache& kv_cache, - hwy::ThreadPool& pool, TimingInfo& timing_info, - LayersOutputT* layers_output) { + hwy::ThreadPool& pool, TimingInfo& timing_info) { const CompressedWeights& weights = GetWeights(weights_u8); auto& prefill_activations = GetActivations(prefill_u8); @@ -860,7 +857,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, pos < max_tokens && generate_pos < max_generated_tokens; ++pos, ++pos_offset, ++generate_pos) { Transformer(&token, kDecodeBatchSize, pos, weights, - activations, kv_cache, pool, layers_output); + activations, kv_cache, pool, + runtime_config.layers_output); float token_logit = 0.0f; // The condition below is always true if we are doing Prefill above. // We keep it here for clarity so that the code is correct even if Prefill @@ -953,17 +951,37 @@ Gemma::~Gemma() { void Gemma::Generate(const RuntimeConfig& runtime_config, const std::vector& prompt, size_t start_pos, - KVCache& kv_cache, TimingInfo& timing_info, - LayersOutputT* layers_output) { + KVCache& kv_cache, TimingInfo& timing_info) { pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); GEMMA_EXPORT_AND_DISPATCH( model_type_, weight_type_, GenerateT, (weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos, - kv_cache, pool_, timing_info, layers_output)); + kv_cache, pool_, timing_info)); pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); } +std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, + const ModelTraining training, size_t pos, + std::string& prompt) { + // Instruction-tuned models are trained to expect control tokens. + if (training == ModelTraining::GEMMA_IT) { + // Prepend "" if this is a multi-turn dialogue continuation. + const std::string start = (pos == 0) + ? "user\n" + : "\nuser\n"; + prompt = start + prompt + "\nmodel\n"; + } + + std::vector tokens; + HWY_ASSERT(tokenizer.Encode(prompt, &tokens)); + // Both pre-trained and instruction-tuned require BOS as first token. + if (pos == 0) { + tokens.insert(tokens.begin(), gcpp::BOS_ID); + } + return tokens; +} + } // namespace gcpp #endif // HWY_ONCE diff --git a/gemma/gemma.h b/gemma/gemma.h index 49fe40d..9a91870 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -74,10 +74,17 @@ class GemmaTokenizer { using StreamFunc = std::function; // If not empty, AcceptFunc is called with token. It should return false for // tokens you don't want to generate and true for tokens you want to generate. -using AcceptFunc = std::function; +using AcceptFunc = std::function; // If not empty, SampleFunc is called with the probability distribution for the // next token, and its return value is used as the next generated token. using SampleFunc = std::function; +// Will be called for layers output with: +// - position in the tokens sequence +// - name of the data, p.ex. "tokens", "block.1", "final_norm" +// - pointer to the data array +// - size of the data array +using LayersOutputFunc = + std::function; struct RuntimeConfig { size_t max_tokens; @@ -88,6 +95,7 @@ struct RuntimeConfig { StreamFunc stream_token; AcceptFunc accept_token; // if empty, accepts all tokens. SampleFunc sample_func; // if empty, uses SampleTopK. + LayersOutputFunc layers_output; // if not empty, called after each layer. int eos_id = EOS_ID; }; @@ -97,14 +105,6 @@ struct TimingInfo { double time_to_first_token = 0.0; }; -// Will be called for layers output with: -// - position in the tokens sequence -// - name of the data, p.ex. "tokens", "block.1", "final_norm" -// - pointer to the data array -// - size of the data array -using LayersOutputT = - std::function; - class Gemma { public: Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, @@ -121,12 +121,9 @@ class Gemma { const ByteStorageT& Prefill() const { return prefill_u8_; } const ByteStorageT& Decode() const { return decode_u8_; } - // layers_output is optional; if set - it will be called with the activations - // output after applying each layer. void Generate(const RuntimeConfig& runtime_config, const std::vector& prompt, size_t start_pos, - KVCache& kv_cache, TimingInfo& timing_info, - LayersOutputT* layers_output = nullptr); + KVCache& kv_cache, TimingInfo& timing_info); private: hwy::ThreadPool& pool_; @@ -141,14 +138,19 @@ class Gemma { Type weight_type_; }; +// Adds BOS token and possibly 'turn' annotations, which depend on `training` +// and `pos`, the number of tokens decoded so far; returns the corresponding +// tokens. Asserts that tokenization is successful. +std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, + ModelTraining training, size_t pos, + std::string& prompt); + // DEPRECATED, call Gemma::Generate directly. HWY_INLINE void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& /*pool*/, - TimingInfo& timing_info, - LayersOutputT* layers_output) { - gemma.Generate(runtime_config, prompt, start_pos, kv_cache, timing_info, - layers_output); + TimingInfo& timing_info) { + gemma.Generate(runtime_config, prompt, start_pos, kv_cache, timing_info); } } // namespace gcpp diff --git a/gemma/gemma_test.cc b/gemma/gemma_test.cc index 8817a0c..b885195 100644 --- a/gemma/gemma_test.cc +++ b/gemma/gemma_test.cc @@ -17,89 +17,35 @@ #include -#include -#include #include #include -// Placeholder for internal header, do not modify. +#include "gemma/benchmark_helper.h" #include "gemma/common.h" -#include "gemma/cross_entropy.h" -#include "gemma/ops.h" -#include "util/app.h" -#include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/tests/test_util-inl.h" +#include "hwy/tests/hwy_gtest.h" namespace gcpp { namespace { -int s_argc = 0; -char** s_argv = nullptr; +// Shared state. Requires argc/argv, so construct in main and use the same raw +// pointer approach as in benchmarks.cc. Note that the style guide forbids +// non-local static variables with dtors. +GemmaEnv* s_env = nullptr; class GemmaTest : public ::testing::Test { protected: - static void SetUpTestSuite() { - gcpp::AppArgs app(s_argc, s_argv); - gcpp::LoaderArgs loader(s_argc, s_argv); - if (const char* err = loader.Validate()) { - fprintf(stderr, "Insufficient LoaderArgs, skipping e2e tests.\n"); - } else { - fprintf(stderr, "Loading model..\n"); - s_pool = std::make_unique(app.num_threads); - s_gemma = AllocateGemma(loader, *s_pool); - s_kv_cache = KVCache::Create(loader.ModelType()); - s_model = loader.ModelType(); - } - } - - static void TearDownTestSuite() { - s_pool.reset(); - s_gemma.reset(); - } - - std::string GemmaReply(const std::string& prompt_string) { - std::mt19937 gen; - gen.seed(42); - - std::vector prompt; - HWY_ASSERT(s_gemma->Tokenizer().Encode(prompt_string, &prompt)); - // For both pre-trained and instruction-tuned models: prepend "" token - // if needed. - prompt.insert(prompt.begin(), BOS_ID); - - std::vector response; - auto stream_token = [&response](int token, float) { - response.push_back(token); - return true; - }; - gcpp::RuntimeConfig runtime_config = { - .max_tokens = 3072, - .max_generated_tokens = 2048, - .temperature = 1.0, - .verbosity = 0, - .gen = &gen, - .stream_token = stream_token, - }; - gcpp::TimingInfo timing_info; - s_gemma->Generate(runtime_config, prompt, /*start_pos=*/0, s_kv_cache, - timing_info, /*layers_output=*/nullptr); - std::string response_text; - HWY_ASSERT(s_gemma->Tokenizer().Decode(response, &response_text)); - return response_text; - } - - float GemmaCrossEntropy(const std::string& prompt_string) { - std::vector prompt; - HWY_ASSERT(s_gemma->Tokenizer().Encode(prompt_string, &prompt)); - prompt.insert(prompt.begin(), BOS_ID); - return ComputeCrossEntropy(*s_gemma, /*max_tokens=*/3072, prompt, - s_kv_cache, - /*verbosity=*/0) / - prompt_string.size(); + std::string GemmaReply(const std::string& prompt) { + s_env->SetMaxGeneratedTokens(2048); + s_env->MutableConfig().temperature = 0.0f; // deterministic + s_env->MutableConfig().verbosity = 0; + // Using the turn structure worsens results. + const std::vector tokens = s_env->TokenizeAndPrependBOS(prompt); + auto [response, n] = s_env->QueryModel(tokens); + return response; } void TestQuestions(const char* kQA[][2], size_t num_questions) { - if (!s_gemma) return; + if (!s_env->GetModel()) return; for (size_t i = 0; i < num_questions; ++i) { fprintf(stderr, "Question %zu\n\n", i + 1); std::string response = GemmaReply(kQA[i][0]); @@ -107,18 +53,8 @@ class GemmaTest : public ::testing::Test { EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT } } - - static std::unique_ptr s_pool; - static std::unique_ptr s_gemma; - static gcpp::KVCache s_kv_cache; - static gcpp::Model s_model; }; -/*static*/ std::unique_ptr GemmaTest::s_pool; -/*static*/ std::unique_ptr GemmaTest::s_gemma; -/*static*/ gcpp::KVCache GemmaTest::s_kv_cache; -/*static*/ gcpp::Model GemmaTest::s_model; - TEST_F(GemmaTest, Geography) { static const char* kQA[][2] = { {"What is the capital of Hungary?", "Budapest"}, @@ -130,7 +66,7 @@ TEST_F(GemmaTest, Geography) { TEST_F(GemmaTest, History) { static const char* kQA[][2] = { - {"When was the Battle of Hastings?", "1066"}, + {"When was the battle of Hastings?", "1066"}, }; static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]); TestQuestions(kQA, kNum); @@ -181,42 +117,39 @@ static const char kGettysburg[] = { "people, for the people, shall not perish from the earth.\n"}; TEST_F(GemmaTest, CrossEntropySmall) { - if (!s_gemma) return; + if (!s_env->GetModel()) return; static const char kSmall[] = "The capital of Hungary is Budapest which is located in Europe."; - float entropy = GemmaCrossEntropy(kSmall); + float entropy = s_env->CrossEntropy(kSmall); fprintf(stderr, "per-byte entropy: %f\n", entropy); - EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 2.1f : 2.0f); + EXPECT_LT(entropy, + (s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 2.1f : 2.0f); } TEST_F(GemmaTest, CrossEntropyJingleBells) { - if (!s_gemma) return; - float entropy = GemmaCrossEntropy(kJingleBells); + if (!s_env->GetModel()) return; + float entropy = s_env->CrossEntropy(kJingleBells); fprintf(stderr, "per-byte entropy: %f\n", entropy); - EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 0.9f : 1.8f); + EXPECT_LT(entropy, + (s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 0.9f : 1.8f); } TEST_F(GemmaTest, CrossEntropyGettysburg) { - if (!s_gemma) return; - float entropy = GemmaCrossEntropy(kGettysburg); + if (!s_env->GetModel()) return; + float entropy = s_env->CrossEntropy(kGettysburg); fprintf(stderr, "per-byte entropy: %f\n", entropy); - EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 0.8f : 1.2f); + EXPECT_LT(entropy, + (s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 0.8f : 1.2f); } } // namespace } // namespace gcpp int main(int argc, char** argv) { - { - // Placeholder for internal init, do not modify. - } + gcpp::GemmaEnv env(argc, argv); + gcpp::s_env = &env; - // For later use by SetUp. - gcpp::s_argc = argc; - gcpp::s_argv = argv; - - // Probably should be called before SetUpTestSuite. - testing::InitGoogleTest(&gcpp::s_argc, argv); + testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } \ No newline at end of file diff --git a/gemma/ops.h b/gemma/ops.h index b29d8e9..dc00da6 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -1668,12 +1668,13 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( std::array indices{}; for (size_t i = 0; i < vocab_size; ++i) { if (probabilities[i] < top_k[k - 1] && - (!accept_token || accept_token(StaticCast(i)))) { + (!accept_token || accept_token(StaticCast(i), probabilities[i]))) { continue; } for (size_t j = 0; j < k; ++j) { if (probabilities[i] > top_k[j] && - (!accept_token || accept_token(StaticCast(i)))) { + (!accept_token || + accept_token(StaticCast(i), probabilities[i]))) { // shift elements by 1, insert the new value, move on to next value for (size_t idx = k - 1; idx > j; --idx) { top_k[idx] = top_k[idx - 1]; diff --git a/gemma/run.cc b/gemma/run.cc index ebd131e..acc3568 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -15,27 +15,22 @@ // Command line text interface to gemma. -#include #include #include #include #include -#include // NOLINT #include // Placeholder for internal header, do not modify. -#include "compression/compress.h" +#include "gemma/benchmark_helper.h" #include "gemma/common.h" -#include "gemma/configs.h" #include "gemma/gemma.h" // Gemma #include "util/app.h" #include "util/args.h" // HasHelp #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" -#include "hwy/per_target.h" #include "hwy/profiler.h" -#include "hwy/timer.h" #if (!defined(HWY_VERSION_LT) || HWY_VERSION_LT(1, 2)) && !HWY_IDE #error "Please update to version 1.2 of github.com/google/highway." @@ -57,56 +52,6 @@ static constexpr std::string_view kAsciiArtBanner = R""( |___/ |_| |_| )""; -void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { - loader.Print(app.verbosity); - inference.Print(app.verbosity); - app.Print(app.verbosity); - - if (app.verbosity >= 2) { - time_t now = time(nullptr); - char* dt = ctime(&now); // NOLINT - std::cout << "Date & Time : " << dt - << "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize - << "\n" - << "Hardware concurrency : " - << std::thread::hardware_concurrency() << "\n" - << "Instruction set : " - << hwy::TargetName(hwy::DispatchedTarget()) << " (" - << hwy::VectorBytes() * 8 << " bits)" << "\n"; - char cpu100[100]; - if (hwy::platform::GetCpuString(cpu100)) { - std::cout << "CPU : " << cpu100 << "\n"; - } - std::cout << "Compiled config : " << CompiledConfig() << "\n" - << "Weight Type : " - << gcpp::StringFromType(loader.WeightType()) << "\n" - << "EmbedderInput Type : " - << gcpp::TypeName(gcpp::EmbedderInputT()) << "\n"; - } -} - -void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, - gcpp::AppArgs& app) { - std::cerr - << kAsciiArtBanner - << "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n" - "==========================================================\n\n" - "To run gemma.cpp, you need to " - "specify 3 required model loading arguments:\n" - " --tokenizer\n" - " --weights\n" - " --model.\n"; - std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm " - "--weights 2b-it-sfp.sbs --model 2b-it\n"; - std::cerr << "\n*Model Loading Arguments*\n\n"; - loader.Help(); - std::cerr << "\n*Inference Arguments*\n\n"; - inference.Help(); - std::cerr << "\n*Application Arguments*\n\n"; - app.Help(); - std::cerr << "\n"; -} - // The main Read-Eval-Print Loop. void ReplGemma(gcpp::Gemma& model, ModelTraining training, gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, @@ -118,12 +63,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training, int prompt_size{}; std::mt19937 gen; - if (args.deterministic) { - gen.seed(42); - } else { - std::random_device rd; - gen.seed(rd()); - } + InitGenerator(args, gen); // callback function invoked for each generated token. auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size, @@ -162,7 +102,6 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training, while (abs_pos < args.max_tokens) { std::string prompt_string; - std::vector prompt; current_pos = 0; { PROFILER_ZONE("Gen.input"); @@ -192,30 +131,11 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training, continue; } - if (training == ModelTraining::GEMMA_IT) { - // For instruction-tuned models: add control tokens. - prompt_string = "user\n" + prompt_string + - "\nmodel\n"; - if (abs_pos != 0) { - // Prepend "" token if this is a multi-turn dialogue - // continuation. - prompt_string = "\n" + prompt_string; - } - } - - HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt)); - - // For both pre-trained and instruction-tuned models: prepend "" token - // if needed. - if (abs_pos == 0) { - prompt.insert(prompt.begin(), gcpp::BOS_ID); - } - + const std::vector prompt = + WrapAndTokenize(model.Tokenizer(), training, abs_pos, prompt_string); prompt_size = prompt.size(); - std::cerr << "\n" << "[ Reading prompt ] " << std::flush; - if constexpr (kVerboseLogTokens) { for (int i = 0; i < prompt_size; ++i) { fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); @@ -301,17 +221,20 @@ int main(int argc, char** argv) { gcpp::AppArgs app(argc, argv); if (gcpp::HasHelp(argc, argv)) { - ShowHelp(loader, inference, app); + std::cerr << gcpp::kAsciiArtBanner; + gcpp::ShowHelp(loader, inference, app); return 0; } if (const char* error = loader.Validate()) { - ShowHelp(loader, inference, app); + std::cerr << gcpp::kAsciiArtBanner; + gcpp::ShowHelp(loader, inference, app); HWY_ABORT("\nInvalid args: %s", error); } if (const char* error = inference.Validate()) { - ShowHelp(loader, inference, app); + std::cerr << gcpp::kAsciiArtBanner; + gcpp::ShowHelp(loader, inference, app); HWY_ABORT("\nInvalid args: %s", error); } diff --git a/gemma/run_mmlu.cc b/gemma/run_mmlu.cc index 6de8b37..814ecd6 100644 --- a/gemma/run_mmlu.cc +++ b/gemma/run_mmlu.cc @@ -13,19 +13,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Command line text interface to gemma. +#include -#include -#include -#include -#include -#include +#include #include #include -// Placeholder for internal header, do not modify. +#include "compression/io.h" // Path +#include "gemma/benchmark_helper.h" #include "gemma/gemma.h" // Gemma -#include "util/app.h" +#include "util/args.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" @@ -34,164 +31,134 @@ namespace gcpp { -void JsonGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache, - hwy::ThreadPool& pool, - const InferenceArgs& args, int verbosity, - std::string& eot_line) { - PROFILER_ZONE("Gen.misc"); - // token index within the current turn - int max_tokens = 4096; +struct JsonArgs : public ArgsBase { + JsonArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } - std::mt19937 gen; - if (args.deterministic) { - gen.seed(42); - } else { - std::random_device rd; - gen.seed(rd()); + Path input; + + // Returns error string or nullptr if OK. + const char* Validate() const { + if (input.Empty()) return "Must specify --input"; + if (!input.Exists()) return "--input file does not exist"; + return nullptr; } - float answers = 0.0; - float correct_answers = 0.0; + template + void ForEach(const Visitor& visitor) { + visitor(input, "input", Path(), "Full pathname of mmlu.json."); + }; +}; - std::ifstream fJson("/tmp/mmlu.json"); - std::stringstream buffer; - buffer << fJson.rdbuf(); - auto json = nlohmann::json::parse(buffer.str()); - - std::vector accept_tokens = {"A", "B", "C", "D"}; - std::set accept_token_set{}; - for (const std::string& accept_token : accept_tokens) { - std::vector accept_token_ids; - HWY_ASSERT(model.Tokenizer().Encode(accept_token, &accept_token_ids)); - accept_token_set.insert(accept_token_ids.begin(), accept_token_ids.end()); +// Linear search for a few tokens is faster than std::set. +// TODO: instead of accepting for each vocab entry, filter the logits once. +class TokenSet { + public: + TokenSet(const GemmaTokenizer& tokenizer, + const std::vector& strings) { + all_tokens_.reserve(strings.size()); + for (const std::string& str : strings) { + std::vector tokens; + fprintf(stderr, "%s -> ", str.c_str()); + HWY_ASSERT(tokenizer.Encode(str, &tokens)); + for (int token : tokens) { + fprintf(stderr, "%d, ", token); + all_tokens_.push_back(token); + } + fprintf(stderr, "\n"); + } } - for (auto sample : json["samples"]) { - int abs_pos = 0; // absolute token index over all turns - int current_pos = 0; - int prompt_size{}; + bool Contains(int token) const { + return std::find(all_tokens_.begin(), all_tokens_.end(), token) != + all_tokens_.end(); + } - // cout << "prompt:" << sample["prompt"] << endl; - const std::string& prompt_string = sample["prompt"]; - std::vector prompt; + private: + std::vector all_tokens_; +}; - HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt)); - prompt_size = prompt.size(); +void Run(GemmaEnv& env, JsonArgs& json) { + PROFILER_ZONE("Run.all"); - const std::string& correct_answer = accept_tokens[sample["input_label"]]; + float answers = 0.0f; + float correct_answers = 0.0f; - // max_tokens = prompt_size + max_tokens; + auto json_data = nlohmann::json::parse(ReadFileToString(json.input)); + + const std::vector accept_strings = { + "A", "B", "C", "D", // + " A", " B", " C", " D", // + "**", "**:", ":**", "The", "Answer", "is", ":", "."}; + const TokenSet accept_set(env.GetModel()->Tokenizer(), accept_strings); + + for (auto sample : json_data["samples"]) { + const int id = sample["i"]; + fprintf(stderr, "Processing question %d\n", id); + const std::string& correct_answer = accept_strings[sample["input_label"]]; + std::string prompt_string = sample["prompt"]; + // AcceptFunc restricts the output to one of these four tokens, so make an + // effort to steer the model towards that. See + // https://huggingface.co/blog/open-llm-leaderboard-mmlu + prompt_string += + "What is start of the line with the correct answer? " + "Do not include any justifications or explanations. Reply only with a " + "letter."; + const std::vector prompt = + WrapAndTokenize(env.GetModel()->Tokenizer(), env.ModelTrainingType(), + /*pos=*/0, prompt_string); + const size_t prompt_size = prompt.size(); std::vector predicted_token_ids; - predicted_token_ids.reserve(max_tokens); - auto stream_token = [¤t_pos, &prompt_size, &predicted_token_ids, - &accept_token_set](int token, float proba) { + predicted_token_ids.reserve(4096); + size_t current_pos = 0; + const StreamFunc stream_token = [¤t_pos, prompt_size, + &predicted_token_ids](int token, + float proba) { + PROFILER_ZONE("Stream"); ++current_pos; if (current_pos > prompt_size) { predicted_token_ids.push_back(token); - - // If the generated token is in the accepted token set, return False. - // This will stop further generation. - return accept_token_set.find(token) == accept_token_set.end(); } - return true; }; - const AcceptFunc accept_token = [¤t_pos, &prompt_size, - &accept_token_set](int token) { - // i.e. we have no constraints on accepted tokens - if (accept_token_set.empty()) { - return true; - } - - if (current_pos >= prompt_size) { - return accept_token_set.find(token) != accept_token_set.end(); - } else { - // auto-accept early tokens - return true; - } - }; - + // Although " A" is a token, it is difficult to associate that with the + // correct answer. Only accepting certain tokens is risky: (A) is easily + // confused with the word "A". gcpp::TimingInfo timing_info; gcpp::RuntimeConfig runtime_config = { - .max_tokens = args.max_tokens, - .max_generated_tokens = args.max_generated_tokens, - .temperature = args.temperature, - .verbosity = verbosity, - .gen = &gen, + .max_tokens = env.MaxTokens(), + .max_generated_tokens = 30, + .temperature = 0.0f, + .verbosity = env.Verbosity(), + .gen = &env.MutableGen(), .stream_token = stream_token, - .accept_token = accept_token, }; - model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info); + env.GetModel()->Generate(runtime_config, prompt, /*pos=*/0, + env.MutableKVCache(), timing_info); - std::string output_string; - HWY_ASSERT(model.Tokenizer().Decode(predicted_token_ids, &output_string)); - std::cout << "QuestionId: " << sample["i"] << "; " - << "Predicted Answer: " << output_string << "; " - << "Correct Answer: " << correct_answer << std::endl; + std::string output_string = env.StringFromTokens(predicted_token_ids); + fprintf(stderr, "Correct %s, model '%s'\n", correct_answer.c_str(), + output_string.c_str()); - answers += 1.0; + answers += 1.0f; if (output_string == correct_answer) { - correct_answers += 1.0; + correct_answers += 1.0f; } - std::cout << "Running accuracy = " << "[" - << static_cast(correct_answers) << "/" - << static_cast(answers) << "]" << " = " - << correct_answers / answers << std::endl; + fprintf(stderr, "%.0f/%.0f = %.2f%%\n", correct_answers, answers, + correct_answers / answers); } } -void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { - loader.Print(app.verbosity); - inference.Print(app.verbosity); - app.Print(app.verbosity); -} - -void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { - PROFILER_ZONE("Run.misc"); - - hwy::ThreadPool pool(app.num_threads); - // For many-core, pinning workers to cores helps. - if (app.num_threads > 10) { - PinWorkersToCores(pool); - } - - gcpp::Gemma model = gcpp::CreateGemma(loader, pool); - gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType()); - - JsonGemma(model, kv_cache, pool, inference, app.verbosity, app.eot_line); -} - } // namespace gcpp int main(int argc, char** argv) { { - PROFILER_ZONE("Startup.misc"); - - // Placeholder for internal init, do not modify. - - gcpp::LoaderArgs loader(argc, argv); - gcpp::InferenceArgs inference(argc, argv); - gcpp::AppArgs app(argc, argv); - - if (const char* error = loader.Validate()) { - fprintf(stderr, - "\ngemma.cpp\n---------\n\nTo run gemma.cpp, you need to " - "specify 3 required model loading arguments: --tokenizer, " - "--compressed_weights, " - "and --model.\n\nModel Loading Arguments\n\n"); - - loader.Help(); - fprintf(stderr, "\nInference Arguments\n\n"); - inference.Help(); - fprintf(stderr, "\nApplication Arguments\n\n"); - app.Help(); - fprintf(stderr, "\n\n"); - HWY_ABORT("\nInvalid args: %s", error); - } - - gcpp::Run(loader, inference, app); + PROFILER_ZONE("Startup.all"); + gcpp::GemmaEnv env(argc, argv); + gcpp::JsonArgs json(argc, argv); + gcpp::AbortIfInvalidArgs(json); + gcpp::Run(env, json); } PROFILER_PRINT_RESULTS(); // Must call outside the zone above. return 0; diff --git a/util/args.h b/util/args.h index b50cbaa..c49ed72 100644 --- a/util/args.h +++ b/util/args.h @@ -197,6 +197,14 @@ static inline HWY_MAYBE_UNUSED bool HasHelp(int argc, char* argv[]) { return false; } +template +static inline HWY_MAYBE_UNUSED void AbortIfInvalidArgs(TArgs& args) { + if (const char* err = args.Validate()) { + args.Help(); + HWY_ABORT("Problem with args: %s\n", err); + } +} + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_ARGS_H_