Move code to gemma/ so we can remove error-prone copybara: comments.

Also fix includes and Lint warnings.

PiperOrigin-RevId: 623127487
This commit is contained in:
Jan Wassenberg 2024-04-09 04:45:02 -07:00 committed by Copybara-Service
parent 83dd08ac87
commit a982ec1287
25 changed files with 424 additions and 133 deletions

View File

@ -22,9 +22,7 @@ exports_files(["LICENSE"])
cc_library(
name = "ops",
hdrs = [
"ops.h",
],
hdrs = ["gemma/ops.h"],
deps = [
"//compression:compress",
"@hwy//:algo",
@ -41,7 +39,7 @@ cc_library(
cc_test(
name = "ops_test",
size = "small",
srcs = ["ops_test.cc"],
srcs = ["gemma/ops_test.cc"],
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
@ -55,9 +53,7 @@ cc_test(
cc_library(
name = "args",
hdrs = [
"util/args.h",
],
hdrs = ["util/args.h"],
deps = [
"@hwy//:hwy",
],
@ -66,11 +62,11 @@ cc_library(
cc_library(
name = "gemma_lib",
srcs = [
"gemma.cc",
"gemma/gemma.cc",
],
hdrs = [
"configs.h",
"gemma.h",
"gemma/configs.h",
"gemma/gemma.h",
],
deps = [
":args",
@ -88,7 +84,7 @@ cc_library(
cc_test(
name = "gemma_test",
srcs = ["gemma_test.cc"],
srcs = ["gemma/gemma_test.cc"],
# Requires model files
tags = [
"local",
@ -107,9 +103,7 @@ cc_test(
cc_library(
name = "app",
hdrs = [
"util/app.h",
],
hdrs = ["util/app.h"],
deps = [
":args",
":gemma_lib",
@ -119,9 +113,7 @@ cc_library(
cc_binary(
name = "gemma",
srcs = [
"run.cc",
],
srcs = ["gemma/run.cc"],
deps = [
":app",
":args",
@ -137,9 +129,7 @@ cc_binary(
cc_binary(
name = "compress_weights",
srcs = [
"compress_weights.cc",
],
srcs = ["gemma/compress_weights.cc"],
deps = [
":args",
":gemma_lib",
@ -154,9 +144,7 @@ cc_binary(
cc_binary(
name = "benchmark",
srcs = [
"benchmark.cc",
],
srcs = ["gemma/benchmark.cc"],
deps = [
":app",
":args",

View File

@ -34,7 +34,6 @@ FetchContent_Declare(json GIT_REPOSITORY https://github.com/nlohmann/json.git GI
FetchContent_MakeAvailable(json)
set(SOURCES
gemma.cc
compression/blob_store.cc
compression/blob_store.h
compression/compress.h
@ -44,6 +43,10 @@ set(SOURCES
compression/sfp.h
compression/sfp-inl.h
compression/test_util.h
gemma/configs.h
gemma/gemma.cc
gemma/gemma.h
gemma/ops.h
util/app.h
util/args.h
)
@ -79,10 +82,10 @@ target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated
# Executable Target
add_executable(gemma run.cc)
add_executable(gemma gemma/run.cc)
target_link_libraries(gemma libgemma hwy hwy_contrib)
add_executable(benchmark benchmark.cc)
add_executable(benchmark gemma/benchmark.cc)
target_link_libraries(benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
## Tests
@ -90,8 +93,8 @@ set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests")
if (GEMMA_ENABLE_TESTS)
set(GEMMA_TEST_FILES
ops_test.cc
gemma_test.cc
gemma/ops_test.cc
gemma/gemma_test.cc
)
foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)
@ -112,5 +115,5 @@ endif() # GEMMA_ENABLE_TESTS
## Tools
add_executable(compress_weights compress_weights.cc)
add_executable(compress_weights gemma/compress_weights.cc)
target_link_libraries(compress_weights libgemma hwy hwy_contrib)

View File

@ -26,11 +26,8 @@
#include <cstdlib> // std::abs
#include <vector>
// copybara:import_next_line:gemma_cpp
#include "compression/distortion.h"
// copybara:import_next_line:gemma_cpp
#include "compression/nuq.h"
// copybara:import_next_line:gemma_cpp
#include "compression/stats.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -46,9 +43,7 @@
#define THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
#endif
// copybara:import_next_line:gemma_cpp
#include "compression/nuq-inl.h"
// copybara:import_next_line:gemma_cpp
#include "compression/sfp-inl.h"
#include "hwy/contrib/sort/vqsort-inl.h"
#include "hwy/highway.h"

View File

@ -26,7 +26,6 @@
#undef _FILE_OFFSET_BITS
#define _FILE_OFFSET_BITS 64
// copybara:import_next_line:gemma_cpp
#include "compression/blob_store.h"
#include <fcntl.h> // open

View File

@ -23,11 +23,8 @@
#include <array>
// copybara:import_next_line:gemma_cpp
#include "compression/blob_store.h"
// copybara:import_next_line:gemma_cpp
#include "compression/compress.h"
// copybara:import_next_line:gemma_cpp
#include "compression/distortion.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
@ -44,9 +41,7 @@
#define THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE
#endif
// copybara:import_next_line:gemma_cpp
#include "compression/nuq-inl.h"
// copybara:import_next_line:gemma_cpp
#include "compression/sfp-inl.h"
#include "hwy/contrib/dot/dot-inl.h"
#include "hwy/highway.h"

View File

@ -27,19 +27,14 @@
#include <vector>
// IWYU pragma: begin_exports
// copybara:import_next_line:gemma_cpp
#include "compression/blob_store.h"
// copybara:import_next_line:gemma_cpp
#include "compression/nuq.h"
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h"
// IWYU pragma: end_exports
// copybara:import_next_line:gemma_cpp
#include "compression/distortion.h"
#include "hwy/base.h" // hwy::bfloat16_t
#include "hwy/contrib/thread_pool/thread_pool.h"
#if COMPRESS_STATS
// copybara:import_next_line:gemma_cpp
#include "compression/stats.h"
#endif

View File

@ -20,9 +20,7 @@
#include <stddef.h>
#include <stdint.h>
// copybara:import_next_line:gemma_cpp
#include "compression/nuq.h"
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h"
#include "hwy/base.h"
@ -37,7 +35,6 @@
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE
#endif
// copybara:import_next_line:gemma_cpp
#include "compression/sfp-inl.h"
#include "hwy/contrib/sort/vqsort-inl.h"
#include "hwy/highway.h"

View File

@ -35,11 +35,8 @@
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
// Other headers that include Highway must come after foreach_target.h
// copybara:import_next_line:gemma_cpp
#include "compression/nuq-inl.h"
// copybara:import_next_line:gemma_cpp
#include "compression/nuq.h"
// copybara:import_next_line:gemma_cpp
#include "compression/test_util.h"
#include "hwy/highway.h"
#include "hwy/tests/hwy_gtest.h"

View File

@ -20,7 +20,6 @@
#include <stddef.h>
#include <stdint.h>
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h"
#include "hwy/base.h"

View File

@ -18,7 +18,6 @@
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h"
#include <stddef.h>
@ -37,9 +36,7 @@
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
// Any highway.h must come after foreach_target.h
// copybara:import_next_line:gemma_cpp
#include "compression/sfp-inl.h"
// copybara:import_next_line:gemma_cpp
#include "compression/test_util.h"
#include "hwy/highway.h"
#include "hwy/tests/hwy_gtest.h"

View File

@ -13,7 +13,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// copybara:import_next_line:gemma_cpp
#include "compression/stats.h"
#include <stdio.h>

View File

@ -24,9 +24,7 @@
#include "hwy/base.h"
// IWYU pragma: begin_exports
// copybara:import_next_line:gemma_cpp
#include "compression/distortion.h"
// copybara:import_next_line:gemma_cpp
#include "compression/stats.h"
#include "hwy/tests/test_util.h" // RandomState
// IWYU pragma: end_exports

View File

@ -15,12 +15,9 @@
#include <iostream>
// copybara:import_next_line:gemma_cpp
#include "gemma.h"
// copybara:import_next_line:gemma_cpp
#include "third_party/gemma_cpp/gemma.h"
#include "util/app.h" // LoaderArgs
#include "hwy/contrib/thread_pool/thread_pool.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h"
std::vector<int> tokenize(const std::string& prompt_string,

View File

@ -8,16 +8,13 @@
#include <vector>
#include "nlohmann/json.hpp"
// copybara:import_next_line:gemma_cpp
#include "gemma.h"
#include "gemma/gemma.h"
#include "util/app.h"
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/timer.h"
// copybara:import_next_line:gemma_cpp
#include "util/app.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h"
using json = nlohmann::json;

137
gemma/benchmarks.cc Normal file
View File

@ -0,0 +1,137 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <iostream>
#include <random>
#include <sstream>
#include <string>
#include <vector>
#include "third_party/benchmark/include/benchmark/benchmark.h"
#include "gemma/gemma.h"
#include "util/app.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
gcpp::LoaderArgs* loader = nullptr;
gcpp::InferenceArgs* inference = nullptr;
gcpp::Gemma* model = nullptr;
hwy::ThreadPool* pool = nullptr;
hwy::ThreadPool* inner_pool = nullptr;
void run_gemma_prompt(const std::string& prompt_string,
benchmark::State& state) {
std::mt19937 gen;
std::vector<int> prompt;
if (prompt_string.empty()) return;
HWY_ASSERT(model->Tokenizer().Encode(prompt_string, &prompt).ok());
int token_counter = 0;
auto stream_token = [&token_counter](int, float) {
token_counter++;
return true;
};
for (auto s : state) {
GenerateGemma(
*model, *inference, prompt, /*start_token=*/0, *pool, *inner_pool,
stream_token,
/*accept=*/[](int) { return true; }, gen, /*verbosity=*/0);
}
state.SetItemsProcessed(token_counter);
}
static void BM_short_prompt(benchmark::State& state) {
run_gemma_prompt("What is the capital of Spain?<ctrl23> ", state);
}
static void BM_factuality_prompt(benchmark::State& state) {
run_gemma_prompt("How does an inkjet printer work?<ctrl23> ", state);
}
static void BM_creative_prompt(benchmark::State& state) {
run_gemma_prompt(
"Tell me a story about a magical bunny and their TRS-80.<ctrl23> ",
state);
}
static void BM_coding_prompt(benchmark::State& state) {
run_gemma_prompt(
"Write a python program to generate a fibonacci sequence.<ctrl23> ",
state);
}
static void BM_long_coding_prompt(benchmark::State& state) {
std::ifstream t("benchmarks.cc", std::ios_base::in);
std::stringstream buffer;
buffer << t.rdbuf();
std::string prompt_string = buffer.str();
t.close();
run_gemma_prompt("Make improvements to the following code:\n " +
prompt_string + "<ctrl23> ",
state);
}
int main(int argc, char** argv) {
loader = new gcpp::LoaderArgs(argc, argv);
inference = new gcpp::InferenceArgs(argc, argv);
gcpp::AppArgs app(argc, argv);
pool = new ::hwy::ThreadPool(app.num_threads);
inner_pool = new ::hwy::ThreadPool(0);
model = new gcpp::Gemma(*loader, *pool);
inference->max_tokens = 128;
BENCHMARK(BM_short_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
inference->max_tokens = 256;
BENCHMARK(BM_factuality_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
BENCHMARK(BM_creative_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
BENCHMARK(BM_coding_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
inference->max_tokens = 1024;
BENCHMARK(BM_long_coding_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
::benchmark ::RunSpecifiedBenchmarks();
::benchmark ::Shutdown();
delete loader;
delete inference;
delete model;
delete pool;
return 0;
}

View File

@ -18,12 +18,8 @@
#include <iostream>
#include <string>
// copybara:import_next_line:gemma_cpp
#include "gemma.h" // Gemma
// copybara:end
// copybara:import_next_line:gemma_cpp
#include "gemma/gemma.h" // Gemma
#include "util/args.h"
// copybara:end
namespace gcpp {

View File

@ -15,8 +15,8 @@
// Model configurations
#ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
#define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_
// Allow changing pre-allocated kv cache size as a compiler flag
#ifndef GEMMA_MAX_SEQLEN
@ -32,7 +32,6 @@
#include <array>
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h"
#include "hwy/base.h" // hwy::bfloat16_t
@ -164,4 +163,4 @@ struct ConfigGriffin2B {
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_

View File

@ -18,22 +18,18 @@
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma.cc" // NOLINT
#define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
// Must come after foreach_target.h to avoid redefinition errors.
// copybara:import_next_line:gemma_cpp
#include "compression/compress-inl.h"
// copybara:import_next_line:gemma_cpp
#include "ops.h"
#include "gemma/ops.h"
#include "util/args.h" // Path
#include "hwy/contrib/matvec/matvec-inl.h"
#include "hwy/highway.h"
#include "hwy/profiler.h"
#include "hwy/timer.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // Path
// copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h"
// copybara:end
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
// compile pass, whereas we want this defined in the first.
@ -53,21 +49,16 @@
#include <iostream>
#include <memory>
#include <random>
#include <regex>
#include <regex> // NOLINT
#include <string>
#include <vector>
// copybara:import_next_line:gemma_cpp
#include "compression/compress.h"
// copybara:import_next_line:gemma_cpp
#include "configs.h"
// copybara:import_next_line:gemma_cpp
#include "gemma.h"
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h"
// Setting this to true disables fread() calls that read the model file.
constexpr bool kDryRunFread = false;

View File

@ -13,25 +13,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_H_
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
#include <functional>
#include <memory>
#include <random>
#include <string>
#include <vector>
// copybara:import_next_line:gemma_cpp
#include "compression/compress.h" // SfpStream/NuqStream
// copybara:import_next_line:gemma_cpp
#include "configs.h"
#include "gemma/configs.h"
#include "util/args.h" // Path
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::bfloat16_t
#include "hwy/contrib/thread_pool/thread_pool.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // Path
// copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h"
namespace gcpp {
@ -71,6 +66,7 @@ struct GemmaInterface;
class GemmaTokenizer {
public:
virtual ~GemmaTokenizer() = default;
virtual bool Encode(const std::string& input,
std::vector<std::string>* pieces) const = 0;
virtual bool Encode(const std::string& input,
@ -82,7 +78,7 @@ class GemmaTokenizer {
struct Gemma {
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
~Gemma(); // must be defined after the GemmaInterface dtor is defined.
const GemmaTokenizer* Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_;
};
@ -105,7 +101,7 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
// Convenience function for the common case:
// - Bundle runtime parameters as RuntimeConfig
// - No threadpools within threadpools (inner_pool = dummy)
// - No ThreadPool within ThreadPool (inner_pool = dummy)
// - All tokens accepted
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
const std::vector<int>& prompt, size_t start_pos,
@ -124,4 +120,4 @@ constexpr int EOS_ID = 1;
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_

View File

@ -13,14 +13,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// copybara:import_next_line:gemma_cpp
#include "gemma.h"
#include "gemma/gemma.h"
#include <thread>
#include <algorithm>
#include <iostream>
#include <random>
#include <string>
#include <thread> // NOLINT
#include <vector>
// copybara:import_next_line:gemma_cpp
#include "ops.h"
// copybara:import_next_line:gemma_cpp
#include "gemma/ops.h"
#include "util/args.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/tests/test_util-inl.h"
@ -79,7 +81,7 @@ class GemmaTest : public ::testing::Test {
std::cout << "Question " << i + 1 << "\n\n";
std::string response = GemmaReply(kQA[i][0]);
std::cout << response << "\n\n";
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos);
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
}
}

View File

@ -14,8 +14,9 @@
// limitations under the License.
// Include guard for non-SIMD code.
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_H_
#define THIRD_PARTY_GEMMA_CPP_OPS_H_
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
#include <stddef.h>
#include <stdint.h>
@ -43,7 +44,7 @@ static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_OPS_H_
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE) == defined(HWY_TARGET_TOGGLE)
@ -53,7 +54,6 @@ static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
#define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
#endif
// copybara:import_next_line:gemma_cpp
#include "compression/compress-inl.h"
#include "hwy/contrib/algo/transform-inl.h"
#include "hwy/contrib/dot/dot-inl.h"

View File

@ -25,14 +25,13 @@
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "ops_test.cc" //NOLINT
#define HWY_TARGET_INCLUDE "gemma/ops_test.cc" //NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
#include "hwy/tests/test_util-inl.h"
// After highway.h
// copybara:import_next_line:gemma_cpp
#include "ops.h"
#include "gemma/ops.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {

View File

@ -23,20 +23,16 @@
#include <vector>
// Placeholder for internal header, do not modify.
// copybara:import_next_line:gemma_cpp
#include "compression/compress.h"
// copybara:import_next_line:gemma_cpp
#include "gemma.h" // Gemma
#include "gemma/gemma.h" // Gemma
#include "util/app.h"
#include "util/args.h" // HasHelp
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/per_target.h"
#include "hwy/profiler.h"
#include "hwy/timer.h"
// copybara:import_next_line:gemma_cpp
#include "util/app.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // HasHelp
static constexpr bool kVerboseLogTokens = false;

223
gemma/run_csv.cc Normal file
View File

@ -0,0 +1,223 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Command line text interface to gemma.
#include <stdio.h>
#include <random>
#include <sstream>
#include <string>
#include <vector>
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "util/app.h"
#include "util/args.h" // ArgsBase
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/profiler.h"
#include "third_party/riegeli/bytes/file_reader.h"
#include "third_party/riegeli/bytes/file_writer.h"
#include "third_party/riegeli/csv/csv_reader.h"
#include "third_party/riegeli/csv/csv_writer.h"
namespace gcpp {
struct CsvArgs : public ArgsBase<CsvArgs> {
CsvArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
Path input_csv;
Path output_csv;
int prompt_column;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(input_csv, "input_csv", Path(),
"When set, prompts will be read from this CSV.");
visitor(output_csv, "output_csv", Path("/tmp/output.csv"),
"When --input_csv is set, prompts will be written to this CSV.");
visitor(prompt_column, "prompt_column", 0, "Prompt column index");
};
};
void FileGemma(gcpp::Gemma& model, InferenceArgs& inference, AppArgs& app,
CsvArgs& csv, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const gcpp::AcceptFunc& accept_token) {
int abs_pos = 0; // absolute token index over all turns
int current_pos = 0; // token index within the current turn
int prompt_size{};
std::mt19937 gen;
if (inference.deterministic) {
gen.seed(42);
} else {
std::random_device rd;
gen.seed(rd());
}
std::stringstream response_stream;
// callback function invoked for each generated token.
auto stream_token = [&inference, &abs_pos, &current_pos, &gen, &prompt_size,
tokenizer = &model.Tokenizer(),
&response_stream](int token, float) {
++abs_pos;
++current_pos;
if (current_pos < prompt_size) {
// pass
} else if (token == gcpp::EOS_ID) {
if (!inference.multiturn) {
abs_pos = 0;
if (inference.deterministic) {
gen.seed(42);
}
}
// end of stream
} else {
std::string token_text;
HWY_ASSERT(tokenizer->Decode({token}, &token_text).ok());
// +1 since position is incremented above
if (current_pos == prompt_size + 1) {
// first token of response
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
}
if (token_text != "\n")
response_stream << token_text;
else
response_stream << "\\n";
}
return true;
};
riegeli::CsvReader csv_reader(
riegeli::FileReader(csv.input_csv.path),
riegeli::CsvReaderBase::Options().set_comment('#').set_recovery(
[](absl::Status status, riegeli::CsvReaderBase& csv_reader) {
fprintf(stderr, "Invalid entry: %s", status.message().data());
return true;
}));
riegeli::CsvWriter csv_writer(
riegeli::FileWriter(csv.output_csv.path),
riegeli::CsvWriterBase::Options().set_header({"prompt", "response"}));
if (!csv_reader.ok()) {
HWY_ABORT("Invalid input CSV path %s", csv.input_csv.path.c_str());
}
if (!csv_writer.ok()) {
HWY_ABORT("Invalid output CSV path %s", csv.output_csv.path.c_str());
}
while (abs_pos < inference.max_tokens) {
std::string prompt_string;
std::vector<int> prompt;
current_pos = 0;
std::vector<std::string> record;
csv_reader.ReadRecord(record);
if (record.empty()) {
break;
}
prompt_string = record[csv.prompt_column];
fprintf(stdout, "Prompt: %s\n", prompt_string.c_str());
prompt_string =
"<ctrl99>user\n" + prompt_string + "<ctrl100>\n<ctrl99>model\n";
if (abs_pos > 0) {
// multi-turn dialogue continuation.
prompt_string = "<ctrl100>\n" + prompt_string;
} else {
HWY_DASSERT(abs_pos == 0);
if (gcpp::kSystemPrompt) {
prompt_string =
"<ctrl99>system\nYou are a large language model built by "
"Google.<ctrl100>\n" +
prompt_string;
}
}
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok());
prompt_size = prompt.size();
// generate prompt
GenerateGemma(model, inference, prompt, abs_pos, pool, inner_pool,
stream_token, accept_token, gen, app.verbosity);
std::string response_string = response_stream.str();
if (!csv_writer.WriteRecord({record[csv.prompt_column], response_string})) {
fprintf(stderr, "Failed to write CSV: %s\n",
csv_writer.status().message().data());
}
response_stream.str(std::string()); // reset stream
response_stream.clear();
abs_pos = 0;
}
if (!csv_reader.Close()) {
fprintf(stderr, "Failed to close the CSV reader\n");
}
if (!csv_writer.Close()) {
fprintf(stderr, "Failed to close the CSV writer\n");
}
}
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
CsvArgs& csv) {
PROFILER_ZONE("Run.misc");
hwy::ThreadPool inner_pool(0);
hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning threads to cores helps.
if (app.num_threads > 10) {
pool.Run(0, pool.NumThreads(),
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
}
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights,
loader.ModelType(), loader.ModelTraining(), pool);
if (csv.input_csv.path.empty()) {
HWY_ABORT("Need to specify csv file.");
}
FileGemma(model, inference, app, csv, pool, inner_pool,
[](int) { return true; });
}
} // namespace gcpp
int main(int argc, char** argv) {
{
PROFILER_ZONE("Startup.misc");
gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs inference(argc, argv);
gcpp::AppArgs app(argc, argv);
gcpp::CsvArgs csv(argc, argv);
if (const char* error = loader.Validate()) {
loader.Help();
HWY_ABORT("Invalid args: %s", error);
}
gcpp::Run(loader, inference, app, csv);
}
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
return 0;
}

View File

@ -18,7 +18,6 @@
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
#include <iterator>
#if HWY_OS_LINUX
#include <sched.h>
@ -32,13 +31,10 @@
#include <algorithm> // std::clamp
#include <thread> // NOLINT>
// copybara:import_next_line:gemma_cpp
#include "configs.h"
// copybara:import_next_line:gemma_cpp
#include "gemma.h"
#include "hwy/base.h" // HWY_ASSERT
// copybara:import_next_line:gemma_cpp
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "util/args.h"
#include "hwy/base.h" // HWY_ASSERT
namespace gcpp {