From a982ec1287131ac080205a32bdd71548331a2cb7 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 9 Apr 2024 04:45:02 -0700 Subject: [PATCH] Move code to gemma/ so we can remove error-prone copybara: comments. Also fix includes and Lint warnings. PiperOrigin-RevId: 623127487 --- BUILD.bazel | 34 +-- CMakeLists.txt | 15 +- compression/analyze.h | 5 - compression/blob_store.cc | 1 - compression/compress-inl.h | 5 - compression/compress.h | 5 - compression/nuq-inl.h | 3 - compression/nuq_test.cc | 3 - compression/sfp-inl.h | 1 - compression/sfp_test.cc | 3 - compression/stats.cc | 1 - compression/test_util.h | 2 - examples/hello_world/run.cc | 5 +- benchmark.cc => gemma/benchmark.cc | 9 +- gemma/benchmarks.cc | 137 +++++++++++ .../compress_weights.cc | 6 +- configs.h => gemma/configs.h | 7 +- gemma.cc => gemma/gemma.cc | 21 +- gemma.h => gemma/gemma.h | 22 +- gemma_test.cc => gemma/gemma_test.cc | 16 +- ops.h => gemma/ops.h | 8 +- ops_test.cc => gemma/ops_test.cc | 5 +- run.cc => gemma/run.cc | 10 +- gemma/run_csv.cc | 223 ++++++++++++++++++ util/app.h | 10 +- 25 files changed, 424 insertions(+), 133 deletions(-) rename benchmark.cc => gemma/benchmark.cc (98%) create mode 100644 gemma/benchmarks.cc rename compress_weights.cc => gemma/compress_weights.cc (96%) rename configs.h => gemma/configs.h (97%) rename gemma.cc => gemma/gemma.cc (99%) rename gemma.h => gemma/gemma.h (88%) rename gemma_test.cc => gemma/gemma_test.cc (98%) rename ops.h => gemma/ops.h (99%) rename ops_test.cc => gemma/ops_test.cc (99%) rename run.cc => gemma/run.cc (98%) create mode 100644 gemma/run_csv.cc diff --git a/BUILD.bazel b/BUILD.bazel index bbcaa2f..68a2026 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -22,9 +22,7 @@ exports_files(["LICENSE"]) cc_library( name = "ops", - hdrs = [ - "ops.h", - ], + hdrs = ["gemma/ops.h"], deps = [ "//compression:compress", "@hwy//:algo", @@ -41,7 +39,7 @@ cc_library( cc_test( name = "ops_test", size = "small", - srcs = ["ops_test.cc"], + srcs = ["gemma/ops_test.cc"], local_defines = ["HWY_IS_TEST"], # for test_suite. tags = ["hwy_ops_test"], @@ -55,9 +53,7 @@ cc_test( cc_library( name = "args", - hdrs = [ - "util/args.h", - ], + hdrs = ["util/args.h"], deps = [ "@hwy//:hwy", ], @@ -66,11 +62,11 @@ cc_library( cc_library( name = "gemma_lib", srcs = [ - "gemma.cc", + "gemma/gemma.cc", ], hdrs = [ - "configs.h", - "gemma.h", + "gemma/configs.h", + "gemma/gemma.h", ], deps = [ ":args", @@ -88,7 +84,7 @@ cc_library( cc_test( name = "gemma_test", - srcs = ["gemma_test.cc"], + srcs = ["gemma/gemma_test.cc"], # Requires model files tags = [ "local", @@ -107,9 +103,7 @@ cc_test( cc_library( name = "app", - hdrs = [ - "util/app.h", - ], + hdrs = ["util/app.h"], deps = [ ":args", ":gemma_lib", @@ -119,9 +113,7 @@ cc_library( cc_binary( name = "gemma", - srcs = [ - "run.cc", - ], + srcs = ["gemma/run.cc"], deps = [ ":app", ":args", @@ -137,9 +129,7 @@ cc_binary( cc_binary( name = "compress_weights", - srcs = [ - "compress_weights.cc", - ], + srcs = ["gemma/compress_weights.cc"], deps = [ ":args", ":gemma_lib", @@ -154,9 +144,7 @@ cc_binary( cc_binary( name = "benchmark", - srcs = [ - "benchmark.cc", - ], + srcs = ["gemma/benchmark.cc"], deps = [ ":app", ":args", diff --git a/CMakeLists.txt b/CMakeLists.txt index 00c47f9..203b6c1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,7 +34,6 @@ FetchContent_Declare(json GIT_REPOSITORY https://github.com/nlohmann/json.git GI FetchContent_MakeAvailable(json) set(SOURCES - gemma.cc compression/blob_store.cc compression/blob_store.h compression/compress.h @@ -44,6 +43,10 @@ set(SOURCES compression/sfp.h compression/sfp-inl.h compression/test_util.h + gemma/configs.h + gemma/gemma.cc + gemma/gemma.h + gemma/ops.h util/app.h util/args.h ) @@ -79,10 +82,10 @@ target_compile_options(libgemma PRIVATE $<$:-Wno-deprecated # Executable Target -add_executable(gemma run.cc) +add_executable(gemma gemma/run.cc) target_link_libraries(gemma libgemma hwy hwy_contrib) -add_executable(benchmark benchmark.cc) +add_executable(benchmark gemma/benchmark.cc) target_link_libraries(benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json) ## Tests @@ -90,8 +93,8 @@ set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests") if (GEMMA_ENABLE_TESTS) set(GEMMA_TEST_FILES - ops_test.cc - gemma_test.cc + gemma/ops_test.cc + gemma/gemma_test.cc ) foreach (TESTFILE IN LISTS GEMMA_TEST_FILES) @@ -112,5 +115,5 @@ endif() # GEMMA_ENABLE_TESTS ## Tools -add_executable(compress_weights compress_weights.cc) +add_executable(compress_weights gemma/compress_weights.cc) target_link_libraries(compress_weights libgemma hwy hwy_contrib) diff --git a/compression/analyze.h b/compression/analyze.h index d719aee..0859c6a 100644 --- a/compression/analyze.h +++ b/compression/analyze.h @@ -26,11 +26,8 @@ #include // std::abs #include -// copybara:import_next_line:gemma_cpp #include "compression/distortion.h" -// copybara:import_next_line:gemma_cpp #include "compression/nuq.h" -// copybara:import_next_line:gemma_cpp #include "compression/stats.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -46,9 +43,7 @@ #define THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE #endif -// copybara:import_next_line:gemma_cpp #include "compression/nuq-inl.h" -// copybara:import_next_line:gemma_cpp #include "compression/sfp-inl.h" #include "hwy/contrib/sort/vqsort-inl.h" #include "hwy/highway.h" diff --git a/compression/blob_store.cc b/compression/blob_store.cc index 2458fb9..7f03833 100644 --- a/compression/blob_store.cc +++ b/compression/blob_store.cc @@ -26,7 +26,6 @@ #undef _FILE_OFFSET_BITS #define _FILE_OFFSET_BITS 64 -// copybara:import_next_line:gemma_cpp #include "compression/blob_store.h" #include // open diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 5717545..516e9e3 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -23,11 +23,8 @@ #include -// copybara:import_next_line:gemma_cpp #include "compression/blob_store.h" -// copybara:import_next_line:gemma_cpp #include "compression/compress.h" -// copybara:import_next_line:gemma_cpp #include "compression/distortion.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" @@ -44,9 +41,7 @@ #define THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE #endif -// copybara:import_next_line:gemma_cpp #include "compression/nuq-inl.h" -// copybara:import_next_line:gemma_cpp #include "compression/sfp-inl.h" #include "hwy/contrib/dot/dot-inl.h" #include "hwy/highway.h" diff --git a/compression/compress.h b/compression/compress.h index 118ded2..5c9a3b2 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -27,19 +27,14 @@ #include // IWYU pragma: begin_exports -// copybara:import_next_line:gemma_cpp #include "compression/blob_store.h" -// copybara:import_next_line:gemma_cpp #include "compression/nuq.h" -// copybara:import_next_line:gemma_cpp #include "compression/sfp.h" // IWYU pragma: end_exports -// copybara:import_next_line:gemma_cpp #include "compression/distortion.h" #include "hwy/base.h" // hwy::bfloat16_t #include "hwy/contrib/thread_pool/thread_pool.h" #if COMPRESS_STATS -// copybara:import_next_line:gemma_cpp #include "compression/stats.h" #endif diff --git a/compression/nuq-inl.h b/compression/nuq-inl.h index 1c8bdf1..28b2eb9 100644 --- a/compression/nuq-inl.h +++ b/compression/nuq-inl.h @@ -20,9 +20,7 @@ #include #include -// copybara:import_next_line:gemma_cpp #include "compression/nuq.h" -// copybara:import_next_line:gemma_cpp #include "compression/sfp.h" #include "hwy/base.h" @@ -37,7 +35,6 @@ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE #endif -// copybara:import_next_line:gemma_cpp #include "compression/sfp-inl.h" #include "hwy/contrib/sort/vqsort-inl.h" #include "hwy/highway.h" diff --git a/compression/nuq_test.cc b/compression/nuq_test.cc index 8e34b4d..1f8de40 100644 --- a/compression/nuq_test.cc +++ b/compression/nuq_test.cc @@ -35,11 +35,8 @@ // clang-format on #include "hwy/foreach_target.h" // IWYU pragma: keep // Other headers that include Highway must come after foreach_target.h -// copybara:import_next_line:gemma_cpp #include "compression/nuq-inl.h" -// copybara:import_next_line:gemma_cpp #include "compression/nuq.h" -// copybara:import_next_line:gemma_cpp #include "compression/test_util.h" #include "hwy/highway.h" #include "hwy/tests/hwy_gtest.h" diff --git a/compression/sfp-inl.h b/compression/sfp-inl.h index 77f7ede..7152960 100644 --- a/compression/sfp-inl.h +++ b/compression/sfp-inl.h @@ -20,7 +20,6 @@ #include #include -// copybara:import_next_line:gemma_cpp #include "compression/sfp.h" #include "hwy/base.h" diff --git a/compression/sfp_test.cc b/compression/sfp_test.cc index 1a4e4ec..f7936a3 100644 --- a/compression/sfp_test.cc +++ b/compression/sfp_test.cc @@ -18,7 +18,6 @@ #define HWY_DISABLED_TARGETS HWY_SCALAR #endif -// copybara:import_next_line:gemma_cpp #include "compression/sfp.h" #include @@ -37,9 +36,7 @@ // clang-format on #include "hwy/foreach_target.h" // IWYU pragma: keep // Any highway.h must come after foreach_target.h -// copybara:import_next_line:gemma_cpp #include "compression/sfp-inl.h" -// copybara:import_next_line:gemma_cpp #include "compression/test_util.h" #include "hwy/highway.h" #include "hwy/tests/hwy_gtest.h" diff --git a/compression/stats.cc b/compression/stats.cc index 8e66119..0f4bf2d 100644 --- a/compression/stats.cc +++ b/compression/stats.cc @@ -13,7 +13,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// copybara:import_next_line:gemma_cpp #include "compression/stats.h" #include diff --git a/compression/test_util.h b/compression/test_util.h index b1e4026..0db9b2e 100644 --- a/compression/test_util.h +++ b/compression/test_util.h @@ -24,9 +24,7 @@ #include "hwy/base.h" // IWYU pragma: begin_exports -// copybara:import_next_line:gemma_cpp #include "compression/distortion.h" -// copybara:import_next_line:gemma_cpp #include "compression/stats.h" #include "hwy/tests/test_util.h" // RandomState // IWYU pragma: end_exports diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 60de786..9bb0d9f 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -15,12 +15,9 @@ #include -// copybara:import_next_line:gemma_cpp -#include "gemma.h" -// copybara:import_next_line:gemma_cpp +#include "third_party/gemma_cpp/gemma.h" #include "util/app.h" // LoaderArgs #include "hwy/contrib/thread_pool/thread_pool.h" -// copybara:import_next_line:gemma_cpp #include "util/args.h" std::vector tokenize(const std::string& prompt_string, diff --git a/benchmark.cc b/gemma/benchmark.cc similarity index 98% rename from benchmark.cc rename to gemma/benchmark.cc index 995b0b8..b5fb375 100644 --- a/benchmark.cc +++ b/gemma/benchmark.cc @@ -8,16 +8,13 @@ #include #include "nlohmann/json.hpp" -// copybara:import_next_line:gemma_cpp -#include "gemma.h" +#include "gemma/gemma.h" +#include "util/app.h" +#include "util/args.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" #include "hwy/timer.h" -// copybara:import_next_line:gemma_cpp -#include "util/app.h" -// copybara:import_next_line:gemma_cpp -#include "util/args.h" using json = nlohmann::json; diff --git a/gemma/benchmarks.cc b/gemma/benchmarks.cc new file mode 100644 index 0000000..fb2c6e3 --- /dev/null +++ b/gemma/benchmarks.cc @@ -0,0 +1,137 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "third_party/benchmark/include/benchmark/benchmark.h" +#include "gemma/gemma.h" +#include "util/app.h" +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +gcpp::LoaderArgs* loader = nullptr; +gcpp::InferenceArgs* inference = nullptr; +gcpp::Gemma* model = nullptr; +hwy::ThreadPool* pool = nullptr; +hwy::ThreadPool* inner_pool = nullptr; + +void run_gemma_prompt(const std::string& prompt_string, + benchmark::State& state) { + std::mt19937 gen; + std::vector prompt; + + if (prompt_string.empty()) return; + HWY_ASSERT(model->Tokenizer().Encode(prompt_string, &prompt).ok()); + + int token_counter = 0; + auto stream_token = [&token_counter](int, float) { + token_counter++; + return true; + }; + + for (auto s : state) { + GenerateGemma( + *model, *inference, prompt, /*start_token=*/0, *pool, *inner_pool, + stream_token, + /*accept=*/[](int) { return true; }, gen, /*verbosity=*/0); + } + + state.SetItemsProcessed(token_counter); +} + +static void BM_short_prompt(benchmark::State& state) { + run_gemma_prompt("What is the capital of Spain? ", state); +} + +static void BM_factuality_prompt(benchmark::State& state) { + run_gemma_prompt("How does an inkjet printer work? ", state); +} + +static void BM_creative_prompt(benchmark::State& state) { + run_gemma_prompt( + "Tell me a story about a magical bunny and their TRS-80. ", + state); +} + +static void BM_coding_prompt(benchmark::State& state) { + run_gemma_prompt( + "Write a python program to generate a fibonacci sequence. ", + state); +} + +static void BM_long_coding_prompt(benchmark::State& state) { + std::ifstream t("benchmarks.cc", std::ios_base::in); + std::stringstream buffer; + buffer << t.rdbuf(); + std::string prompt_string = buffer.str(); + t.close(); + + run_gemma_prompt("Make improvements to the following code:\n " + + prompt_string + " ", + state); +} + +int main(int argc, char** argv) { + loader = new gcpp::LoaderArgs(argc, argv); + inference = new gcpp::InferenceArgs(argc, argv); + gcpp::AppArgs app(argc, argv); + + pool = new ::hwy::ThreadPool(app.num_threads); + inner_pool = new ::hwy::ThreadPool(0); + model = new gcpp::Gemma(*loader, *pool); + + inference->max_tokens = 128; + BENCHMARK(BM_short_prompt) + ->Iterations(3) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + + inference->max_tokens = 256; + BENCHMARK(BM_factuality_prompt) + ->Iterations(3) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + + BENCHMARK(BM_creative_prompt) + ->Iterations(3) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + + BENCHMARK(BM_coding_prompt) + ->Iterations(3) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + + inference->max_tokens = 1024; + BENCHMARK(BM_long_coding_prompt) + ->Iterations(3) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + + ::benchmark ::RunSpecifiedBenchmarks(); + ::benchmark ::Shutdown(); + + delete loader; + delete inference; + delete model; + delete pool; + + return 0; +} diff --git a/compress_weights.cc b/gemma/compress_weights.cc similarity index 96% rename from compress_weights.cc rename to gemma/compress_weights.cc index ae8b088..776da50 100644 --- a/compress_weights.cc +++ b/gemma/compress_weights.cc @@ -18,12 +18,8 @@ #include #include -// copybara:import_next_line:gemma_cpp -#include "gemma.h" // Gemma -// copybara:end -// copybara:import_next_line:gemma_cpp +#include "gemma/gemma.h" // Gemma #include "util/args.h" -// copybara:end namespace gcpp { diff --git a/configs.h b/gemma/configs.h similarity index 97% rename from configs.h rename to gemma/configs.h index 98bcf12..9b82880 100644 --- a/configs.h +++ b/gemma/configs.h @@ -15,8 +15,8 @@ // Model configurations -#ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ -#define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_ // Allow changing pre-allocated kv cache size as a compiler flag #ifndef GEMMA_MAX_SEQLEN @@ -32,7 +32,6 @@ #include -// copybara:import_next_line:gemma_cpp #include "compression/sfp.h" #include "hwy/base.h" // hwy::bfloat16_t @@ -164,4 +163,4 @@ struct ConfigGriffin2B { } // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_ diff --git a/gemma.cc b/gemma/gemma.cc similarity index 99% rename from gemma.cc rename to gemma/gemma.cc index 80f7c17..9b8fd93 100644 --- a/gemma.cc +++ b/gemma/gemma.cc @@ -18,22 +18,18 @@ // Compiles this file for multiple architectures via "foreach_target.h", to // which we pass the filename via macro 'argument'. #undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "gemma.cc" // NOLINT +#define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT #include "hwy/foreach_target.h" // IWYU pragma: keep // Must come after foreach_target.h to avoid redefinition errors. -// copybara:import_next_line:gemma_cpp #include "compression/compress-inl.h" -// copybara:import_next_line:gemma_cpp -#include "ops.h" +#include "gemma/ops.h" +#include "util/args.h" // Path #include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/highway.h" #include "hwy/profiler.h" #include "hwy/timer.h" -// copybara:import_next_line:gemma_cpp -#include "util/args.h" // Path // copybara:import_next_line:sentencepiece #include "src/sentencepiece_processor.h" -// copybara:end // Non-SIMD includes and types. Note that HWY_ONCE is only true on the last // compile pass, whereas we want this defined in the first. @@ -53,21 +49,16 @@ #include #include #include -#include +#include // NOLINT #include #include -// copybara:import_next_line:gemma_cpp #include "compression/compress.h" -// copybara:import_next_line:gemma_cpp -#include "configs.h" -// copybara:import_next_line:gemma_cpp -#include "gemma.h" +#include "gemma/configs.h" +#include "gemma/gemma.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" -// copybara:import_next_line:sentencepiece -#include "src/sentencepiece_processor.h" // Setting this to true disables fread() calls that read the model file. constexpr bool kDryRunFread = false; diff --git a/gemma.h b/gemma/gemma.h similarity index 88% rename from gemma.h rename to gemma/gemma.h index 8ae3577..abe402c 100644 --- a/gemma.h +++ b/gemma/gemma.h @@ -13,25 +13,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_H_ +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_ #include #include #include +#include #include -// copybara:import_next_line:gemma_cpp -#include "compression/compress.h" // SfpStream/NuqStream -// copybara:import_next_line:gemma_cpp -#include "configs.h" +#include "gemma/configs.h" +#include "util/args.h" // Path #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t #include "hwy/contrib/thread_pool/thread_pool.h" -// copybara:import_next_line:gemma_cpp -#include "util/args.h" // Path -// copybara:import_next_line:sentencepiece -#include "src/sentencepiece_processor.h" namespace gcpp { @@ -71,6 +66,7 @@ struct GemmaInterface; class GemmaTokenizer { public: + virtual ~GemmaTokenizer() = default; virtual bool Encode(const std::string& input, std::vector* pieces) const = 0; virtual bool Encode(const std::string& input, @@ -82,7 +78,7 @@ class GemmaTokenizer { struct Gemma { Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, hwy::ThreadPool& pool); - ~Gemma(); // must be defined after GemmaInterface's dtor is defined. + ~Gemma(); // must be defined after the GemmaInterface dtor is defined. const GemmaTokenizer* Tokenizer() const; std::unique_ptr impl_; }; @@ -105,7 +101,7 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, // Convenience function for the common case: // - Bundle runtime parameters as RuntimeConfig -// - No threadpools within threadpools (inner_pool = dummy) +// - No ThreadPool within ThreadPool (inner_pool = dummy) // - All tokens accepted void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, const std::vector& prompt, size_t start_pos, @@ -124,4 +120,4 @@ constexpr int EOS_ID = 1; } // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_ +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_ diff --git a/gemma_test.cc b/gemma/gemma_test.cc similarity index 98% rename from gemma_test.cc rename to gemma/gemma_test.cc index 4601cc8..a842a9c 100644 --- a/gemma_test.cc +++ b/gemma/gemma_test.cc @@ -13,14 +13,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -// copybara:import_next_line:gemma_cpp -#include "gemma.h" +#include "gemma/gemma.h" -#include +#include +#include +#include +#include +#include // NOLINT +#include -// copybara:import_next_line:gemma_cpp -#include "ops.h" -// copybara:import_next_line:gemma_cpp +#include "gemma/ops.h" #include "util/args.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/tests/test_util-inl.h" @@ -79,7 +81,7 @@ class GemmaTest : public ::testing::Test { std::cout << "Question " << i + 1 << "\n\n"; std::string response = GemmaReply(kQA[i][0]); std::cout << response << "\n\n"; - EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); + EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT } } diff --git a/ops.h b/gemma/ops.h similarity index 99% rename from ops.h rename to gemma/ops.h index eff3c81..da6a38e 100644 --- a/ops.h +++ b/gemma/ops.h @@ -14,8 +14,9 @@ // limitations under the License. // Include guard for non-SIMD code. -#ifndef THIRD_PARTY_GEMMA_CPP_OPS_H_ -#define THIRD_PARTY_GEMMA_CPP_OPS_H_ +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_ + #include #include @@ -43,7 +44,7 @@ static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); } } // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_OPS_H_ +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_ // Include guard for (potentially) SIMD code. #if defined(THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE) == defined(HWY_TARGET_TOGGLE) @@ -53,7 +54,6 @@ static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); } #define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE #endif -// copybara:import_next_line:gemma_cpp #include "compression/compress-inl.h" #include "hwy/contrib/algo/transform-inl.h" #include "hwy/contrib/dot/dot-inl.h" diff --git a/ops_test.cc b/gemma/ops_test.cc similarity index 99% rename from ops_test.cc rename to gemma/ops_test.cc index d74ceb8..06ef6ef 100644 --- a/ops_test.cc +++ b/gemma/ops_test.cc @@ -25,14 +25,13 @@ // clang-format off #undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "ops_test.cc" //NOLINT +#define HWY_TARGET_INCLUDE "gemma/ops_test.cc" //NOLINT // clang-format on #include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/highway.h" #include "hwy/tests/test_util-inl.h" // After highway.h -// copybara:import_next_line:gemma_cpp -#include "ops.h" +#include "gemma/ops.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { diff --git a/run.cc b/gemma/run.cc similarity index 98% rename from run.cc rename to gemma/run.cc index 1a3fa0a..633e37c 100644 --- a/run.cc +++ b/gemma/run.cc @@ -23,20 +23,16 @@ #include // Placeholder for internal header, do not modify. -// copybara:import_next_line:gemma_cpp #include "compression/compress.h" -// copybara:import_next_line:gemma_cpp -#include "gemma.h" // Gemma +#include "gemma/gemma.h" // Gemma +#include "util/app.h" +#include "util/args.h" // HasHelp #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" #include "hwy/per_target.h" #include "hwy/profiler.h" #include "hwy/timer.h" -// copybara:import_next_line:gemma_cpp -#include "util/app.h" -// copybara:import_next_line:gemma_cpp -#include "util/args.h" // HasHelp static constexpr bool kVerboseLogTokens = false; diff --git a/gemma/run_csv.cc b/gemma/run_csv.cc new file mode 100644 index 0000000..cea3826 --- /dev/null +++ b/gemma/run_csv.cc @@ -0,0 +1,223 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Command line text interface to gemma. + +#include + +#include +#include +#include +#include + +#include "gemma/configs.h" +#include "gemma/gemma.h" +#include "util/app.h" +#include "util/args.h" // ArgsBase +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/highway.h" +#include "hwy/profiler.h" +#include "third_party/riegeli/bytes/file_reader.h" +#include "third_party/riegeli/bytes/file_writer.h" +#include "third_party/riegeli/csv/csv_reader.h" +#include "third_party/riegeli/csv/csv_writer.h" + +namespace gcpp { + +struct CsvArgs : public ArgsBase { + CsvArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + + Path input_csv; + Path output_csv; + int prompt_column; + + template + void ForEach(const Visitor& visitor) { + visitor(input_csv, "input_csv", Path(), + "When set, prompts will be read from this CSV."); + visitor(output_csv, "output_csv", Path("/tmp/output.csv"), + "When --input_csv is set, prompts will be written to this CSV."); + visitor(prompt_column, "prompt_column", 0, "Prompt column index"); + }; +}; + +void FileGemma(gcpp::Gemma& model, InferenceArgs& inference, AppArgs& app, + CsvArgs& csv, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const gcpp::AcceptFunc& accept_token) { + int abs_pos = 0; // absolute token index over all turns + int current_pos = 0; // token index within the current turn + int prompt_size{}; + + std::mt19937 gen; + if (inference.deterministic) { + gen.seed(42); + } else { + std::random_device rd; + gen.seed(rd()); + } + + std::stringstream response_stream; + + // callback function invoked for each generated token. + auto stream_token = [&inference, &abs_pos, ¤t_pos, &gen, &prompt_size, + tokenizer = &model.Tokenizer(), + &response_stream](int token, float) { + ++abs_pos; + ++current_pos; + if (current_pos < prompt_size) { + // pass + } else if (token == gcpp::EOS_ID) { + if (!inference.multiturn) { + abs_pos = 0; + if (inference.deterministic) { + gen.seed(42); + } + } + // end of stream + } else { + std::string token_text; + HWY_ASSERT(tokenizer->Decode({token}, &token_text).ok()); + // +1 since position is incremented above + if (current_pos == prompt_size + 1) { + // first token of response + token_text.erase(0, token_text.find_first_not_of(" \t\n")); + } + if (token_text != "\n") + response_stream << token_text; + else + response_stream << "\\n"; + } + return true; + }; + + riegeli::CsvReader csv_reader( + riegeli::FileReader(csv.input_csv.path), + riegeli::CsvReaderBase::Options().set_comment('#').set_recovery( + [](absl::Status status, riegeli::CsvReaderBase& csv_reader) { + fprintf(stderr, "Invalid entry: %s", status.message().data()); + return true; + })); + + riegeli::CsvWriter csv_writer( + riegeli::FileWriter(csv.output_csv.path), + riegeli::CsvWriterBase::Options().set_header({"prompt", "response"})); + + if (!csv_reader.ok()) { + HWY_ABORT("Invalid input CSV path %s", csv.input_csv.path.c_str()); + } + + if (!csv_writer.ok()) { + HWY_ABORT("Invalid output CSV path %s", csv.output_csv.path.c_str()); + } + + while (abs_pos < inference.max_tokens) { + std::string prompt_string; + std::vector prompt; + current_pos = 0; + + std::vector record; + csv_reader.ReadRecord(record); + + if (record.empty()) { + break; + } + + prompt_string = record[csv.prompt_column]; + fprintf(stdout, "Prompt: %s\n", prompt_string.c_str()); + + prompt_string = + "user\n" + prompt_string + "\nmodel\n"; + if (abs_pos > 0) { + // multi-turn dialogue continuation. + prompt_string = "\n" + prompt_string; + } else { + HWY_DASSERT(abs_pos == 0); + if (gcpp::kSystemPrompt) { + prompt_string = + "system\nYou are a large language model built by " + "Google.\n" + + prompt_string; + } + } + HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok()); + prompt_size = prompt.size(); + + // generate prompt + GenerateGemma(model, inference, prompt, abs_pos, pool, inner_pool, + stream_token, accept_token, gen, app.verbosity); + + std::string response_string = response_stream.str(); + if (!csv_writer.WriteRecord({record[csv.prompt_column], response_string})) { + fprintf(stderr, "Failed to write CSV: %s\n", + csv_writer.status().message().data()); + } + + response_stream.str(std::string()); // reset stream + response_stream.clear(); + abs_pos = 0; + } + + if (!csv_reader.Close()) { + fprintf(stderr, "Failed to close the CSV reader\n"); + } + if (!csv_writer.Close()) { + fprintf(stderr, "Failed to close the CSV writer\n"); + } +} + +void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app, + CsvArgs& csv) { + PROFILER_ZONE("Run.misc"); + + hwy::ThreadPool inner_pool(0); + hwy::ThreadPool pool(app.num_threads); + // For many-core, pinning threads to cores helps. + if (app.num_threads > 10) { + pool.Run(0, pool.NumThreads(), + [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); + } + + gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, + loader.ModelType(), loader.ModelTraining(), pool); + + if (csv.input_csv.path.empty()) { + HWY_ABORT("Need to specify csv file."); + } + + FileGemma(model, inference, app, csv, pool, inner_pool, + [](int) { return true; }); +} + +} // namespace gcpp + +int main(int argc, char** argv) { + { + PROFILER_ZONE("Startup.misc"); + gcpp::LoaderArgs loader(argc, argv); + gcpp::InferenceArgs inference(argc, argv); + gcpp::AppArgs app(argc, argv); + gcpp::CsvArgs csv(argc, argv); + + if (const char* error = loader.Validate()) { + loader.Help(); + HWY_ABORT("Invalid args: %s", error); + } + + gcpp::Run(loader, inference, app, csv); + } + PROFILER_PRINT_RESULTS(); // Must call outside the zone above. + return 0; +} diff --git a/util/app.h b/util/app.h index af26712..296ec9a 100644 --- a/util/app.h +++ b/util/app.h @@ -18,7 +18,6 @@ #ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ #define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ -#include #if HWY_OS_LINUX #include @@ -32,13 +31,10 @@ #include // std::clamp #include // NOLINT> -// copybara:import_next_line:gemma_cpp -#include "configs.h" -// copybara:import_next_line:gemma_cpp -#include "gemma.h" -#include "hwy/base.h" // HWY_ASSERT -// copybara:import_next_line:gemma_cpp +#include "gemma/configs.h" +#include "gemma/gemma.h" #include "util/args.h" +#include "hwy/base.h" // HWY_ASSERT namespace gcpp {