Declutter gemma/ directory, move binaries to evals/ and util/.

PiperOrigin-RevId: 648400795
This commit is contained in:
Jan Wassenberg 2024-07-01 09:50:27 -07:00 committed by Copybara-Service
parent e588a7f45d
commit af8eb2fde3
11 changed files with 178 additions and 23 deletions

View File

@ -128,8 +128,8 @@ cc_library(
cc_library( cc_library(
name = "cross_entropy", name = "cross_entropy",
srcs = ["gemma/cross_entropy.cc"], srcs = ["evals/cross_entropy.cc"],
hdrs = ["gemma/cross_entropy.h"], hdrs = ["evals/cross_entropy.h"],
deps = [ deps = [
":common", ":common",
":gemma_lib", ":gemma_lib",
@ -224,7 +224,7 @@ cc_binary(
cc_binary( cc_binary(
name = "compress_weights", name = "compress_weights",
srcs = ["gemma/compress_weights.cc"], srcs = ["util/compress_weights.cc"],
deps = [ deps = [
":args", ":args",
":common", ":common",
@ -242,7 +242,7 @@ cc_binary(
cc_binary( cc_binary(
name = "single_benchmark", name = "single_benchmark",
srcs = ["gemma/benchmark.cc"], srcs = ["evals/benchmark.cc"],
deps = [ deps = [
":app", ":app",
":args", ":args",
@ -260,7 +260,7 @@ cc_binary(
cc_binary( cc_binary(
name = "benchmarks", name = "benchmarks",
srcs = ["gemma/benchmarks.cc"], srcs = ["evals/benchmarks.cc"],
deps = [ deps = [
":benchmark_helper", ":benchmark_helper",
"@benchmark//:benchmark", "@benchmark//:benchmark",
@ -270,7 +270,7 @@ cc_binary(
cc_binary( cc_binary(
name = "debug_prompt", name = "debug_prompt",
srcs = [ srcs = [
"debug_prompt.cc", "evals/debug_prompt.cc",
], ],
deps = [ deps = [
":app", ":app",
@ -286,7 +286,7 @@ cc_binary(
cc_binary( cc_binary(
name = "gemma_mmlu", name = "gemma_mmlu",
srcs = ["gemma/run_mmlu.cc"], srcs = ["evals/run_mmlu.cc"],
deps = [ deps = [
":app", ":app",
":args", ":args",

View File

@ -60,14 +60,14 @@ set(SOURCES
backprop/forward_scalar.h backprop/forward_scalar.h
backprop/optimizer.cc backprop/optimizer.cc
backprop/optimizer.h backprop/optimizer.h
evals/cross_entropy.cc
evals/cross_entropy.h
gemma/configs.h gemma/configs.h
gemma/activations.h gemma/activations.h
gemma/benchmark_helper.cc gemma/benchmark_helper.cc
gemma/benchmark_helper.h gemma/benchmark_helper.h
gemma/common.cc gemma/common.cc
gemma/common.h gemma/common.h
gemma/cross_entropy.cc
gemma/cross_entropy.h
gemma/gemma.cc gemma/gemma.cc
gemma/gemma.h gemma/gemma.h
gemma/ops.h gemma/ops.h
@ -103,13 +103,13 @@ add_executable(gemma gemma/run.cc)
target_link_libraries(gemma libgemma hwy hwy_contrib) target_link_libraries(gemma libgemma hwy hwy_contrib)
install(TARGETS gemma DESTINATION bin) install(TARGETS gemma DESTINATION bin)
add_executable(single_benchmark gemma/benchmark.cc) add_executable(single_benchmark evals/benchmark.cc)
target_link_libraries(single_benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json) target_link_libraries(single_benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
add_executable(benchmarks gemma/benchmarks.cc) add_executable(benchmarks evals/benchmarks.cc)
target_link_libraries(benchmarks libgemma hwy hwy_contrib nlohmann_json::nlohmann_json benchmark) target_link_libraries(benchmarks libgemma hwy hwy_contrib nlohmann_json::nlohmann_json benchmark)
add_executable(debug_prompt debug_prompt.cc) add_executable(debug_prompt evals/debug_prompt.cc)
target_link_libraries(debug_prompt libgemma hwy hwy_contrib nlohmann_json::nlohmann_json) target_link_libraries(debug_prompt libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
## Tests ## Tests
@ -124,7 +124,7 @@ set(GEMMA_TEST_FILES
backprop/backward_scalar_test.cc backprop/backward_scalar_test.cc
backprop/optimize_test.cc backprop/optimize_test.cc
gemma/ops_test.cc gemma/ops_test.cc
gemma/gemma_test.cc evals/gemma_test.cc
) )
foreach (TESTFILE IN LISTS GEMMA_TEST_FILES) foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)
@ -145,5 +145,5 @@ endif() # GEMMA_ENABLE_TESTS
## Tools ## Tools
add_executable(compress_weights gemma/compress_weights.cc) add_executable(compress_weights util/compress_weights.cc)
target_link_libraries(compress_weights libgemma hwy hwy_contrib) target_link_libraries(compress_weights libgemma hwy hwy_contrib)

View File

@ -10,9 +10,9 @@
#include <vector> #include <vector>
#include "compression/io.h" // Path #include "compression/io.h" // Path
#include "evals/cross_entropy.h"
#include "gemma/benchmark_helper.h" #include "gemma/benchmark_helper.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/cross_entropy.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "util/args.h" #include "util/args.h"
#include "hwy/base.h" #include "hwy/base.h"

View File

@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "gemma/cross_entropy.h" #include "evals/cross_entropy.h"
#include <stddef.h> #include <stddef.h>
#include <stdio.h> #include <stdio.h>

View File

@ -13,8 +13,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_ #ifndef THIRD_PARTY_GEMMA_CPP_EVALS_CROSS_ENTROPY_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_ #define THIRD_PARTY_GEMMA_CPP_EVALS_CROSS_ENTROPY_H_
#include <stddef.h> #include <stddef.h>
@ -30,4 +30,4 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_ #endif // THIRD_PARTY_GEMMA_CPP_EVALS_CROSS_ENTROPY_H_

155
evals/gemma_test.cc Normal file
View File

@ -0,0 +1,155 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/gemma.h"
#include <stdio.h>
#include <string>
#include <vector>
#include "gemma/benchmark_helper.h"
#include "gemma/common.h"
#include "hwy/tests/hwy_gtest.h"
namespace gcpp {
namespace {
// Shared state. Requires argc/argv, so construct in main and use the same raw
// pointer approach as in benchmarks.cc. Note that the style guide forbids
// non-local static variables with dtors.
GemmaEnv* s_env = nullptr;
class GemmaTest : public ::testing::Test {
protected:
std::string GemmaReply(const std::string& prompt) {
s_env->SetMaxGeneratedTokens(2048);
s_env->MutableConfig().temperature = 0.0f; // deterministic
s_env->MutableConfig().verbosity = 0;
// Using the turn structure worsens results.
const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt);
auto [response, n] = s_env->QueryModel(tokens);
return response;
}
void TestQuestions(const char* kQA[][2], size_t num_questions) {
if (!s_env->GetModel()) return;
for (size_t i = 0; i < num_questions; ++i) {
fprintf(stderr, "Question %zu\n\n", i + 1);
std::string response = GemmaReply(kQA[i][0]);
fprintf(stderr, "'%s'\n\n", response.c_str());
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
}
}
};
TEST_F(GemmaTest, Geography) {
static const char* kQA[][2] = {
{"What is the capital of Hungary?", "Budapest"},
{"How many states does the US have?", "50"},
};
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
TestQuestions(kQA, kNum);
}
TEST_F(GemmaTest, History) {
static const char* kQA[][2] = {
{"When was the battle of Hastings?", "1066"},
};
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
TestQuestions(kQA, kNum);
}
TEST_F(GemmaTest, Arithmetic) {
static const char* kQA[][2] = {
{"what is 13 + 14?", "27"},
{"what is 7 * 8?", "56"},
};
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
TestQuestions(kQA, kNum);
}
static const char kJingleBells[] = R"(
Dashing through the snow
In a one-horse open sleigh
O'er the fields we go
Laughing all the way
Bells on bobtails ring
Making spirits bright
What fun it is to ride and sing
A sleighing song tonight
)";
// The "Hay Draft" of the Gettysburg Address.
static const char kGettysburg[] = {
"Four score and seven years ago our fathers brought forth, upon this "
"continent, a new nation, conceived in Liberty, and dedicated to the "
"proposition that all men are created equal.\n\nNow we are engaged in a "
"great civil war, testing whether that nation, or any nation, so "
"conceived, and so dedicated, can long endure. We are met here on a great "
"battlefield of that war. We have come to dedicate a portion of it as a "
"final resting place for those who here gave their lives that that nation "
"might live. It is altogether fitting and proper that we should do "
"this.\n\nBut in a larger sense we can not dedicate -- we can not "
"consecrate -- we can not hallow this ground. The brave men, living and "
"dead, who struggled, here, have consecrated it far above our poor power "
"to add or detract. The world will little note, nor long remember, what we "
"say here, but can never forget what they did here. It is for us, the "
"living, rather to be dedicated here to the unfinished work which they "
"have, thus far, so nobly carried on. It is rather for us to be here "
"dedicated to the great task remaining before us -- that from these "
"honored dead we take increased devotion to that cause for which they here "
"gave the last full measure of devotion -- that we here highly resolve "
"that these dead shall not have died in vain; that this nation shall have "
"a new birth of freedom; and that this government of the people, by the "
"people, for the people, shall not perish from the earth.\n"};
TEST_F(GemmaTest, CrossEntropySmall) {
if (!s_env->GetModel()) return;
static const char kSmall[] =
"The capital of Hungary is Budapest which is located in Europe.";
float entropy = s_env->CrossEntropy(kSmall);
fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy,
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 2.1f : 2.0f);
}
TEST_F(GemmaTest, CrossEntropyJingleBells) {
if (!s_env->GetModel()) return;
float entropy = s_env->CrossEntropy(kJingleBells);
fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy,
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 0.9f : 1.8f);
}
TEST_F(GemmaTest, CrossEntropyGettysburg) {
if (!s_env->GetModel()) return;
float entropy = s_env->CrossEntropy(kGettysburg);
fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy,
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 0.8f : 1.2f);
}
} // namespace
} // namespace gcpp
int main(int argc, char** argv) {
gcpp::GemmaEnv env(argc, argv);
gcpp::s_env = &env;
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -25,14 +25,14 @@
#include <ostream> #include <ostream>
#include <random> #include <random>
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <utility> // std::pair #include <utility> // std::pair
#include <vector> #include <vector>
// Placeholder for internal header, do not modify. // Placeholder for internal header, do not modify.
#include "compression/compress.h" // TypeName #include "compression/compress.h" // TypeName
#include "gemma/common.h" // StringFromType #include "evals/cross_entropy.h"
#include "gemma/cross_entropy.h" #include "gemma/common.h" // StringFromType
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "util/app.h" #include "util/app.h"
#include "util/args.h" #include "util/args.h"

View File

@ -19,7 +19,7 @@
// which we pass the filename via macro 'argument'. // which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \ #define HWY_TARGET_INCLUDE \
"gemma/compress_weights.cc" // NOLINT "util/compress_weights.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
// Must come after foreach_target.h to avoid redefinition errors. // Must come after foreach_target.h to avoid redefinition errors.
#include "compression/compress-inl.h" #include "compression/compress-inl.h"