mirror of https://github.com/google/gemma.cpp.git
Updated benchmarks.cc to recent changes to Gemma API.
PiperOrigin-RevId: 642285902
This commit is contained in:
parent
b6565e3bf6
commit
bdf33c7008
43
BUILD.bazel
43
BUILD.bazel
|
|
@ -165,7 +165,7 @@ cc_test(
|
||||||
":cross_entropy",
|
":cross_entropy",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":ops",
|
":ops",
|
||||||
# "//base",
|
# Placeholder for internal dep, do not remove.,
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@hwy//:hwy_test_util",
|
"@hwy//:hwy_test_util",
|
||||||
|
|
@ -181,7 +181,7 @@ cc_binary(
|
||||||
":args",
|
":args",
|
||||||
":common",
|
":common",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
# "//base",
|
# Placeholder for internal dep, do not remove.,
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:nanobenchmark",
|
"@hwy//:nanobenchmark",
|
||||||
|
|
@ -198,7 +198,7 @@ cc_binary(
|
||||||
":common",
|
":common",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":weights",
|
":weights",
|
||||||
# "//base",
|
# Placeholder for internal dep, do not remove.,
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:nanobenchmark",
|
"@hwy//:nanobenchmark",
|
||||||
|
|
@ -208,7 +208,7 @@ cc_binary(
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "benchmark",
|
name = "single_benchmark",
|
||||||
srcs = ["gemma/benchmark.cc"],
|
srcs = ["gemma/benchmark.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":app",
|
":app",
|
||||||
|
|
@ -216,7 +216,7 @@ cc_binary(
|
||||||
":common",
|
":common",
|
||||||
":cross_entropy",
|
":cross_entropy",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
# "//base",
|
# Placeholder for internal dep, do not remove.,
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:nanobenchmark",
|
"@hwy//:nanobenchmark",
|
||||||
|
|
@ -225,6 +225,16 @@ cc_binary(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "benchmarks",
|
||||||
|
srcs = ["gemma/benchmarks.cc"],
|
||||||
|
deps = [
|
||||||
|
":benchmark_helper",
|
||||||
|
# Placeholder for internal dep, do not remove.,
|
||||||
|
"@benchmark//:benchmark",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "debug_prompt",
|
name = "debug_prompt",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
|
@ -234,7 +244,7 @@ cc_binary(
|
||||||
":app",
|
":app",
|
||||||
":args",
|
":args",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
# "//base",
|
# Placeholder for internal dep, do not remove.,
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:thread_pool",
|
"@hwy//:thread_pool",
|
||||||
|
|
@ -248,7 +258,7 @@ cc_binary(
|
||||||
deps = [
|
deps = [
|
||||||
":app",
|
":app",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
# "//base",
|
# Placeholder for internal dep, do not remove.,
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:profiler",
|
"@hwy//:profiler",
|
||||||
"@hwy//:thread_pool",
|
"@hwy//:thread_pool",
|
||||||
|
|
@ -308,6 +318,25 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "benchmark_helper",
|
||||||
|
srcs = [
|
||||||
|
"gemma/benchmark_helper.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"gemma/benchmark_helper.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":app",
|
||||||
|
":common",
|
||||||
|
":gemma_lib",
|
||||||
|
"@benchmark//:benchmark",
|
||||||
|
"@hwy//:hwy",
|
||||||
|
"@hwy//:nanobenchmark",
|
||||||
|
"@hwy//:thread_pool",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "backward_scalar_test",
|
name = "backward_scalar_test",
|
||||||
size = "large",
|
size = "large",
|
||||||
|
|
|
||||||
|
|
@ -97,8 +97,11 @@ add_executable(gemma gemma/run.cc)
|
||||||
target_link_libraries(gemma libgemma hwy hwy_contrib)
|
target_link_libraries(gemma libgemma hwy hwy_contrib)
|
||||||
install(TARGETS gemma DESTINATION bin)
|
install(TARGETS gemma DESTINATION bin)
|
||||||
|
|
||||||
add_executable(benchmark gemma/benchmark.cc)
|
add_executable(single_benchmark gemma/benchmark.cc)
|
||||||
target_link_libraries(benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
|
target_link_libraries(single_benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
|
||||||
|
|
||||||
|
add_executable(benchmarks gemma/benchmarks.cc)
|
||||||
|
target_link_libraries(benchmarks libgemma hwy hwy_contrib nlohmann_json::nlohmann_json benchmark)
|
||||||
|
|
||||||
add_executable(debug_prompt debug_prompt.cc)
|
add_executable(debug_prompt debug_prompt.cc)
|
||||||
target_link_libraries(debug_prompt libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
|
target_link_libraries(debug_prompt libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
|
||||||
|
|
|
||||||
|
|
@ -63,3 +63,10 @@ http_archive(
|
||||||
strip_prefix = "abseil-cpp-9687a8ea750bfcddf790372093245a1d041b21a3",
|
strip_prefix = "abseil-cpp-9687a8ea750bfcddf790372093245a1d041b21a3",
|
||||||
urls = ["https://github.com/abseil/abseil-cpp/archive//9687a8ea750bfcddf790372093245a1d041b21a3.tar.gz"],
|
urls = ["https://github.com/abseil/abseil-cpp/archive//9687a8ea750bfcddf790372093245a1d041b21a3.tar.gz"],
|
||||||
)
|
)
|
||||||
|
# Benchmark
|
||||||
|
http_archive(
|
||||||
|
name = "benchmark",
|
||||||
|
urls = ["https://github.com/google/benchmark/archive/refs/tags/v1.8.2.tar.gz"],
|
||||||
|
integrity = "sha256-KqspgNA3YTf5adkoSPu2gharsHYzA0U0/IxlzE56DpM=",
|
||||||
|
strip_prefix = "benchmark-1.8.2",
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,118 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "gemma/benchmark_helper.h"
|
||||||
|
#include <cstdlib> // EXIT_FAILURE
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
#include <ostream>
|
||||||
|
#include <random>
|
||||||
|
#include <string>
|
||||||
|
#include <utility> // std::pair
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "gemma/common.h"
|
||||||
|
#include "gemma/gemma.h"
|
||||||
|
#include "util/app.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
#include "hwy/timer.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
GemmaEnv::GemmaEnv(int argc, char** argv)
|
||||||
|
: loader_(argc, argv), inference_args_(argc, argv), app_(argc, argv),
|
||||||
|
pool_(app_.num_threads) {
|
||||||
|
if (const char* error = loader_.Validate()) {
|
||||||
|
HWY_ABORT("\nInvalid loader args: %s", error);
|
||||||
|
}
|
||||||
|
if (const char* error = inference_args_.Validate()) {
|
||||||
|
HWY_ABORT("\nInvalid inference args: %s", error);
|
||||||
|
}
|
||||||
|
// For many-core, pinning workers to cores helps.
|
||||||
|
if (app_.num_threads > 10) {
|
||||||
|
gcpp::PinWorkersToCores(pool_);
|
||||||
|
}
|
||||||
|
model_ = AllocateGemma(loader_, pool_);
|
||||||
|
kv_cache_ = KVCache::Create(loader_.ModelType());
|
||||||
|
gen_.seed(42);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::string, int> GemmaEnv::QueryModel(const std::string& input) {
|
||||||
|
std::string prompt_string = input;
|
||||||
|
if (loader_.ModelTrainingType() == ModelTraining::GEMMA_IT) {
|
||||||
|
// For instruction-tuned models: add control tokens.
|
||||||
|
prompt_string = "<start_of_turn>user\n" + input +
|
||||||
|
"<end_of_turn>\n<start_of_turn>model\n";
|
||||||
|
}
|
||||||
|
std::vector<int> prompt;
|
||||||
|
HWY_ASSERT(model_->Tokenizer().Encode(input, &prompt));
|
||||||
|
|
||||||
|
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
||||||
|
// if needed.
|
||||||
|
prompt.insert(prompt.begin(), gcpp::BOS_ID);
|
||||||
|
std::string res;
|
||||||
|
size_t total_tokens = 0;
|
||||||
|
auto accept_token = [](int) { return true; };
|
||||||
|
std::mt19937 gen;
|
||||||
|
gen.seed(42);
|
||||||
|
|
||||||
|
const double time_start = hwy::platform::Now();
|
||||||
|
auto stream_token = [&res, &total_tokens, &time_start, this](
|
||||||
|
int token, float) {
|
||||||
|
++total_tokens;
|
||||||
|
std::string token_text;
|
||||||
|
HWY_ASSERT(model_->Tokenizer().Decode(std::vector<int>{token},
|
||||||
|
&token_text));
|
||||||
|
res += token_text;
|
||||||
|
if (app_.verbosity >= 1 && total_tokens % 100 == 0) {
|
||||||
|
LogSpeedStats(time_start, total_tokens);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
if (app_.verbosity >= 2) {
|
||||||
|
std::cout << inference_args_.max_tokens << " "
|
||||||
|
<< inference_args_.max_generated_tokens << " "
|
||||||
|
<< inference_args_.temperature;
|
||||||
|
}
|
||||||
|
gcpp::TimingInfo timing_info;
|
||||||
|
gcpp::RuntimeConfig runtime_config = {
|
||||||
|
.max_tokens = inference_args_.max_tokens,
|
||||||
|
.max_generated_tokens = inference_args_.max_generated_tokens,
|
||||||
|
.temperature = inference_args_.temperature,
|
||||||
|
.verbosity = app_.verbosity,
|
||||||
|
.gen = &gen,
|
||||||
|
.stream_token = stream_token,
|
||||||
|
.accept_token = accept_token,
|
||||||
|
};
|
||||||
|
model_->Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache_,
|
||||||
|
timing_info, /*layers_output=*/nullptr);
|
||||||
|
if (app_.verbosity >= 1) {
|
||||||
|
LogSpeedStats(time_start, total_tokens);
|
||||||
|
}
|
||||||
|
return {res, total_tokens};
|
||||||
|
}
|
||||||
|
|
||||||
|
void GemmaEnv::LogSpeedStats(double time_start, size_t total_tokens) const {
|
||||||
|
const double time_end = hwy::platform::Now();
|
||||||
|
const double time_elapsed = time_end - time_start;
|
||||||
|
const double tok_sec = total_tokens / time_elapsed;
|
||||||
|
std::cout << total_tokens << " tokens in " << time_elapsed << " seconds"
|
||||||
|
<< " [" << tok_sec << " tokens / sec" << "]\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <random>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "gemma/gemma.h"
|
||||||
|
#include "util/app.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
// Convenience class to load a model and run inference.
|
||||||
|
class GemmaEnv {
|
||||||
|
public:
|
||||||
|
GemmaEnv(int argc, char** argv);
|
||||||
|
|
||||||
|
// Sets the maximum number of output tokens to generate.
|
||||||
|
void set_max_generated_tokens(int max_tokens) {
|
||||||
|
inference_args_.max_generated_tokens = max_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Runs inference on the given input and returns the top-1 result string and
|
||||||
|
// the number of tokens that were generated.
|
||||||
|
std::pair<std::string, int> QueryModel(const std::string& input);
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Logs the inference speed in tokens/sec.
|
||||||
|
void LogSpeedStats(double time_start, size_t total_tokens) const;
|
||||||
|
|
||||||
|
// Arguments to the model loader: file locations, etc.
|
||||||
|
LoaderArgs loader_;
|
||||||
|
// Arguments to the inference function: max tokens, etc.
|
||||||
|
InferenceArgs inference_args_;
|
||||||
|
// Controls overall behavior of the app.
|
||||||
|
AppArgs app_;
|
||||||
|
// Thread pool for running inference.
|
||||||
|
hwy::ThreadPool pool_;
|
||||||
|
// Random number generator.
|
||||||
|
std::mt19937 gen_;
|
||||||
|
// The model to run inference on.
|
||||||
|
std::unique_ptr<Gemma> model_;
|
||||||
|
// The KV cache to use for inference.
|
||||||
|
KVCache kv_cache_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BENCHMARK_HELPER_H_
|
||||||
|
|
@ -0,0 +1,120 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <ostream>
|
||||||
|
#include <random>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
// Placeholder for internal header, do not modify.
|
||||||
|
#include "benchmark/benchmark.h"
|
||||||
|
#include "gemma/benchmark_helper.h"
|
||||||
|
|
||||||
|
void run_gemma_prompt(const std::string& prompt_string,
|
||||||
|
gcpp::GemmaEnv& env,
|
||||||
|
benchmark::State& state) {
|
||||||
|
std::mt19937 gen;
|
||||||
|
|
||||||
|
if (prompt_string.empty()) return;
|
||||||
|
|
||||||
|
int token_counter = 0;
|
||||||
|
for (auto s : state) {
|
||||||
|
auto [response, n] = env.QueryModel(prompt_string);
|
||||||
|
std::cout << "response: " << response << "\n";
|
||||||
|
std::cout << "n: " << n << "\n";
|
||||||
|
token_counter += n;
|
||||||
|
}
|
||||||
|
|
||||||
|
state.SetItemsProcessed(token_counter);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Awkward global because benchmarks don't support additional state, so it is
|
||||||
|
// either this or cast to int64_t.
|
||||||
|
gcpp::GemmaEnv* global_env = nullptr;
|
||||||
|
|
||||||
|
static void BM_short_prompt(benchmark::State& state) {
|
||||||
|
run_gemma_prompt("What is the capital of Spain?", *global_env,
|
||||||
|
state);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BM_factuality_prompt(benchmark::State& state) {
|
||||||
|
run_gemma_prompt("How does an inkjet printer work?",
|
||||||
|
*global_env, state);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BM_creative_prompt(benchmark::State& state) {
|
||||||
|
run_gemma_prompt(
|
||||||
|
"Tell me a story about a magical bunny and their TRS-80.",
|
||||||
|
*global_env, state);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BM_coding_prompt(benchmark::State& state) {
|
||||||
|
run_gemma_prompt(
|
||||||
|
"Write a python program to generate a fibonacci sequence.",
|
||||||
|
*global_env, state);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BM_long_coding_prompt(benchmark::State& state) {
|
||||||
|
std::ifstream t("benchmarks.cc", std::ios_base::in);
|
||||||
|
std::stringstream buffer;
|
||||||
|
buffer << t.rdbuf();
|
||||||
|
std::string prompt_string = buffer.str();
|
||||||
|
t.close();
|
||||||
|
|
||||||
|
run_gemma_prompt("Make improvements to the following code:\n " +
|
||||||
|
prompt_string, *global_env, state);
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
{
|
||||||
|
// Placeholder for internal init, do not modify.
|
||||||
|
}
|
||||||
|
gcpp::GemmaEnv env(argc, argv);
|
||||||
|
|
||||||
|
env.set_max_generated_tokens(128);
|
||||||
|
global_env = &env;
|
||||||
|
BENCHMARK(BM_short_prompt)
|
||||||
|
->Iterations(3)
|
||||||
|
->Unit(benchmark::kMillisecond)
|
||||||
|
->UseRealTime();
|
||||||
|
|
||||||
|
env.set_max_generated_tokens(256);
|
||||||
|
BENCHMARK(BM_factuality_prompt)
|
||||||
|
->Iterations(3)
|
||||||
|
->Unit(benchmark::kMillisecond)
|
||||||
|
->UseRealTime();
|
||||||
|
|
||||||
|
BENCHMARK(BM_creative_prompt)
|
||||||
|
->Iterations(3)
|
||||||
|
->Unit(benchmark::kMillisecond)
|
||||||
|
->UseRealTime();
|
||||||
|
|
||||||
|
BENCHMARK(BM_coding_prompt)
|
||||||
|
->Iterations(3)
|
||||||
|
->Unit(benchmark::kMillisecond)
|
||||||
|
->UseRealTime();
|
||||||
|
|
||||||
|
env.set_max_generated_tokens(1024);
|
||||||
|
BENCHMARK(BM_long_coding_prompt)
|
||||||
|
->Iterations(3)
|
||||||
|
->Unit(benchmark::kMillisecond)
|
||||||
|
->UseRealTime();
|
||||||
|
|
||||||
|
::benchmark ::RunSpecifiedBenchmarks();
|
||||||
|
::benchmark ::Shutdown();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue