From d7b23d532a98b2e9b9b3a3f13603879e2da08206 Mon Sep 17 00:00:00 2001 From: Daniel Keysers Date: Tue, 10 Jun 2025 01:24:52 -0700 Subject: [PATCH] Restructure internal initialization. PiperOrigin-RevId: 769507096 --- BUILD.bazel | 2 +- evals/gemma_test.cc | 2 ++ gemma/gemma.cc | 5 ----- gemma/gemma.h | 1 + gemma/run.cc | 1 + io/BUILD.bazel | 1 + io/io.cc | 4 ++++ io/io.h | 4 ++++ paligemma/BUILD.bazel | 1 + paligemma/paligemma_test.cc | 6 ++++-- 10 files changed, 19 insertions(+), 8 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index eab1476..82e8878 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -501,7 +501,6 @@ cc_library( ":threading", ":threading_context", ":weights", - # Placeholder for internal dep, do not remove., "//io:blob_store", "//io", "//paligemma:image", @@ -588,6 +587,7 @@ cc_test( ":configs", ":gemma_lib", "@googletest//:gtest_main", # buildcleaner: keep + "//io", "@highway//:hwy", "@highway//:hwy_test_util", ], diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 5d29f4f..22958d9 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -22,6 +22,7 @@ #include "evals/benchmark_helper.h" #include "gemma/configs.h" +#include "io/io.h" #include "hwy/base.h" #include "hwy/tests/hwy_gtest.h" @@ -175,6 +176,7 @@ TEST_F(GemmaTest, CrossEntropySmall) { int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); + gcpp::InternalInit(); gcpp::GemmaTest::InitEnv(argc, argv); int ret = RUN_ALL_TESTS(); gcpp::GemmaTest::DeleteEnv(); diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 3f7c149..2bebb11 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -43,7 +43,6 @@ #include -// Placeholder for internal header, do not modify. #include "gemma/configs.h" #include "gemma/model_store.h" #include "gemma/tokenizer.h" @@ -636,11 +635,7 @@ HWY_EXPORT(GenerateSingleT); HWY_EXPORT(GenerateBatchT); HWY_EXPORT(GenerateImageTokensT); -// Internal init must run before I/O. This helper function takes care of that, -// plus calling `SetArgs`. MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) { - // Placeholder for internal init, do not modify. - ThreadingContext::SetArgs(threading_args); return MatMulEnv(ThreadingContext::Get()); } diff --git a/gemma/gemma.h b/gemma/gemma.h index 99936f5..21e9619 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -97,6 +97,7 @@ struct TimingInfo { size_t tokens_generated = 0; }; +// Returns the `MatMulEnv` after calling `SetArgs`. MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args); using KVCaches = hwy::Span; diff --git a/gemma/run.cc b/gemma/run.cc index bacae8f..49432dd 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -294,6 +294,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading, } // namespace gcpp int main(int argc, char** argv) { + gcpp::InternalInit(); { PROFILER_ZONE("Startup.misc"); diff --git a/io/BUILD.bazel b/io/BUILD.bazel index cd02c78..1ca42a4 100644 --- a/io/BUILD.bazel +++ b/io/BUILD.bazel @@ -40,6 +40,7 @@ cc_library( "//conditions:default": [], }), deps = [ + # Placeholder for internal dep, do not remove., "//:allocator", "@highway//:hwy", ] + FILE_DEPS, diff --git a/io/io.cc b/io/io.cc index eea3bd6..1f7eda6 100644 --- a/io/io.cc +++ b/io/io.cc @@ -72,6 +72,7 @@ #include // O_RDONLY #include // read, write, close +// Placeholder for internal header, do not modify. #include "util/allocator.h" namespace gcpp { @@ -218,6 +219,9 @@ bool IOBatch::Add(void* mem, size_t bytes) { return true; } +void InternalInit() { +} + uint64_t IOBatch::Read(const File& file) const { #if GEMMA_IO_PREADV HWY_ASSERT(!spans_.empty()); diff --git a/io/io.h b/io/io.h index 027a4ba..a14889b 100644 --- a/io/io.h +++ b/io/io.h @@ -144,6 +144,10 @@ struct Path { // Aborts on error. std::string ReadFileToString(const Path& path); +// No-op in open-source. Must be called at the beginning of a binary, before +// any I/O or flag usage. +void InternalInit(); + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_ diff --git a/paligemma/BUILD.bazel b/paligemma/BUILD.bazel index 0a93bb0..d60c745 100644 --- a/paligemma/BUILD.bazel +++ b/paligemma/BUILD.bazel @@ -45,6 +45,7 @@ cc_test( "//:configs", "//:gemma_lib", "//compression:types", + "//io", "@highway//:hwy", "@highway//:hwy_test_util", ], diff --git a/paligemma/paligemma_test.cc b/paligemma/paligemma_test.cc index f5aebef..e681e00 100644 --- a/paligemma/paligemma_test.cc +++ b/paligemma/paligemma_test.cc @@ -22,6 +22,7 @@ #include "evals/benchmark_helper.h" #include "gemma/configs.h" #include "gemma/gemma.h" +#include "io/io.h" #include "util/allocator.h" #include "hwy/base.h" #include "hwy/tests/hwy_gtest.h" @@ -120,10 +121,11 @@ TEST_F(PaliGemmaTest, QueryObjects) { } // namespace gcpp int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + gcpp::InternalInit(); + gcpp::GemmaEnv env(argc, argv); gcpp::s_env = &env; - testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); }