mirror of https://github.com/google/gemma.cpp.git
parent
a3f7bf0991
commit
c4a75abe43
|
|
@ -605,11 +605,11 @@ cc_test(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":benchmark_helper",
|
":benchmark_helper",
|
||||||
":configs",
|
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
|
"@highway//:nanobenchmark",
|
||||||
"@highway//:profiler",
|
"@highway//:profiler",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -19,18 +19,12 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "evals/benchmark_helper.h"
|
#include "evals/benchmark_helper.h"
|
||||||
#include "gemma/configs.h"
|
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/nanobenchmark.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
|
||||||
// This test can be run manually with the downloaded gemma weights.
|
|
||||||
// To run the test, pass the following flags:
|
|
||||||
// --tokenizer <tokenizer_path> --weights <weights_path>
|
|
||||||
// It should pass for the following models:
|
|
||||||
// Gemma2: gemma2-2b-it, 9b-it, 27b-it,
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
|
@ -43,59 +37,19 @@ class GemmaTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
std::vector<std::string> BatchGemmaReply(
|
std::vector<std::string> BatchGemmaReply(
|
||||||
const std::vector<std::string>& inputs) {
|
const std::vector<std::string>& inputs) {
|
||||||
s_env->SetMaxGeneratedTokens(64);
|
s_env->SetMaxGeneratedTokens(32);
|
||||||
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
||||||
s_env->MutableConfig().verbosity = 5;
|
s_env->MutableConfig().verbosity = 2;
|
||||||
const ModelConfig& config = s_env->GetGemma()->GetModelConfig();
|
|
||||||
std::vector<std::string> replies;
|
std::vector<std::string> replies;
|
||||||
// Using the turn structure worsens results sometimes.
|
for (const QueryResult& result : s_env->BatchQueryModel(inputs)) {
|
||||||
// However, some models need the turn structure to work.
|
|
||||||
// It would be good to make these tests more consistent.
|
|
||||||
if (config.model == Model::GEMMA2_27B ||
|
|
||||||
config.model == Model::GRIFFIN_2B) {
|
|
||||||
for (const QueryResult& result : s_env->BatchQueryModel(inputs)) {
|
|
||||||
replies.push_back(result.response);
|
|
||||||
}
|
|
||||||
return replies;
|
|
||||||
}
|
|
||||||
// Otherwise, do not use turn structure.
|
|
||||||
std::vector<std::vector<int>> prompts_vector;
|
|
||||||
prompts_vector.reserve(inputs.size());
|
|
||||||
for (const auto& input_string : inputs) {
|
|
||||||
prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string));
|
|
||||||
}
|
|
||||||
std::vector<PromptTokens> prompt_spans;
|
|
||||||
prompt_spans.reserve(prompts_vector.size());
|
|
||||||
for (const auto& prompt : prompts_vector) {
|
|
||||||
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
|
|
||||||
}
|
|
||||||
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
|
|
||||||
for (const QueryResult& result : s_env->BatchQueryModel(prompts)) {
|
|
||||||
replies.push_back(result.response);
|
replies.push_back(result.response);
|
||||||
}
|
}
|
||||||
return replies;
|
return replies;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GenerateTokens(const std::vector<std::string>& questions) {
|
|
||||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
|
||||||
|
|
||||||
// Fills prompts round robin from `questions` until the desired batch size.
|
|
||||||
std::vector<std::string> inputs;
|
|
||||||
inputs.reserve(s_env->MutableConfig().decode_qbatch_size);
|
|
||||||
size_t qpos = 0;
|
|
||||||
for (size_t i = 0; i < inputs.capacity(); ++i) {
|
|
||||||
inputs.push_back(questions[qpos++]);
|
|
||||||
if (qpos == questions.size()) qpos = 0;
|
|
||||||
}
|
|
||||||
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
|
||||||
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(GemmaTest, RandomQuestionsBatched) {
|
TEST_F(GemmaTest, RandomQuestionsBatched) {
|
||||||
static std::vector<std::string> kQA = {
|
const std::vector<std::string> questions = {
|
||||||
{"Write me a poem about Australia?"},
|
{"Write me a poem about Australia?"},
|
||||||
{"What's the history of Denmark?"},
|
{"What's the history of Denmark?"},
|
||||||
{"Write me a comedy story about the USA."},
|
{"Write me a comedy story about the USA."},
|
||||||
|
|
@ -129,8 +83,20 @@ TEST_F(GemmaTest, RandomQuestionsBatched) {
|
||||||
{"Tell me about space travel."},
|
{"Tell me about space travel."},
|
||||||
{"Explain to me how electric cars work."},
|
{"Explain to me how electric cars work."},
|
||||||
};
|
};
|
||||||
s_env->MutableConfig().verbosity = 5;
|
|
||||||
GenerateTokens(kQA);
|
// Fills prompts round robin from `questions` until the desired batch size.
|
||||||
|
std::vector<std::string> inputs;
|
||||||
|
inputs.reserve(s_env->MutableConfig().decode_qbatch_size);
|
||||||
|
size_t qpos = 0;
|
||||||
|
for (size_t i = 0; i < inputs.capacity(); ++i) {
|
||||||
|
inputs.push_back(questions[qpos++]);
|
||||||
|
if (qpos == questions.size()) qpos = 0;
|
||||||
|
}
|
||||||
|
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
||||||
|
for (size_t i = 0; i < hwy::Unpredictable1(); ++i) {
|
||||||
|
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
|
||||||
|
}
|
||||||
|
|
||||||
PROFILER_PRINT_RESULTS();
|
PROFILER_PRINT_RESULTS();
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue