mirror of https://github.com/google/gemma.cpp.git
parent
a3f7bf0991
commit
c4a75abe43
|
|
@ -605,11 +605,11 @@ cc_test(
|
|||
],
|
||||
deps = [
|
||||
":benchmark_helper",
|
||||
":configs",
|
||||
":gemma_lib",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:profiler",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,18 +19,12 @@
|
|||
#include <vector>
|
||||
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/nanobenchmark.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
||||
// This test can be run manually with the downloaded gemma weights.
|
||||
// To run the test, pass the following flags:
|
||||
// --tokenizer <tokenizer_path> --weights <weights_path>
|
||||
// It should pass for the following models:
|
||||
// Gemma2: gemma2-2b-it, 9b-it, 27b-it,
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
|
|
@ -43,59 +37,19 @@ class GemmaTest : public ::testing::Test {
|
|||
protected:
|
||||
std::vector<std::string> BatchGemmaReply(
|
||||
const std::vector<std::string>& inputs) {
|
||||
s_env->SetMaxGeneratedTokens(64);
|
||||
s_env->SetMaxGeneratedTokens(32);
|
||||
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
||||
s_env->MutableConfig().verbosity = 5;
|
||||
const ModelConfig& config = s_env->GetGemma()->GetModelConfig();
|
||||
s_env->MutableConfig().verbosity = 2;
|
||||
std::vector<std::string> replies;
|
||||
// Using the turn structure worsens results sometimes.
|
||||
// However, some models need the turn structure to work.
|
||||
// It would be good to make these tests more consistent.
|
||||
if (config.model == Model::GEMMA2_27B ||
|
||||
config.model == Model::GRIFFIN_2B) {
|
||||
for (const QueryResult& result : s_env->BatchQueryModel(inputs)) {
|
||||
replies.push_back(result.response);
|
||||
}
|
||||
return replies;
|
||||
}
|
||||
// Otherwise, do not use turn structure.
|
||||
std::vector<std::vector<int>> prompts_vector;
|
||||
prompts_vector.reserve(inputs.size());
|
||||
for (const auto& input_string : inputs) {
|
||||
prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string));
|
||||
}
|
||||
std::vector<PromptTokens> prompt_spans;
|
||||
prompt_spans.reserve(prompts_vector.size());
|
||||
for (const auto& prompt : prompts_vector) {
|
||||
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
|
||||
}
|
||||
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
|
||||
for (const QueryResult& result : s_env->BatchQueryModel(prompts)) {
|
||||
for (const QueryResult& result : s_env->BatchQueryModel(inputs)) {
|
||||
replies.push_back(result.response);
|
||||
}
|
||||
return replies;
|
||||
}
|
||||
|
||||
void GenerateTokens(const std::vector<std::string>& questions) {
|
||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||
|
||||
// Fills prompts round robin from `questions` until the desired batch size.
|
||||
std::vector<std::string> inputs;
|
||||
inputs.reserve(s_env->MutableConfig().decode_qbatch_size);
|
||||
size_t qpos = 0;
|
||||
for (size_t i = 0; i < inputs.capacity(); ++i) {
|
||||
inputs.push_back(questions[qpos++]);
|
||||
if (qpos == questions.size()) qpos = 0;
|
||||
}
|
||||
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(GemmaTest, RandomQuestionsBatched) {
|
||||
static std::vector<std::string> kQA = {
|
||||
const std::vector<std::string> questions = {
|
||||
{"Write me a poem about Australia?"},
|
||||
{"What's the history of Denmark?"},
|
||||
{"Write me a comedy story about the USA."},
|
||||
|
|
@ -129,8 +83,20 @@ TEST_F(GemmaTest, RandomQuestionsBatched) {
|
|||
{"Tell me about space travel."},
|
||||
{"Explain to me how electric cars work."},
|
||||
};
|
||||
s_env->MutableConfig().verbosity = 5;
|
||||
GenerateTokens(kQA);
|
||||
|
||||
// Fills prompts round robin from `questions` until the desired batch size.
|
||||
std::vector<std::string> inputs;
|
||||
inputs.reserve(s_env->MutableConfig().decode_qbatch_size);
|
||||
size_t qpos = 0;
|
||||
for (size_t i = 0; i < inputs.capacity(); ++i) {
|
||||
inputs.push_back(questions[qpos++]);
|
||||
if (qpos == questions.size()) qpos = 0;
|
||||
}
|
||||
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
||||
for (size_t i = 0; i < hwy::Unpredictable1(); ++i) {
|
||||
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
|
||||
}
|
||||
|
||||
PROFILER_PRINT_RESULTS();
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
|||
Loading…
Reference in New Issue