mirror of https://github.com/google/gemma.cpp.git
Move code to gemma/ so we can remove error-prone copybara: comments.
Also fix includes and Lint warnings. PiperOrigin-RevId: 623127487
This commit is contained in:
parent
83dd08ac87
commit
a982ec1287
34
BUILD.bazel
34
BUILD.bazel
|
|
@ -22,9 +22,7 @@ exports_files(["LICENSE"])
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "ops",
|
name = "ops",
|
||||||
hdrs = [
|
hdrs = ["gemma/ops.h"],
|
||||||
"ops.h",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@hwy//:algo",
|
"@hwy//:algo",
|
||||||
|
|
@ -41,7 +39,7 @@ cc_library(
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "ops_test",
|
name = "ops_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["ops_test.cc"],
|
srcs = ["gemma/ops_test.cc"],
|
||||||
local_defines = ["HWY_IS_TEST"],
|
local_defines = ["HWY_IS_TEST"],
|
||||||
# for test_suite.
|
# for test_suite.
|
||||||
tags = ["hwy_ops_test"],
|
tags = ["hwy_ops_test"],
|
||||||
|
|
@ -55,9 +53,7 @@ cc_test(
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "args",
|
name = "args",
|
||||||
hdrs = [
|
hdrs = ["util/args.h"],
|
||||||
"util/args.h",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
],
|
],
|
||||||
|
|
@ -66,11 +62,11 @@ cc_library(
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gemma_lib",
|
name = "gemma_lib",
|
||||||
srcs = [
|
srcs = [
|
||||||
"gemma.cc",
|
"gemma/gemma.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"configs.h",
|
"gemma/configs.h",
|
||||||
"gemma.h",
|
"gemma/gemma.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":args",
|
":args",
|
||||||
|
|
@ -88,7 +84,7 @@ cc_library(
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "gemma_test",
|
name = "gemma_test",
|
||||||
srcs = ["gemma_test.cc"],
|
srcs = ["gemma/gemma_test.cc"],
|
||||||
# Requires model files
|
# Requires model files
|
||||||
tags = [
|
tags = [
|
||||||
"local",
|
"local",
|
||||||
|
|
@ -107,9 +103,7 @@ cc_test(
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "app",
|
name = "app",
|
||||||
hdrs = [
|
hdrs = ["util/app.h"],
|
||||||
"util/app.h",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
":args",
|
":args",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
|
|
@ -119,9 +113,7 @@ cc_library(
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "gemma",
|
name = "gemma",
|
||||||
srcs = [
|
srcs = ["gemma/run.cc"],
|
||||||
"run.cc",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
":app",
|
":app",
|
||||||
":args",
|
":args",
|
||||||
|
|
@ -137,9 +129,7 @@ cc_binary(
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "compress_weights",
|
name = "compress_weights",
|
||||||
srcs = [
|
srcs = ["gemma/compress_weights.cc"],
|
||||||
"compress_weights.cc",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
":args",
|
":args",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
|
|
@ -154,9 +144,7 @@ cc_binary(
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "benchmark",
|
name = "benchmark",
|
||||||
srcs = [
|
srcs = ["gemma/benchmark.cc"],
|
||||||
"benchmark.cc",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
":app",
|
":app",
|
||||||
":args",
|
":args",
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,6 @@ FetchContent_Declare(json GIT_REPOSITORY https://github.com/nlohmann/json.git GI
|
||||||
FetchContent_MakeAvailable(json)
|
FetchContent_MakeAvailable(json)
|
||||||
|
|
||||||
set(SOURCES
|
set(SOURCES
|
||||||
gemma.cc
|
|
||||||
compression/blob_store.cc
|
compression/blob_store.cc
|
||||||
compression/blob_store.h
|
compression/blob_store.h
|
||||||
compression/compress.h
|
compression/compress.h
|
||||||
|
|
@ -44,6 +43,10 @@ set(SOURCES
|
||||||
compression/sfp.h
|
compression/sfp.h
|
||||||
compression/sfp-inl.h
|
compression/sfp-inl.h
|
||||||
compression/test_util.h
|
compression/test_util.h
|
||||||
|
gemma/configs.h
|
||||||
|
gemma/gemma.cc
|
||||||
|
gemma/gemma.h
|
||||||
|
gemma/ops.h
|
||||||
util/app.h
|
util/app.h
|
||||||
util/args.h
|
util/args.h
|
||||||
)
|
)
|
||||||
|
|
@ -79,10 +82,10 @@ target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated
|
||||||
|
|
||||||
# Executable Target
|
# Executable Target
|
||||||
|
|
||||||
add_executable(gemma run.cc)
|
add_executable(gemma gemma/run.cc)
|
||||||
target_link_libraries(gemma libgemma hwy hwy_contrib)
|
target_link_libraries(gemma libgemma hwy hwy_contrib)
|
||||||
|
|
||||||
add_executable(benchmark benchmark.cc)
|
add_executable(benchmark gemma/benchmark.cc)
|
||||||
target_link_libraries(benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
|
target_link_libraries(benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
|
||||||
|
|
||||||
## Tests
|
## Tests
|
||||||
|
|
@ -90,8 +93,8 @@ set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests")
|
||||||
if (GEMMA_ENABLE_TESTS)
|
if (GEMMA_ENABLE_TESTS)
|
||||||
|
|
||||||
set(GEMMA_TEST_FILES
|
set(GEMMA_TEST_FILES
|
||||||
ops_test.cc
|
gemma/ops_test.cc
|
||||||
gemma_test.cc
|
gemma/gemma_test.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)
|
foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)
|
||||||
|
|
@ -112,5 +115,5 @@ endif() # GEMMA_ENABLE_TESTS
|
||||||
|
|
||||||
## Tools
|
## Tools
|
||||||
|
|
||||||
add_executable(compress_weights compress_weights.cc)
|
add_executable(compress_weights gemma/compress_weights.cc)
|
||||||
target_link_libraries(compress_weights libgemma hwy hwy_contrib)
|
target_link_libraries(compress_weights libgemma hwy hwy_contrib)
|
||||||
|
|
|
||||||
|
|
@ -26,11 +26,8 @@
|
||||||
#include <cstdlib> // std::abs
|
#include <cstdlib> // std::abs
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/distortion.h"
|
#include "compression/distortion.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/nuq.h"
|
#include "compression/nuq.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/stats.h"
|
#include "compression/stats.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
@ -46,9 +43,7 @@
|
||||||
#define THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
|
#define THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/nuq-inl.h"
|
#include "compression/nuq-inl.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/sfp-inl.h"
|
#include "compression/sfp-inl.h"
|
||||||
#include "hwy/contrib/sort/vqsort-inl.h"
|
#include "hwy/contrib/sort/vqsort-inl.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,6 @@
|
||||||
#undef _FILE_OFFSET_BITS
|
#undef _FILE_OFFSET_BITS
|
||||||
#define _FILE_OFFSET_BITS 64
|
#define _FILE_OFFSET_BITS 64
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/blob_store.h"
|
#include "compression/blob_store.h"
|
||||||
|
|
||||||
#include <fcntl.h> // open
|
#include <fcntl.h> // open
|
||||||
|
|
|
||||||
|
|
@ -23,11 +23,8 @@
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/blob_store.h"
|
#include "compression/blob_store.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/distortion.h"
|
#include "compression/distortion.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
@ -44,9 +41,7 @@
|
||||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE
|
#define THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/nuq-inl.h"
|
#include "compression/nuq-inl.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/sfp-inl.h"
|
#include "compression/sfp-inl.h"
|
||||||
#include "hwy/contrib/dot/dot-inl.h"
|
#include "hwy/contrib/dot/dot-inl.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
|
||||||
|
|
@ -27,19 +27,14 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/blob_store.h"
|
#include "compression/blob_store.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/nuq.h"
|
#include "compression/nuq.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/sfp.h"
|
#include "compression/sfp.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/distortion.h"
|
#include "compression/distortion.h"
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
#include "hwy/base.h" // hwy::bfloat16_t
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#if COMPRESS_STATS
|
#if COMPRESS_STATS
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/stats.h"
|
#include "compression/stats.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,9 +20,7 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/nuq.h"
|
#include "compression/nuq.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/sfp.h"
|
#include "compression/sfp.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
||||||
|
|
@ -37,7 +35,6 @@
|
||||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE
|
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/sfp-inl.h"
|
#include "compression/sfp-inl.h"
|
||||||
#include "hwy/contrib/sort/vqsort-inl.h"
|
#include "hwy/contrib/sort/vqsort-inl.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
|
||||||
|
|
@ -35,11 +35,8 @@
|
||||||
// clang-format on
|
// clang-format on
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
// Other headers that include Highway must come after foreach_target.h
|
// Other headers that include Highway must come after foreach_target.h
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/nuq-inl.h"
|
#include "compression/nuq-inl.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/nuq.h"
|
#include "compression/nuq.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/test_util.h"
|
#include "compression/test_util.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/sfp.h"
|
#include "compression/sfp.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@
|
||||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/sfp.h"
|
#include "compression/sfp.h"
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
@ -37,9 +36,7 @@
|
||||||
// clang-format on
|
// clang-format on
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
// Any highway.h must come after foreach_target.h
|
// Any highway.h must come after foreach_target.h
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/sfp-inl.h"
|
#include "compression/sfp-inl.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/test_util.h"
|
#include "compression/test_util.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/stats.h"
|
#include "compression/stats.h"
|
||||||
|
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
|
||||||
|
|
@ -24,9 +24,7 @@
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/distortion.h"
|
#include "compression/distortion.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/stats.h"
|
#include "compression/stats.h"
|
||||||
#include "hwy/tests/test_util.h" // RandomState
|
#include "hwy/tests/test_util.h" // RandomState
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
|
|
|
||||||
|
|
@ -15,12 +15,9 @@
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "third_party/gemma_cpp/gemma.h"
|
||||||
#include "gemma.h"
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/app.h" // LoaderArgs
|
#include "util/app.h" // LoaderArgs
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
|
|
||||||
std::vector<int> tokenize(const std::string& prompt_string,
|
std::vector<int> tokenize(const std::string& prompt_string,
|
||||||
|
|
|
||||||
|
|
@ -8,16 +8,13 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "nlohmann/json.hpp"
|
#include "nlohmann/json.hpp"
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "gemma/gemma.h"
|
||||||
#include "gemma.h"
|
#include "util/app.h"
|
||||||
|
#include "util/args.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/app.h"
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/args.h"
|
|
||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
|
||||||
|
|
@ -0,0 +1,137 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <random>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "third_party/benchmark/include/benchmark/benchmark.h"
|
||||||
|
#include "gemma/gemma.h"
|
||||||
|
#include "util/app.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
|
gcpp::LoaderArgs* loader = nullptr;
|
||||||
|
gcpp::InferenceArgs* inference = nullptr;
|
||||||
|
gcpp::Gemma* model = nullptr;
|
||||||
|
hwy::ThreadPool* pool = nullptr;
|
||||||
|
hwy::ThreadPool* inner_pool = nullptr;
|
||||||
|
|
||||||
|
void run_gemma_prompt(const std::string& prompt_string,
|
||||||
|
benchmark::State& state) {
|
||||||
|
std::mt19937 gen;
|
||||||
|
std::vector<int> prompt;
|
||||||
|
|
||||||
|
if (prompt_string.empty()) return;
|
||||||
|
HWY_ASSERT(model->Tokenizer().Encode(prompt_string, &prompt).ok());
|
||||||
|
|
||||||
|
int token_counter = 0;
|
||||||
|
auto stream_token = [&token_counter](int, float) {
|
||||||
|
token_counter++;
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
for (auto s : state) {
|
||||||
|
GenerateGemma(
|
||||||
|
*model, *inference, prompt, /*start_token=*/0, *pool, *inner_pool,
|
||||||
|
stream_token,
|
||||||
|
/*accept=*/[](int) { return true; }, gen, /*verbosity=*/0);
|
||||||
|
}
|
||||||
|
|
||||||
|
state.SetItemsProcessed(token_counter);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BM_short_prompt(benchmark::State& state) {
|
||||||
|
run_gemma_prompt("What is the capital of Spain?<ctrl23> ", state);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BM_factuality_prompt(benchmark::State& state) {
|
||||||
|
run_gemma_prompt("How does an inkjet printer work?<ctrl23> ", state);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BM_creative_prompt(benchmark::State& state) {
|
||||||
|
run_gemma_prompt(
|
||||||
|
"Tell me a story about a magical bunny and their TRS-80.<ctrl23> ",
|
||||||
|
state);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BM_coding_prompt(benchmark::State& state) {
|
||||||
|
run_gemma_prompt(
|
||||||
|
"Write a python program to generate a fibonacci sequence.<ctrl23> ",
|
||||||
|
state);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BM_long_coding_prompt(benchmark::State& state) {
|
||||||
|
std::ifstream t("benchmarks.cc", std::ios_base::in);
|
||||||
|
std::stringstream buffer;
|
||||||
|
buffer << t.rdbuf();
|
||||||
|
std::string prompt_string = buffer.str();
|
||||||
|
t.close();
|
||||||
|
|
||||||
|
run_gemma_prompt("Make improvements to the following code:\n " +
|
||||||
|
prompt_string + "<ctrl23> ",
|
||||||
|
state);
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
loader = new gcpp::LoaderArgs(argc, argv);
|
||||||
|
inference = new gcpp::InferenceArgs(argc, argv);
|
||||||
|
gcpp::AppArgs app(argc, argv);
|
||||||
|
|
||||||
|
pool = new ::hwy::ThreadPool(app.num_threads);
|
||||||
|
inner_pool = new ::hwy::ThreadPool(0);
|
||||||
|
model = new gcpp::Gemma(*loader, *pool);
|
||||||
|
|
||||||
|
inference->max_tokens = 128;
|
||||||
|
BENCHMARK(BM_short_prompt)
|
||||||
|
->Iterations(3)
|
||||||
|
->Unit(benchmark::kMillisecond)
|
||||||
|
->UseRealTime();
|
||||||
|
|
||||||
|
inference->max_tokens = 256;
|
||||||
|
BENCHMARK(BM_factuality_prompt)
|
||||||
|
->Iterations(3)
|
||||||
|
->Unit(benchmark::kMillisecond)
|
||||||
|
->UseRealTime();
|
||||||
|
|
||||||
|
BENCHMARK(BM_creative_prompt)
|
||||||
|
->Iterations(3)
|
||||||
|
->Unit(benchmark::kMillisecond)
|
||||||
|
->UseRealTime();
|
||||||
|
|
||||||
|
BENCHMARK(BM_coding_prompt)
|
||||||
|
->Iterations(3)
|
||||||
|
->Unit(benchmark::kMillisecond)
|
||||||
|
->UseRealTime();
|
||||||
|
|
||||||
|
inference->max_tokens = 1024;
|
||||||
|
BENCHMARK(BM_long_coding_prompt)
|
||||||
|
->Iterations(3)
|
||||||
|
->Unit(benchmark::kMillisecond)
|
||||||
|
->UseRealTime();
|
||||||
|
|
||||||
|
::benchmark ::RunSpecifiedBenchmarks();
|
||||||
|
::benchmark ::Shutdown();
|
||||||
|
|
||||||
|
delete loader;
|
||||||
|
delete inference;
|
||||||
|
delete model;
|
||||||
|
delete pool;
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
@ -18,12 +18,8 @@
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "gemma/gemma.h" // Gemma
|
||||||
#include "gemma.h" // Gemma
|
|
||||||
// copybara:end
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
// copybara:end
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -15,8 +15,8 @@
|
||||||
|
|
||||||
// Model configurations
|
// Model configurations
|
||||||
|
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_
|
||||||
|
|
||||||
// Allow changing pre-allocated kv cache size as a compiler flag
|
// Allow changing pre-allocated kv cache size as a compiler flag
|
||||||
#ifndef GEMMA_MAX_SEQLEN
|
#ifndef GEMMA_MAX_SEQLEN
|
||||||
|
|
@ -32,7 +32,6 @@
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/sfp.h"
|
#include "compression/sfp.h"
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
#include "hwy/base.h" // hwy::bfloat16_t
|
||||||
|
|
||||||
|
|
@ -164,4 +163,4 @@ struct ConfigGriffin2B {
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_
|
||||||
|
|
@ -18,22 +18,18 @@
|
||||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||||
// which we pass the filename via macro 'argument'.
|
// which we pass the filename via macro 'argument'.
|
||||||
#undef HWY_TARGET_INCLUDE
|
#undef HWY_TARGET_INCLUDE
|
||||||
#define HWY_TARGET_INCLUDE "gemma.cc" // NOLINT
|
#define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
// Must come after foreach_target.h to avoid redefinition errors.
|
// Must come after foreach_target.h to avoid redefinition errors.
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "gemma/ops.h"
|
||||||
#include "ops.h"
|
#include "util/args.h" // Path
|
||||||
#include "hwy/contrib/matvec/matvec-inl.h"
|
#include "hwy/contrib/matvec/matvec-inl.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/args.h" // Path
|
|
||||||
// copybara:import_next_line:sentencepiece
|
// copybara:import_next_line:sentencepiece
|
||||||
#include "src/sentencepiece_processor.h"
|
#include "src/sentencepiece_processor.h"
|
||||||
// copybara:end
|
|
||||||
|
|
||||||
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
|
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
|
||||||
// compile pass, whereas we want this defined in the first.
|
// compile pass, whereas we want this defined in the first.
|
||||||
|
|
@ -53,21 +49,16 @@
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <regex>
|
#include <regex> // NOLINT
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "gemma/configs.h"
|
||||||
#include "configs.h"
|
#include "gemma/gemma.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "gemma.h"
|
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
// copybara:import_next_line:sentencepiece
|
|
||||||
#include "src/sentencepiece_processor.h"
|
|
||||||
|
|
||||||
// Setting this to true disables fread() calls that read the model file.
|
// Setting this to true disables fread() calls that read the model file.
|
||||||
constexpr bool kDryRunFread = false;
|
constexpr bool kDryRunFread = false;
|
||||||
|
|
@ -13,25 +13,20 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "gemma/configs.h"
|
||||||
#include "compression/compress.h" // SfpStream/NuqStream
|
#include "util/args.h" // Path
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "configs.h"
|
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
#include "hwy/base.h" // hwy::bfloat16_t
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/args.h" // Path
|
|
||||||
// copybara:import_next_line:sentencepiece
|
|
||||||
#include "src/sentencepiece_processor.h"
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -71,6 +66,7 @@ struct GemmaInterface;
|
||||||
|
|
||||||
class GemmaTokenizer {
|
class GemmaTokenizer {
|
||||||
public:
|
public:
|
||||||
|
virtual ~GemmaTokenizer() = default;
|
||||||
virtual bool Encode(const std::string& input,
|
virtual bool Encode(const std::string& input,
|
||||||
std::vector<std::string>* pieces) const = 0;
|
std::vector<std::string>* pieces) const = 0;
|
||||||
virtual bool Encode(const std::string& input,
|
virtual bool Encode(const std::string& input,
|
||||||
|
|
@ -82,7 +78,7 @@ class GemmaTokenizer {
|
||||||
struct Gemma {
|
struct Gemma {
|
||||||
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||||
hwy::ThreadPool& pool);
|
hwy::ThreadPool& pool);
|
||||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
~Gemma(); // must be defined after the GemmaInterface dtor is defined.
|
||||||
const GemmaTokenizer* Tokenizer() const;
|
const GemmaTokenizer* Tokenizer() const;
|
||||||
std::unique_ptr<GemmaInterface> impl_;
|
std::unique_ptr<GemmaInterface> impl_;
|
||||||
};
|
};
|
||||||
|
|
@ -105,7 +101,7 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||||
|
|
||||||
// Convenience function for the common case:
|
// Convenience function for the common case:
|
||||||
// - Bundle runtime parameters as RuntimeConfig
|
// - Bundle runtime parameters as RuntimeConfig
|
||||||
// - No threadpools within threadpools (inner_pool = dummy)
|
// - No ThreadPool within ThreadPool (inner_pool = dummy)
|
||||||
// - All tokens accepted
|
// - All tokens accepted
|
||||||
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
|
|
@ -124,4 +120,4 @@ constexpr int EOS_ID = 1;
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
|
||||||
|
|
@ -13,14 +13,16 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "gemma/gemma.h"
|
||||||
#include "gemma.h"
|
|
||||||
|
|
||||||
#include <thread>
|
#include <algorithm>
|
||||||
|
#include <iostream>
|
||||||
|
#include <random>
|
||||||
|
#include <string>
|
||||||
|
#include <thread> // NOLINT
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "gemma/ops.h"
|
||||||
#include "ops.h"
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/tests/test_util-inl.h"
|
#include "hwy/tests/test_util-inl.h"
|
||||||
|
|
@ -79,7 +81,7 @@ class GemmaTest : public ::testing::Test {
|
||||||
std::cout << "Question " << i + 1 << "\n\n";
|
std::cout << "Question " << i + 1 << "\n\n";
|
||||||
std::string response = GemmaReply(kQA[i][0]);
|
std::string response = GemmaReply(kQA[i][0]);
|
||||||
std::cout << response << "\n\n";
|
std::cout << response << "\n\n";
|
||||||
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos);
|
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -14,8 +14,9 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
// Include guard for non-SIMD code.
|
// Include guard for non-SIMD code.
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_OPS_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
|
|
@ -43,7 +44,7 @@ static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
|
||||||
|
|
||||||
// Include guard for (potentially) SIMD code.
|
// Include guard for (potentially) SIMD code.
|
||||||
#if defined(THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
#if defined(THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||||
|
|
@ -53,7 +54,6 @@ static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
|
||||||
#define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
|
#define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
#include "hwy/contrib/algo/transform-inl.h"
|
#include "hwy/contrib/algo/transform-inl.h"
|
||||||
#include "hwy/contrib/dot/dot-inl.h"
|
#include "hwy/contrib/dot/dot-inl.h"
|
||||||
|
|
@ -25,14 +25,13 @@
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#undef HWY_TARGET_INCLUDE
|
#undef HWY_TARGET_INCLUDE
|
||||||
#define HWY_TARGET_INCLUDE "ops_test.cc" //NOLINT
|
#define HWY_TARGET_INCLUDE "gemma/ops_test.cc" //NOLINT
|
||||||
// clang-format on
|
// clang-format on
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/tests/test_util-inl.h"
|
#include "hwy/tests/test_util-inl.h"
|
||||||
// After highway.h
|
// After highway.h
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "gemma/ops.h"
|
||||||
#include "ops.h"
|
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -23,20 +23,16 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// Placeholder for internal header, do not modify.
|
// Placeholder for internal header, do not modify.
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "gemma/gemma.h" // Gemma
|
||||||
#include "gemma.h" // Gemma
|
#include "util/app.h"
|
||||||
|
#include "util/args.h" // HasHelp
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/per_target.h"
|
#include "hwy/per_target.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/app.h"
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/args.h" // HasHelp
|
|
||||||
|
|
||||||
static constexpr bool kVerboseLogTokens = false;
|
static constexpr bool kVerboseLogTokens = false;
|
||||||
|
|
||||||
|
|
@ -0,0 +1,223 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Command line text interface to gemma.
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <random>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "gemma/configs.h"
|
||||||
|
#include "gemma/gemma.h"
|
||||||
|
#include "util/app.h"
|
||||||
|
#include "util/args.h" // ArgsBase
|
||||||
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
#include "hwy/profiler.h"
|
||||||
|
#include "third_party/riegeli/bytes/file_reader.h"
|
||||||
|
#include "third_party/riegeli/bytes/file_writer.h"
|
||||||
|
#include "third_party/riegeli/csv/csv_reader.h"
|
||||||
|
#include "third_party/riegeli/csv/csv_writer.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
struct CsvArgs : public ArgsBase<CsvArgs> {
|
||||||
|
CsvArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
|
|
||||||
|
Path input_csv;
|
||||||
|
Path output_csv;
|
||||||
|
int prompt_column;
|
||||||
|
|
||||||
|
template <class Visitor>
|
||||||
|
void ForEach(const Visitor& visitor) {
|
||||||
|
visitor(input_csv, "input_csv", Path(),
|
||||||
|
"When set, prompts will be read from this CSV.");
|
||||||
|
visitor(output_csv, "output_csv", Path("/tmp/output.csv"),
|
||||||
|
"When --input_csv is set, prompts will be written to this CSV.");
|
||||||
|
visitor(prompt_column, "prompt_column", 0, "Prompt column index");
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
void FileGemma(gcpp::Gemma& model, InferenceArgs& inference, AppArgs& app,
|
||||||
|
CsvArgs& csv, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||||
|
const gcpp::AcceptFunc& accept_token) {
|
||||||
|
int abs_pos = 0; // absolute token index over all turns
|
||||||
|
int current_pos = 0; // token index within the current turn
|
||||||
|
int prompt_size{};
|
||||||
|
|
||||||
|
std::mt19937 gen;
|
||||||
|
if (inference.deterministic) {
|
||||||
|
gen.seed(42);
|
||||||
|
} else {
|
||||||
|
std::random_device rd;
|
||||||
|
gen.seed(rd());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::stringstream response_stream;
|
||||||
|
|
||||||
|
// callback function invoked for each generated token.
|
||||||
|
auto stream_token = [&inference, &abs_pos, ¤t_pos, &gen, &prompt_size,
|
||||||
|
tokenizer = &model.Tokenizer(),
|
||||||
|
&response_stream](int token, float) {
|
||||||
|
++abs_pos;
|
||||||
|
++current_pos;
|
||||||
|
if (current_pos < prompt_size) {
|
||||||
|
// pass
|
||||||
|
} else if (token == gcpp::EOS_ID) {
|
||||||
|
if (!inference.multiturn) {
|
||||||
|
abs_pos = 0;
|
||||||
|
if (inference.deterministic) {
|
||||||
|
gen.seed(42);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// end of stream
|
||||||
|
} else {
|
||||||
|
std::string token_text;
|
||||||
|
HWY_ASSERT(tokenizer->Decode({token}, &token_text).ok());
|
||||||
|
// +1 since position is incremented above
|
||||||
|
if (current_pos == prompt_size + 1) {
|
||||||
|
// first token of response
|
||||||
|
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
||||||
|
}
|
||||||
|
if (token_text != "\n")
|
||||||
|
response_stream << token_text;
|
||||||
|
else
|
||||||
|
response_stream << "\\n";
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
riegeli::CsvReader csv_reader(
|
||||||
|
riegeli::FileReader(csv.input_csv.path),
|
||||||
|
riegeli::CsvReaderBase::Options().set_comment('#').set_recovery(
|
||||||
|
[](absl::Status status, riegeli::CsvReaderBase& csv_reader) {
|
||||||
|
fprintf(stderr, "Invalid entry: %s", status.message().data());
|
||||||
|
return true;
|
||||||
|
}));
|
||||||
|
|
||||||
|
riegeli::CsvWriter csv_writer(
|
||||||
|
riegeli::FileWriter(csv.output_csv.path),
|
||||||
|
riegeli::CsvWriterBase::Options().set_header({"prompt", "response"}));
|
||||||
|
|
||||||
|
if (!csv_reader.ok()) {
|
||||||
|
HWY_ABORT("Invalid input CSV path %s", csv.input_csv.path.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!csv_writer.ok()) {
|
||||||
|
HWY_ABORT("Invalid output CSV path %s", csv.output_csv.path.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
while (abs_pos < inference.max_tokens) {
|
||||||
|
std::string prompt_string;
|
||||||
|
std::vector<int> prompt;
|
||||||
|
current_pos = 0;
|
||||||
|
|
||||||
|
std::vector<std::string> record;
|
||||||
|
csv_reader.ReadRecord(record);
|
||||||
|
|
||||||
|
if (record.empty()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt_string = record[csv.prompt_column];
|
||||||
|
fprintf(stdout, "Prompt: %s\n", prompt_string.c_str());
|
||||||
|
|
||||||
|
prompt_string =
|
||||||
|
"<ctrl99>user\n" + prompt_string + "<ctrl100>\n<ctrl99>model\n";
|
||||||
|
if (abs_pos > 0) {
|
||||||
|
// multi-turn dialogue continuation.
|
||||||
|
prompt_string = "<ctrl100>\n" + prompt_string;
|
||||||
|
} else {
|
||||||
|
HWY_DASSERT(abs_pos == 0);
|
||||||
|
if (gcpp::kSystemPrompt) {
|
||||||
|
prompt_string =
|
||||||
|
"<ctrl99>system\nYou are a large language model built by "
|
||||||
|
"Google.<ctrl100>\n" +
|
||||||
|
prompt_string;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok());
|
||||||
|
prompt_size = prompt.size();
|
||||||
|
|
||||||
|
// generate prompt
|
||||||
|
GenerateGemma(model, inference, prompt, abs_pos, pool, inner_pool,
|
||||||
|
stream_token, accept_token, gen, app.verbosity);
|
||||||
|
|
||||||
|
std::string response_string = response_stream.str();
|
||||||
|
if (!csv_writer.WriteRecord({record[csv.prompt_column], response_string})) {
|
||||||
|
fprintf(stderr, "Failed to write CSV: %s\n",
|
||||||
|
csv_writer.status().message().data());
|
||||||
|
}
|
||||||
|
|
||||||
|
response_stream.str(std::string()); // reset stream
|
||||||
|
response_stream.clear();
|
||||||
|
abs_pos = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!csv_reader.Close()) {
|
||||||
|
fprintf(stderr, "Failed to close the CSV reader\n");
|
||||||
|
}
|
||||||
|
if (!csv_writer.Close()) {
|
||||||
|
fprintf(stderr, "Failed to close the CSV writer\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
|
||||||
|
CsvArgs& csv) {
|
||||||
|
PROFILER_ZONE("Run.misc");
|
||||||
|
|
||||||
|
hwy::ThreadPool inner_pool(0);
|
||||||
|
hwy::ThreadPool pool(app.num_threads);
|
||||||
|
// For many-core, pinning threads to cores helps.
|
||||||
|
if (app.num_threads > 10) {
|
||||||
|
pool.Run(0, pool.NumThreads(),
|
||||||
|
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
|
||||||
|
}
|
||||||
|
|
||||||
|
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights,
|
||||||
|
loader.ModelType(), loader.ModelTraining(), pool);
|
||||||
|
|
||||||
|
if (csv.input_csv.path.empty()) {
|
||||||
|
HWY_ABORT("Need to specify csv file.");
|
||||||
|
}
|
||||||
|
|
||||||
|
FileGemma(model, inference, app, csv, pool, inner_pool,
|
||||||
|
[](int) { return true; });
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
{
|
||||||
|
PROFILER_ZONE("Startup.misc");
|
||||||
|
gcpp::LoaderArgs loader(argc, argv);
|
||||||
|
gcpp::InferenceArgs inference(argc, argv);
|
||||||
|
gcpp::AppArgs app(argc, argv);
|
||||||
|
gcpp::CsvArgs csv(argc, argv);
|
||||||
|
|
||||||
|
if (const char* error = loader.Validate()) {
|
||||||
|
loader.Help();
|
||||||
|
HWY_ABORT("Invalid args: %s", error);
|
||||||
|
}
|
||||||
|
|
||||||
|
gcpp::Run(loader, inference, app, csv);
|
||||||
|
}
|
||||||
|
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
10
util/app.h
10
util/app.h
|
|
@ -18,7 +18,6 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
||||||
|
|
||||||
#include <iterator>
|
|
||||||
#if HWY_OS_LINUX
|
#if HWY_OS_LINUX
|
||||||
#include <sched.h>
|
#include <sched.h>
|
||||||
|
|
||||||
|
|
@ -32,13 +31,10 @@
|
||||||
#include <algorithm> // std::clamp
|
#include <algorithm> // std::clamp
|
||||||
#include <thread> // NOLINT>
|
#include <thread> // NOLINT>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
#include "gemma/configs.h"
|
||||||
#include "configs.h"
|
#include "gemma/gemma.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "gemma.h"
|
|
||||||
#include "hwy/base.h" // HWY_ASSERT
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
|
#include "hwy/base.h" // HWY_ASSERT
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue