diff --git a/.clang-tidy b/.clang-tidy index abcd9d7..497c2e3 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,4 +1,5 @@ FormatStyle: file +WarningsAsErrors: "*" Checks: "-*,\ abseil-*,\ -abseil-string-find-startswith,\ @@ -204,3 +205,6 @@ Checks: "-*,\ -readability-uppercase-literal-suffix,\ -readability-use-anyofallof " +CheckOptions: + - { key: readability-identifier-naming.ConstexprVariableCase, value: CamelCase } + - { key: readability-identifier-naming.ConstexprVariablePrefix, value: k } diff --git a/BUILD.bazel b/BUILD.bazel index 3019030..7f9dfce 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -46,17 +46,6 @@ cc_library( ], ) -cc_library( - name = "app", - hdrs = [ - "util/app.h", - ], - deps = [ - ":args", - "@hwy//:hwy", - ], -) - cc_library( name = "gemma_lib", srcs = [ @@ -69,6 +58,7 @@ cc_library( deps = [ ":args", ":transformer_ops", + # "//base", "//compression:compress", "@hwy//:hwy", "@hwy//:matvec", @@ -79,6 +69,18 @@ cc_library( ], ) +cc_library( + name = "app", + hdrs = [ + "util/app.h", + ], + deps = [ + ":args", + ":gemma_lib", + "@hwy//:hwy", + ], +) + cc_binary( name = "gemma", srcs = [ @@ -88,6 +90,7 @@ cc_binary( ":app", ":args", ":gemma_lib", + # "//base", "//compression:compress", "@hwy//:hwy", "@hwy//:nanobenchmark", diff --git a/compression/blob_store.cc b/compression/blob_store.cc index e088fc6..050dfbd 100644 --- a/compression/blob_store.cc +++ b/compression/blob_store.cc @@ -341,7 +341,7 @@ BlobError BlobReader::Open(const char* filename) { #endif if (fd_ < 0) return __LINE__; -#if HWY_OS_LINUX +#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21) // Doubles the readahead window, which seems slightly faster when cached. (void)posix_fadvise(fd_, 0, 0, POSIX_FADV_SEQUENTIAL); #endif diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt new file mode 100644 index 0000000..292c80c --- /dev/null +++ b/examples/hello_world/CMakeLists.txt @@ -0,0 +1,49 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.11) +project(hello_world) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +include(FetchContent) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f) +FetchContent_MakeAvailable(highway) +FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) +FetchContent_MakeAvailable(sentencepiece) + + + +# Allow for both local and remote building) +option(BUILD_MODE "'local' or 'remote' git fetch for builds") +if (NOT BUILD_MODE) + set(BUILD_MODE "remote") +endif() +if (BUILD_MODE STREQUAL "local") + # Relative path to gemma.cpp from examples/hello_world/build/ + FetchContent_Declare(gemma SOURCE_DIR ../../..) +else() + FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e) +endif() +FetchContent_MakeAvailable(gemma) + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release") +endif() + +add_executable(hello_world run.cc) +target_link_libraries(hello_world hwy hwy_contrib sentencepiece libgemma) +FetchContent_GetProperties(sentencepiece) +target_include_directories(hello_world PRIVATE ${sentencepiece_SOURCE_DIR}) +target_compile_definitions(hello_world PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) +target_compile_options(hello_world PRIVATE $<$:-Wno-deprecated-declarations>) diff --git a/examples/hello_world/README.md b/examples/hello_world/README.md new file mode 100644 index 0000000..63c319e --- /dev/null +++ b/examples/hello_world/README.md @@ -0,0 +1,51 @@ +# Hello World Example + +This is a minimal/template project for using `gemma.cpp` as a library. Instead +of an interactive interface, it sets up the model state and generates text for a +single hard coded prompt. + +Build steps are similar to the main `gemma` executable. For now only +`cmake`/`make` is available for builds (PRs welcome for other build options). + +First use `cmake` to configure the project, starting from the `hello_world` +example directory (`gemma.cpp/examples/hello_world`): + +```sh +cmake -B build +``` + +This sets up a build configuration in `gemma.cpp/examples/hello_world/build`. +Note that this fetches `libgemma` from a git commit hash on github. +Alternatively if you want to build using the local version of `gemma.cpp` use: + +```sh +cmake -B build -DBUILD_MODE=local +``` + +Make sure you delete the contents of the build directory before changing +configurations. + +Then use `make` to build the project: + +```sh +cd build +make hello_world +``` + +As with the top-level `gemma.cpp` project you can use the `make` commands `-j` +flag to use parallel threads for faster builds. + +From inside the `gemma.cpp/examples/hello_world/build` directory, there should +be a `hello_world` executable. You can run it with the same 3 model arguments as +gemma.cpp specifying the tokenizer, compressed weights file, and model type, for +example: + +```sh +./hello_world --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it +``` + +Should print a greeting to the terminal: + +``` +"Hello, world! It's a pleasure to greet you all. May your day be filled with joy, peace, and all the things that make your heart soar. +``` diff --git a/examples/hello_world/build/.gitignore b/examples/hello_world/build/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/examples/hello_world/build/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc new file mode 100644 index 0000000..a994f31 --- /dev/null +++ b/examples/hello_world/run.cc @@ -0,0 +1,86 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +// copybara:import_next_line:gemma_cpp +#include "gemma.h" +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/args.h" +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/app.h" // LoaderArgs +// copybara:end +#include "hwy/contrib/thread_pool/thread_pool.h" + +std::vector tokenize( + const std::string& prompt_string, + const sentencepiece::SentencePieceProcessor* tokenizer) { + std::string formatted = "user\n" + prompt_string + + "\nmodel\n"; + std::vector tokens; + HWY_ASSERT(tokenizer->Encode(formatted, &tokens).ok()); + tokens.insert(tokens.begin(), 2); // BOS token + return tokens; +} + +int main(int argc, char** argv) { + gcpp::LoaderArgs loader(argc, argv); + + // Rough heuristic for the number of threads to use + size_t num_threads = static_cast(std::clamp( + static_cast(std::thread::hardware_concurrency()) - 2, 1, 18)); + hwy::ThreadPool pool(num_threads); + + // Instantiate model and KV Cache + gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, + loader.ModelType(), pool); + auto kv_cache = CreateKVCache(loader.ModelType()); + size_t pos = 0; // KV Cache position + + // Initialize random number generator + std::mt19937 gen; + std::random_device rd; + gen.seed(rd()); + + // Tokenize instruction + std::vector tokens = + tokenize("Write a greeting to the world.", model.Tokenizer()); + size_t ntokens = tokens.size(); + + // This callback function gets invoked everytime a token is generated + auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()]( + int token, float) { + ++pos; + if (pos < ntokens) { + // print feedback + } else if (token != gcpp::EOS_ID) { + std::string token_text; + HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text).ok()); + std::cout << token_text << std::flush; + } + return true; + }; + + GenerateGemma(model, + {.max_tokens = 2048, + .max_generated_tokens = 1024, + .temperature = 1.0, + .verbosity = 0}, + tokens, /*KV cache position = */ 0, kv_cache, pool, + stream_token, gen); + std::cout << std::endl; +} diff --git a/gemma.cc b/gemma.cc index eb43e81..4fe2782 100644 --- a/gemma.cc +++ b/gemma.cc @@ -25,8 +25,6 @@ #include "compression/compress-inl.h" // copybara:import_next_line:gemma_cpp #include "ops.h" -// copybara:import_next_line:gemma_cpp -#include "util/args.h" // Path #include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/highway.h" #include "hwy/profiler.h" @@ -231,20 +229,39 @@ struct Activations { struct GemmaInterface { virtual ~GemmaInterface() = default; - virtual const sentencepiece::SentencePieceProcessor& Tokenizer() const = 0; + virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0; - // TODO: group pool/callbacks into struct - virtual void Generate(const InferenceArgs& args, - const std::vector& prompt, size_t start_pos, + virtual void Generate(size_t max_tokens, size_t max_generated_tokens, + float temperature, const std::vector& prompt, + size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) = 0; }; +template +KVCache CreateKVCache() { + return CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim, + Config::kSeqLen); +} + +KVCache CreateKVCache(Model type) { + switch (type) { + case Model::GEMMA_2B: + return CreateKVCache(); + case Model::GEMMA_7B: + return CreateKVCache(); + default: + HWY_ABORT("Model type %d unknown.", static_cast(type)); + } +} + template struct GemmaImpl : public GemmaInterface { - GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool); + GemmaImpl(std::unique_ptr& tokenizer, + hwy::AlignedFreeUniquePtr& compressed_weights, + hwy::ThreadPool& pool); ~GemmaImpl() { using CWeights = CompressedWeights; @@ -252,22 +269,21 @@ struct GemmaImpl : public GemmaInterface { c_weights->c_layer_ptrs.~CompressedLayerPointers(); } - const sentencepiece::SentencePieceProcessor& Tokenizer() const { - return tokenizer; + const sentencepiece::SentencePieceProcessor* Tokenizer() const override { + return tokenizer.get(); } - void Generate(const InferenceArgs& args, const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, + void Generate(size_t max_tokens, size_t max_generated_tokens, + float temperature, const std::vector& prompt, + size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937&, int verbosity); + const AcceptFunc& accept_token, std::mt19937&, + int verbosity) override; - sentencepiece::SentencePieceProcessor tokenizer; - - // CompressedWeights + std::unique_ptr tokenizer; hwy::AlignedFreeUniquePtr compressed_weights; hwy::AlignedUniquePtr> prefill; hwy::AlignedUniquePtr> state; - KVCache kv_cache; }; } // namespace gcpp @@ -294,7 +310,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, static constexpr size_t kModelDim = gcpp::Activations::kModelDim; static constexpr size_t kHeads = TConfig::kHeads; - const float kQueryScale = 1.0 / sqrtf(static_cast(kQKVDim)); + static const float kQueryScale = + static_cast(1.0 / sqrt(static_cast(kQKVDim))); pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { // linear projections to QKV @@ -417,7 +434,8 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, hwy::ThreadPool& inner_pool) { PROFILER_ZONE("Gen.Prefill\\Att\\FFW"); static constexpr size_t kModelDim = TConfig::kModelDim; - static const float kEmbScaling = sqrtf(static_cast(kModelDim)); + static const float kEmbScaling = + static_cast(sqrt(static_cast(kModelDim))); pool.Run( 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { @@ -472,7 +490,8 @@ void Transformer(int token, size_t pos, static constexpr size_t kLayers = TConfig::kLayers; static constexpr size_t kModelDim = TConfig::kModelDim; - static const float kEmbScaling = sqrtf(static_cast(kModelDim)); + static const float kEmbScaling = + static_cast(sqrt(static_cast(kModelDim))); Decompress(c_weights.c_embedder_input_embedding, token * kModelDim, activations.x.data(), kModelDim); @@ -495,8 +514,9 @@ void Transformer(int token, size_t pos, } template -void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, - const std::vector& prompt, size_t pos, +void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, + size_t max_generated_tokens, float temperature, + const std::vector& prompt, size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, @@ -510,7 +530,6 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, const CompressedWeights& c_weights = *reinterpret_cast*>( gemma.compressed_weights.get()); - KVCache& kv_cache = gemma.kv_cache; int token; // pos indexes the KV cache. In the first turn of a chat, pos = 0. @@ -548,8 +567,9 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, // in the future this output should not occur in GenerateImpl but instead // should be available as observable state for frontend code to handle I/O. const double prefill_end = hwy::platform::Now(); - const double prefill_tok_sec = static_cast(pos_offset) / (prefill_end - prefill_start); - std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]\n"; + const double prefill_tok_sec = + static_cast(pos_offset) / (prefill_end - prefill_start); + std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]"; } const double gen_start = hwy::platform::Now(); @@ -558,10 +578,10 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, if (verbosity >= 2) { // Provide usage warnings if max_new_tokens is out of range. - if (args.max_generated_tokens > args.max_tokens) { + if (max_generated_tokens > max_tokens) { std::cout << "Warning: max_new_tokens should be <= max_tokens" << std::endl; - } else if ((prompt.size() + args.max_generated_tokens) > args.max_tokens) { + } else if ((prompt.size() + max_generated_tokens) > max_tokens) { std::cout << "Warning: Prompt size + max_new_tokens exceeds max_tokens." << std::endl; } @@ -570,7 +590,7 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, auto pos_gen_start = pos_offset; token = prompt.at(pos_offset); size_t generate_pos = 0; - for (; pos < args.max_tokens && generate_pos < args.max_generated_tokens; + for (; pos < max_tokens && generate_pos < max_generated_tokens; ++pos, ++pos_offset, ++generate_pos) { Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool); float* final_activation = activations.x.data(); @@ -583,7 +603,7 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, // Barrier: must have all logits so we can subtract max. Softmax(activations.logits.data(), kVocabSize); token = SampleTopK(activations.logits.data(), kVocabSize, gen, - args.temperature, accept_token); + temperature, accept_token); } if (!stream_token(token, activations.logits[token])) { token = EOS_ID; @@ -592,7 +612,8 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, if (verbosity >= 2) { const double gen_end = hwy::platform::Now(); const double gen_tok_sec = - static_cast(pos_offset - pos_gen_start) / (gen_end - gen_start); + static_cast(pos_offset - pos_gen_start) / + (gen_end - gen_start); std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n"; } break; @@ -600,21 +621,27 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, } } -void Generate2B(GemmaImpl& gemma, const InferenceArgs& args, +void Generate2B(GemmaImpl& gemma, size_t max_tokens, + size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { - GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token, + KVCache& kv_cache, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, + int verbosity) { + GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, + start_pos, kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); } -void Generate7B(GemmaImpl& gemma, const InferenceArgs& args, +void Generate7B(GemmaImpl& gemma, size_t max_tokens, + size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { - GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token, + KVCache& kv_cache, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, + int verbosity) { + GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, + start_pos, kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); } @@ -666,10 +693,10 @@ void ForEachTensor(const Weights* weights, template hwy::AlignedFreeUniquePtr GetCompressedWeights( - const Path& model, const Path& cache, hwy::ThreadPool& pool) { + const Path& weights_path, const Path& cache, hwy::ThreadPool& pool) { PROFILER_ZONE("Startup.LoadCache"); - if (!std::filesystem::exists(model.path) && + if (!std::filesystem::exists(weights_path.path) && !std::filesystem::exists(cache.path)) { HWY_ABORT( "Either the model weights (--weights) or cached compressed weights " @@ -689,7 +716,8 @@ hwy::AlignedFreeUniquePtr GetCompressedWeights( if (loader.ReadAll(pool)) return c_weights_u8; // Get weights, compress, and store in cache. - const hwy::AlignedUniquePtr> weights = LoadWeights(model); + const hwy::AlignedUniquePtr> weights = + LoadWeights(weights_path); Compressor compressor(pool); ForEachTensor(weights.get(), *c_weights, compressor); compressor.WriteAll(pool, cache.path.c_str()); @@ -699,14 +727,17 @@ hwy::AlignedFreeUniquePtr GetCompressedWeights( // Type-erased because this function is called via a function pointer. hwy::AlignedFreeUniquePtr GetCompressedWeightsT( - const LoaderArgs& args, hwy::ThreadPool& pool) { - switch (args.ModelType()) { + gcpp::Model model, const Path& weights, const Path& compressed_weights, + hwy::ThreadPool& pool) { + switch (model) { case Model::GEMMA_2B: - return GetCompressedWeights(args.model, args.cache, pool); + return GetCompressedWeights(weights, compressed_weights, + pool); case Model::GEMMA_7B: - return GetCompressedWeights(args.model, args.cache, pool); + return GetCompressedWeights(weights, compressed_weights, + pool); default: - HWY_ABORT("Model type %d unknown.", static_cast(args.ModelType())); + HWY_ABORT("Model type %d unknown.", static_cast(model)); } } @@ -729,75 +760,99 @@ KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) { } template -GemmaImpl::GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool) - : compressed_weights( - HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)), +GemmaImpl::GemmaImpl( + std::unique_ptr& tokenizer, + hwy::AlignedFreeUniquePtr& compressed_weights, + hwy::ThreadPool& pool) + : tokenizer(std::move(tokenizer)), + compressed_weights(std::move(compressed_weights)), prefill(hwy::MakeUniqueAligned>()), - state(hwy::MakeUniqueAligned>()), - kv_cache( - CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim, - Config::kSeqLen)) { - PROFILER_ZONE("Startup.tokenizer"); - - HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok()); -} + state(hwy::MakeUniqueAligned>()) {} template <> -void GemmaImpl::Generate(const InferenceArgs& args, - const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, - const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { +void GemmaImpl::Generate( + size_t max_tokens, size_t max_generated_tokens, float temperature, + const std::vector& prompt, size_t start_pos, KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, const AcceptFunc& accept_token, + std::mt19937& gen, int verbosity) { HWY_DYNAMIC_DISPATCH(Generate2B) - (*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token, - gen, verbosity); + (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, + kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); } template <> -void GemmaImpl::Generate(const InferenceArgs& args, - const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, - const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { +void GemmaImpl::Generate( + size_t max_tokens, size_t max_generated_tokens, float temperature, + const std::vector& prompt, size_t start_pos, KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, const AcceptFunc& accept_token, + std::mt19937& gen, int verbosity) { HWY_DYNAMIC_DISPATCH(Generate7B) - (*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token, - gen, verbosity); + (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, + kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); } -Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { - const Model model_type = args.ModelType(); - model_training = args.ModelTraining(); +Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + const Path& weights_path, Model model_type, + hwy::ThreadPool& pool) { + std::unique_ptr tokenizer; + { + PROFILER_ZONE("Startup.tokenizer"); + tokenizer = std::make_unique(); + if (!tokenizer->Load(tokenizer_path.path).ok()) { + HWY_ABORT("Failed to load the tokenizer file."); + } + } + auto compressed_weights = HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)( + model_type, weights_path, compressed_weights_path, pool); switch (model_type) { case Model::GEMMA_2B: - impl_.reset(new GemmaImpl(args, pool)); + impl_.reset( + new GemmaImpl(tokenizer, compressed_weights, pool)); break; case Model::GEMMA_7B: - impl_.reset(new GemmaImpl(args, pool)); + impl_.reset( + new GemmaImpl(tokenizer, compressed_weights, pool)); break; default: HWY_ABORT("Model type %d unknown.", static_cast(model_type)); } } + +Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + Model model_type, hwy::ThreadPool& pool) + : Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type, + pool) {} + Gemma::~Gemma() = default; // after GemmaInterface is defined -const sentencepiece::SentencePieceProcessor& Gemma::Tokenizer() const { +const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { return impl_->Tokenizer(); } -void GenerateGemma(Gemma& gemma, const InferenceArgs& args, - const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, +void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, + float temperature, const std::vector& prompt, + size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { pool.SetWaitMode(hwy::PoolWaitMode::kSpin); - gemma.impl_->Generate(args, prompt, start_pos, pool, inner_pool, stream_token, + gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt, + start_pos, kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); pool.SetWaitMode(hwy::PoolWaitMode::kBlock); } +void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, + const std::vector& prompt, size_t start_pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + const StreamFunc& stream_token, std::mt19937& gen) { + hwy::ThreadPool inner_pool(0); + GenerateGemma( + gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens, + runtime_config.temperature, prompt, start_pos, kv_cache, pool, inner_pool, + stream_token, [](int) { return true; }, gen, runtime_config.verbosity); +} + } // namespace gcpp #endif // HWY_ONCE diff --git a/gemma.h b/gemma.h index 7195bc9..cdd4873 100644 --- a/gemma.h +++ b/gemma.h @@ -64,147 +64,50 @@ struct KVCache { enum class Model { GEMMA_2B, GEMMA_7B }; enum class ModelTraining { GEMMA_IT, GEMMA_PT }; -struct LoaderArgs : public ArgsBase { - LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } - - static std::string ToLower(const std::string& text) { - std::string result = text; - std::transform(begin(result), end(result), begin(result), - [](unsigned char c) { return std::tolower(c); }); - return result; - } - - gcpp::Model ModelType() const { - const std::string model_type_lc = ToLower(model_type); - if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") { - return gcpp::Model::GEMMA_2B; - } else { - return gcpp::Model::GEMMA_7B; - } - } - - gcpp::ModelTraining ModelTraining() const { - const std::string model_type_lc = ToLower(model_type); - if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") { - return gcpp::ModelTraining::GEMMA_PT; - } else { - return gcpp::ModelTraining::GEMMA_IT; - } - } - - // Returns error string or nullptr if OK. - const char* Validate() const { - const std::string model_type_lc = ToLower(model_type); - if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" && - model_type_lc != "2b-it" && model_type_lc != "7b-it") { - return "Model type must be 2b-pt, 7b-pt, 2b-it, or " - "7b-it."; - } - if (tokenizer.path.empty()) { - return "Missing --tokenizer flag, a file for the tokenizer is required."; - } - if (model_type.empty()) { - return "Missing --model flag, need to specify either 2b-pt, 7b-pt, " - "2b-it, or 7b-it."; - } - if (cache.path.empty()) { - return "Missing --compressed_weights flag, a file for the compressed " - "model."; - } - return nullptr; - } - - Path tokenizer; - Path model; // uncompressed weights OR - Path cache; // compressed weights - std::string model_type; - - template - void ForEach(const Visitor& visitor) { - visitor(tokenizer, "tokenizer", Path(), - "Path name of tokenizer model file.\n Required argument."); - visitor( - cache, "compressed_weights", Path(), - "Path name of compressed weights file, regenerated from `--weights` " - "file if " - "the compressed weights file does not exist.\n Required argument."); - visitor(model_type, "model", std::string(), - "Model type\n 2b-it (2B parameters, instruction-tuned)\n " - "2b-pt (2B parameters, pretrained)\n 7b-it (7B parameters " - "instruction-tuned)\n 7b-pt (7B parameters, pretrained)\n" - " Required argument."); - visitor(model, "weights", Path(), - "Path name of model weights (.sbs) file. Only required if " - "compressed_weights file is not present and needs to be " - "regenerated. This parameter is only required for compressing" - "new model weight exports, otherwise it is not needed."); - } +struct RuntimeConfig { + size_t max_tokens; + size_t max_generated_tokens; + float temperature; + int verbosity; }; struct GemmaInterface; struct Gemma { - Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); + Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + const Path& weights_path, Model model_type, hwy::ThreadPool& pool); + Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + Model model_type, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. - - const sentencepiece::SentencePieceProcessor& Tokenizer() const; - + const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; gcpp::ModelTraining model_training; }; +KVCache CreateKVCache(Model type); // convenient workaround for now +KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len); + // StreamFunc is called with (token, probability). For prompt tokens, // probability is 0.0f. using StreamFunc = std::function; using AcceptFunc = std::function; -struct InferenceArgs : public ArgsBase { - InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } - - size_t max_tokens; - size_t max_generated_tokens; - - float temperature; - bool deterministic; - bool multiturn; - - // Returns error string or nullptr if OK. - const char* Validate() const { - if (max_tokens > gcpp::kSeqLen) { - return "max_tokens is larger than the maximum sequence length (see " - "configs.h)."; - } - if (max_generated_tokens > max_tokens) { - return "Maximum number of generated tokens is larger than the maximum " - "total tokens."; - } - return nullptr; - } - - template - void ForEach(const Visitor& visitor) { - visitor(max_tokens, "max_tokens", size_t{3072}, - "Maximum number of tokens in prompt + generation."); - visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, - "Maximum number of tokens to generate."); - - visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2); - visitor(deterministic, "deterministic", false, - "Make top-k sampling deterministic", 2); - visitor(multiturn, "multiturn", false, - "Multiturn mode (if 0, this clears the KV cache after every " - "interaction without quitting)\n Default : 0 (conversation " - "resets every turn)"); - } -}; - -void GenerateGemma(Gemma& gemma, const InferenceArgs& args, - const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& g, +void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, + float temperature, const std::vector& prompt, + size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, int verbosity); +// Convenience function for the common case: +// - Bundle runtime parameters as RuntimeConfig +// - No threadpools within threadpools (inner_pool = dummy) +// - All tokens accepted +void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, + const std::vector& prompt, size_t start_pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + const StreamFunc& stream_token, std::mt19937& gen); + constexpr int EOS_ID = 1; } // namespace gcpp diff --git a/models/.gitignore b/models/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/ops.h b/ops.h index 3725776..481e1d7 100644 --- a/ops.h +++ b/ops.h @@ -340,11 +340,21 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a, // = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT. static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2( const float* HWY_RESTRICT a, size_t size) { - float total = 0.f; - for (size_t i = 0; i < size; ++i) { - total += a[i] * a[i]; + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + HWY_DASSERT(size >= 2 * N); + HWY_DASSERT(size % (2 * N) == 0); + + auto sum0 = hn::Zero(d); + auto sum1 = hn::Zero(d); + for (size_t i = 0; i <= size - 2 * N; i += 2 * N) { + const auto a0 = LoadU(d, a + i); + sum0 = MulAdd(a0, a0, sum0); + const auto a1 = LoadU(d, a + i + N); + sum1 = MulAdd(a1, a1, sum1); } - return total; + + return ReduceSum(d, Add(sum0, sum1)); } static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( @@ -362,12 +372,30 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight, float* HWY_RESTRICT out, size_t size) { - constexpr float eps = 1e-6f; - float ss = SquaredL2(x, size); - ss = 1.0f / sqrtf(ss / StaticCast(size) + eps); - for (size_t j = 0; j < size; j++) { - // Note 1.0f centering here - out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]); + namespace hn = hwy::HWY_NAMESPACE; + + constexpr float kEps = 1e-6f; + constexpr size_t kUnrollSize = 2; + + const hn::ScalableTag dbf; + const hn::Repartition df32; + const size_t N32 = hn::Lanes(df32); + + const float ss = SquaredL2(x, size); + const auto vss = + hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + kEps)); + + HWY_DASSERT(size % (kUnrollSize * MaxLanes(df32)) == 0); + for (size_t i = 0; i < size; i += kUnrollSize * N32) { + const hn::Vec w16 = hn::LoadU(dbf, weight + i); + const auto w0 = hn::PromoteLowerTo(df32, w16); + const auto w1 = hn::PromoteUpperTo(df32, w16); + const auto m0 = hn::Mul(vss, hn::LoadU(df32, x + i)); + const auto m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32)); + + // (1+weight) * m = m + weight*m = one FMA. + hn::StoreU(hn::MulAdd(m0, w0, m0), df32, out + i); + hn::StoreU(hn::MulAdd(m1, w1, m1), df32, out + i + N32); } } diff --git a/run.cc b/run.cc index 507979d..d6bf22d 100644 --- a/run.cc +++ b/run.cc @@ -66,8 +66,8 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { << std::thread::hardware_concurrency() << std::endl << "Instruction set : " << hwy::TargetName(hwy::DispatchedTarget()) << " (" - << hwy::VectorBytes() * 8 << " bits)" - << "\n" + << hwy::VectorBytes() * 8 << " bits)" << "\n" + << "Compiled config : " << CompiledConfig() << "\n" << "Weight Type : " << gcpp::TypeName(gcpp::WeightT()) << "\n" << "EmbedderInput Type : " @@ -96,10 +96,10 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, std::cerr << "\n"; } -void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const InferenceArgs& args, - int verbosity, const gcpp::AcceptFunc& accept_token, - std::string& eot_line) { +void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const InferenceArgs& args, int verbosity, + const gcpp::AcceptFunc& accept_token, std::string& eot_line) { PROFILER_ZONE("Gen.misc"); int abs_pos = 0; // absolute token index over all turns int current_pos = 0; // token index within the current turn @@ -115,7 +115,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, // callback function invoked for each generated token. auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size, - tokenizer = &model.Tokenizer(), + tokenizer = model.Tokenizer(), verbosity](int token, float) { ++abs_pos; ++current_pos; @@ -129,7 +129,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, } } if (verbosity >= 2) { - std::cout << "\n[ End ]" << std::endl; + std::cout << "\n[ End ]\n"; } } else { std::string token_text; @@ -142,7 +142,6 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, std::cout << std::endl << std::endl; } } - // TODO(austinvhuang): is explicit space necessary? std::cout << token_text << std::flush; } return true; @@ -191,7 +190,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, } } - HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok()); + HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt).ok()); // For both pre-trained and instruction-tuned models: prepend "" token // if needed. @@ -204,8 +203,9 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, std::cerr << std::endl << "[ Reading prompt ] " << std::flush; const double time_start = hwy::platform::Now(); - GenerateGemma(model, args, prompt, abs_pos, pool, inner_pool, stream_token, - accept_token, gen, verbosity); + GenerateGemma(model, args.max_tokens, args.max_generated_tokens, + args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool, + stream_token, accept_token, gen, verbosity); const double time_end = hwy::platform::Now(); const double tok_sec = current_pos / (time_end - time_start); if (verbosity >= 2) { @@ -234,7 +234,10 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); } - gcpp::Gemma model(loader, pool); + gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, + loader.ModelType(), pool); + + auto kv_cache = CreateKVCache(loader.ModelType()); if (const char* error = inference.Validate()) { ShowHelp(loader, inference, app); @@ -272,7 +275,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } ReplGemma( - model, pool, inner_pool, inference, app.verbosity, + model, kv_cache, pool, inner_pool, inference, app.verbosity, /*accept_token=*/[](int) { return true; }, app.eot_line); } diff --git a/util/app.h b/util/app.h index 7f926a5..cd6cb6c 100644 --- a/util/app.h +++ b/util/app.h @@ -18,10 +18,13 @@ #ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ #define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ +#include #if HWY_OS_LINUX #include +#include #include // IDE does not recognize errno.h as providing errno. +#include #endif #include #include @@ -29,6 +32,14 @@ #include // std::clamp #include // NOLINT> +// copybara:import_next_line:gemma_cpp +#include "configs.h" +// copybara:end + +// copybara:import_next_line:gemma_cpp +#include "gemma.h" +// copybara:end + // copybara:import_next_line:gemma_cpp #include "util/args.h" // copybara:end @@ -36,6 +47,24 @@ namespace gcpp { +static inline const char* CompiledConfig() { + if (HWY_IS_ASAN) { + return "asan"; + } else if (HWY_IS_MSAN) { + return "msan"; + } else if (HWY_IS_TSAN) { + return "tsan"; +#if defined(HWY_IS_UBSAN) + } else if (HWY_IS_UBSAN) { + return "ubsan"; +#endif + } else if (HWY_IS_DEBUG_BUILD) { + return "dbg"; + } else { + return "opt"; + } +} + static inline void PinThreadToCore(size_t cpu_index) { #if HWY_OS_LINUX // Forces the thread to run on the logical processor with the same number. @@ -79,9 +108,9 @@ class AppArgs : public ArgsBase { template void ForEach(const Visitor& visitor) { visitor(verbosity, "verbosity", 1, - "Show verbose developer information\n 0 = only print generation " - "output\n 1 = standard user-facing terminal ui\n 2 = show " - "developer/debug info).\n Default = 1.", + "Show verbose developer information\n 0 = only print generation " + "output\n 1 = standard user-facing terminal ui\n 2 = show " + "developer/debug info).\n Default = 1.", 2); visitor(num_threads, "num_threads", kDefaultNumThreads, // see ChooseNumThreads @@ -98,6 +127,124 @@ class AppArgs : public ArgsBase { } }; +struct LoaderArgs : public ArgsBase { + LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + + static std::string ToLower(const std::string& text) { + std::string result = text; + std::transform(begin(result), end(result), begin(result), + [](unsigned char c) { return std::tolower(c); }); + return result; + } + + gcpp::Model ModelType() const { + const std::string model_type_lc = ToLower(model_type); + if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") { + return gcpp::Model::GEMMA_2B; + } else { + return gcpp::Model::GEMMA_7B; + } + } + + gcpp::ModelTraining ModelTraining() const { + const std::string model_type_lc = ToLower(model_type); + if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") { + return gcpp::ModelTraining::GEMMA_PT; + } else { + return gcpp::ModelTraining::GEMMA_IT; + } + } + + // Returns error string or nullptr if OK. + const char* Validate() const { + const std::string model_type_lc = ToLower(model_type); + if (model_type.empty()) { + return "Missing --model flag, need to specify either 2b-pt, 7b-pt, " + "2b-it, or 7b-it."; + } + if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" && + model_type_lc != "2b-it" && model_type_lc != "7b-it") { + return "Model type must be 2b-pt, 7b-pt, 2b-it, or " + "7b-it."; + } + if (tokenizer.path.empty()) { + return "Missing --tokenizer flag, a file for the tokenizer is required."; + } + if (compressed_weights.path.empty()) { + return "Missing --compressed_weights flag, a file for the compressed " + "model."; + } + return nullptr; + } + + Path tokenizer; + Path weights; // uncompressed weights file location + Path compressed_weights; // compressed weights file location + std::string model_type; + + template + void ForEach(const Visitor& visitor) { + visitor(tokenizer, "tokenizer", Path(), + "Path name of tokenizer model file.\n Required argument."); + visitor( + compressed_weights, "compressed_weights", Path(), + "Path name of compressed weights file, regenerated from `--weights` " + "file if " + "the compressed weights file does not exist.\n Required argument."); + visitor(model_type, "model", std::string(), + "Model type\n 2b-it = 2B parameters, instruction-tuned\n " + "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " + "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n" + " Required argument."); + visitor(weights, "weights", Path(), + "Path name of model weights (.sbs) file. Only required if " + "compressed_weights file is not present and needs to be " + "regenerated. This parameter is only required for compressing" + "new model weight exports, otherwise it is not needed."); + } +}; + +struct InferenceArgs : public ArgsBase { + InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + + size_t max_tokens; + size_t max_generated_tokens; + + float temperature; + bool deterministic; + bool multiturn; + + // Returns error string or nullptr if OK. + const char* Validate() const { + if (max_tokens > gcpp::kSeqLen) { + return "max_tokens is larger than the maximum sequence length (see " + "configs.h)."; + } + if (max_generated_tokens > max_tokens) { + return "Maximum number of generated tokens is larger than the maximum " + "total tokens."; + } + return nullptr; + } + + template + void ForEach(const Visitor& visitor) { + visitor(max_tokens, "max_tokens", size_t{3072}, + "Maximum number of tokens in prompt + generation."); + visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, + "Maximum number of tokens to generate."); + + visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2); + visitor(deterministic, "deterministic", false, + "Make top-k sampling deterministic", 2); + visitor(multiturn, "multiturn", false, + "Multiturn mode\n 0 = clear KV cache after every " + "interaction\n 1 = continue KV cache after every interaction\n " + " Default : 0 (conversation " + "resets every turn)"); + } +}; + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_