diff --git a/BUILD.bazel b/BUILD.bazel index a0aef79..aa75c90 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -165,7 +165,7 @@ cc_test( ":cross_entropy", ":gemma_lib", ":ops", - # "//base", + # Placeholder for internal dep, do not remove., "@googletest//:gtest_main", "//compression:io", "@hwy//:hwy_test_util", @@ -181,7 +181,7 @@ cc_binary( ":args", ":common", ":gemma_lib", - # "//base", + # Placeholder for internal dep, do not remove., "//compression:compress", "@hwy//:hwy", "@hwy//:nanobenchmark", @@ -198,7 +198,7 @@ cc_binary( ":common", ":gemma_lib", ":weights", - # "//base", + # Placeholder for internal dep, do not remove., "//compression:compress", "@hwy//:hwy", "@hwy//:nanobenchmark", @@ -208,7 +208,7 @@ cc_binary( ) cc_binary( - name = "benchmark", + name = "single_benchmark", srcs = ["gemma/benchmark.cc"], deps = [ ":app", @@ -216,7 +216,7 @@ cc_binary( ":common", ":cross_entropy", ":gemma_lib", - # "//base", + # Placeholder for internal dep, do not remove., "//compression:io", "@hwy//:hwy", "@hwy//:nanobenchmark", @@ -225,6 +225,16 @@ cc_binary( ], ) +cc_binary( + name = "benchmarks", + srcs = ["gemma/benchmarks.cc"], + deps = [ + ":benchmark_helper", + # Placeholder for internal dep, do not remove., + "@benchmark//:benchmark", + ], +) + cc_binary( name = "debug_prompt", srcs = [ @@ -234,7 +244,7 @@ cc_binary( ":app", ":args", ":gemma_lib", - # "//base", + # Placeholder for internal dep, do not remove., "//compression:io", "@hwy//:hwy", "@hwy//:thread_pool", @@ -248,7 +258,7 @@ cc_binary( deps = [ ":app", ":gemma_lib", - # "//base", + # Placeholder for internal dep, do not remove., "@hwy//:hwy", "@hwy//:profiler", "@hwy//:thread_pool", @@ -308,6 +318,25 @@ cc_library( ], ) +cc_library( + name = "benchmark_helper", + srcs = [ + "gemma/benchmark_helper.cc", + ], + hdrs = [ + "gemma/benchmark_helper.h", + ], + deps = [ + ":app", + ":common", + ":gemma_lib", + "@benchmark//:benchmark", + "@hwy//:hwy", + "@hwy//:nanobenchmark", + "@hwy//:thread_pool", + ], +) + cc_test( name = "backward_scalar_test", size = "large", diff --git a/CMakeLists.txt b/CMakeLists.txt index 5739561..1ff655c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -97,8 +97,11 @@ add_executable(gemma gemma/run.cc) target_link_libraries(gemma libgemma hwy hwy_contrib) install(TARGETS gemma DESTINATION bin) -add_executable(benchmark gemma/benchmark.cc) -target_link_libraries(benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json) +add_executable(single_benchmark gemma/benchmark.cc) +target_link_libraries(single_benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json) + +add_executable(benchmarks gemma/benchmarks.cc) +target_link_libraries(benchmarks libgemma hwy hwy_contrib nlohmann_json::nlohmann_json benchmark) add_executable(debug_prompt debug_prompt.cc) target_link_libraries(debug_prompt libgemma hwy hwy_contrib nlohmann_json::nlohmann_json) diff --git a/MODULE.bazel b/MODULE.bazel index 7c6336c..43b33a5 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -63,3 +63,10 @@ http_archive( strip_prefix = "abseil-cpp-9687a8ea750bfcddf790372093245a1d041b21a3", urls = ["https://github.com/abseil/abseil-cpp/archive//9687a8ea750bfcddf790372093245a1d041b21a3.tar.gz"], ) +# Benchmark +http_archive( + name = "benchmark", + urls = ["https://github.com/google/benchmark/archive/refs/tags/v1.8.2.tar.gz"], + integrity = "sha256-KqspgNA3YTf5adkoSPu2gharsHYzA0U0/IxlzE56DpM=", + strip_prefix = "benchmark-1.8.2", +) diff --git a/gemma/benchmark_helper.cc b/gemma/benchmark_helper.cc new file mode 100644 index 0000000..d48914b --- /dev/null +++ b/gemma/benchmark_helper.cc @@ -0,0 +1,118 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gemma/benchmark_helper.h" +#include // EXIT_FAILURE +#include +#include +#include +#include +#include +#include // std::pair +#include + +#include "gemma/common.h" +#include "gemma/gemma.h" +#include "util/app.h" +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/highway.h" +#include "hwy/timer.h" + +namespace gcpp { + GemmaEnv::GemmaEnv(int argc, char** argv) + : loader_(argc, argv), inference_args_(argc, argv), app_(argc, argv), + pool_(app_.num_threads) { + if (const char* error = loader_.Validate()) { + HWY_ABORT("\nInvalid loader args: %s", error); + } + if (const char* error = inference_args_.Validate()) { + HWY_ABORT("\nInvalid inference args: %s", error); + } + // For many-core, pinning workers to cores helps. + if (app_.num_threads > 10) { + gcpp::PinWorkersToCores(pool_); + } + model_ = AllocateGemma(loader_, pool_); + kv_cache_ = KVCache::Create(loader_.ModelType()); + gen_.seed(42); + } + +std::pair GemmaEnv::QueryModel(const std::string& input) { + std::string prompt_string = input; + if (loader_.ModelTrainingType() == ModelTraining::GEMMA_IT) { + // For instruction-tuned models: add control tokens. + prompt_string = "user\n" + input + + "\nmodel\n"; + } + std::vector prompt; + HWY_ASSERT(model_->Tokenizer().Encode(input, &prompt)); + + // For both pre-trained and instruction-tuned models: prepend "" token + // if needed. + prompt.insert(prompt.begin(), gcpp::BOS_ID); + std::string res; + size_t total_tokens = 0; + auto accept_token = [](int) { return true; }; + std::mt19937 gen; + gen.seed(42); + + const double time_start = hwy::platform::Now(); + auto stream_token = [&res, &total_tokens, &time_start, this]( + int token, float) { + ++total_tokens; + std::string token_text; + HWY_ASSERT(model_->Tokenizer().Decode(std::vector{token}, + &token_text)); + res += token_text; + if (app_.verbosity >= 1 && total_tokens % 100 == 0) { + LogSpeedStats(time_start, total_tokens); + } + return true; + }; + if (app_.verbosity >= 2) { + std::cout << inference_args_.max_tokens << " " + << inference_args_.max_generated_tokens << " " + << inference_args_.temperature; + } + gcpp::TimingInfo timing_info; + gcpp::RuntimeConfig runtime_config = { + .max_tokens = inference_args_.max_tokens, + .max_generated_tokens = inference_args_.max_generated_tokens, + .temperature = inference_args_.temperature, + .verbosity = app_.verbosity, + .gen = &gen, + .stream_token = stream_token, + .accept_token = accept_token, + }; + model_->Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache_, + timing_info, /*layers_output=*/nullptr); + if (app_.verbosity >= 1) { + LogSpeedStats(time_start, total_tokens); + } + return {res, total_tokens}; +} + +void GemmaEnv::LogSpeedStats(double time_start, size_t total_tokens) const { + const double time_end = hwy::platform::Now(); + const double time_elapsed = time_end - time_start; + const double tok_sec = total_tokens / time_elapsed; + std::cout << total_tokens << " tokens in " << time_elapsed << " seconds" + << " [" << tok_sec << " tokens / sec" << "]\n"; +} + + +} // namespace gcpp + diff --git a/gemma/benchmark_helper.h b/gemma/benchmark_helper.h new file mode 100644 index 0000000..1feac4a --- /dev/null +++ b/gemma/benchmark_helper.h @@ -0,0 +1,69 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_ + +#include +#include +#include +#include + +#include "gemma/gemma.h" +#include "util/app.h" +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace gcpp { + +// Convenience class to load a model and run inference. +class GemmaEnv { + public: + GemmaEnv(int argc, char** argv); + + // Sets the maximum number of output tokens to generate. + void set_max_generated_tokens(int max_tokens) { + inference_args_.max_generated_tokens = max_tokens; + } + + // Runs inference on the given input and returns the top-1 result string and + // the number of tokens that were generated. + std::pair QueryModel(const std::string& input); + + private: + // Logs the inference speed in tokens/sec. + void LogSpeedStats(double time_start, size_t total_tokens) const; + + // Arguments to the model loader: file locations, etc. + LoaderArgs loader_; + // Arguments to the inference function: max tokens, etc. + InferenceArgs inference_args_; + // Controls overall behavior of the app. + AppArgs app_; + // Thread pool for running inference. + hwy::ThreadPool pool_; + // Random number generator. + std::mt19937 gen_; + // The model to run inference on. + std::unique_ptr model_; + // The KV cache to use for inference. + KVCache kv_cache_; +}; + +} // namespace gcpp + + + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_ diff --git a/gemma/benchmarks.cc b/gemma/benchmarks.cc new file mode 100644 index 0000000..630d5c8 --- /dev/null +++ b/gemma/benchmarks.cc @@ -0,0 +1,120 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +// Placeholder for internal header, do not modify. +#include "benchmark/benchmark.h" +#include "gemma/benchmark_helper.h" + +void run_gemma_prompt(const std::string& prompt_string, + gcpp::GemmaEnv& env, + benchmark::State& state) { + std::mt19937 gen; + + if (prompt_string.empty()) return; + + int token_counter = 0; + for (auto s : state) { + auto [response, n] = env.QueryModel(prompt_string); + std::cout << "response: " << response << "\n"; + std::cout << "n: " << n << "\n"; + token_counter += n; + } + + state.SetItemsProcessed(token_counter); +} + +// Awkward global because benchmarks don't support additional state, so it is +// either this or cast to int64_t. +gcpp::GemmaEnv* global_env = nullptr; + +static void BM_short_prompt(benchmark::State& state) { + run_gemma_prompt("What is the capital of Spain?", *global_env, + state); +} + +static void BM_factuality_prompt(benchmark::State& state) { + run_gemma_prompt("How does an inkjet printer work?", + *global_env, state); +} + +static void BM_creative_prompt(benchmark::State& state) { + run_gemma_prompt( + "Tell me a story about a magical bunny and their TRS-80.", + *global_env, state); +} + +static void BM_coding_prompt(benchmark::State& state) { + run_gemma_prompt( + "Write a python program to generate a fibonacci sequence.", + *global_env, state); +} + +static void BM_long_coding_prompt(benchmark::State& state) { + std::ifstream t("benchmarks.cc", std::ios_base::in); + std::stringstream buffer; + buffer << t.rdbuf(); + std::string prompt_string = buffer.str(); + t.close(); + + run_gemma_prompt("Make improvements to the following code:\n " + + prompt_string, *global_env, state); +} + +int main(int argc, char** argv) { + { + // Placeholder for internal init, do not modify. + } + gcpp::GemmaEnv env(argc, argv); + + env.set_max_generated_tokens(128); + global_env = &env; + BENCHMARK(BM_short_prompt) + ->Iterations(3) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + + env.set_max_generated_tokens(256); + BENCHMARK(BM_factuality_prompt) + ->Iterations(3) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + + BENCHMARK(BM_creative_prompt) + ->Iterations(3) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + + BENCHMARK(BM_coding_prompt) + ->Iterations(3) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + + env.set_max_generated_tokens(1024); + BENCHMARK(BM_long_coding_prompt) + ->Iterations(3) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + + ::benchmark ::RunSpecifiedBenchmarks(); + ::benchmark ::Shutdown(); + return 0; +}