diff --git a/BUILD.bazel b/BUILD.bazel index 3019030..cc5104c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -69,6 +69,7 @@ cc_library( deps = [ ":args", ":transformer_ops", + "//base", "//compression:compress", "@hwy//:hwy", "@hwy//:matvec", @@ -88,6 +89,7 @@ cc_binary( ":app", ":args", ":gemma_lib", + "//base", "//compression:compress", "@hwy//:hwy", "@hwy//:nanobenchmark", diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt new file mode 100644 index 0000000..9d44f04 --- /dev/null +++ b/examples/hello_world/CMakeLists.txt @@ -0,0 +1,49 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.11) +project(hello_world) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +include(FetchContent) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f) +FetchContent_MakeAvailable(highway) +FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) +FetchContent_MakeAvailable(sentencepiece) + + + +# Allow for both local and remote building) +option(BUILD_MODE "'local' or 'remote' git fetch for builds") +if (NOT BUILD_MODE) + set(BUILD_MODE "remote") +endif() +if (BUILD_MODE STREQUAL "local") + # Relative path to gemma.cpp from examples/hello_world/build/ + FetchContent_Declare(gemma SOURCE_DIR ../../..) +else() + FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 8c7b2cf61b9794b806de091685dc6739dd3db837) +endif() +FetchContent_MakeAvailable(gemma) + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release") +endif() + +add_executable(hello_world run.cc) +target_link_libraries(hello_world hwy hwy_contrib sentencepiece libgemma) +FetchContent_GetProperties(sentencepiece) +target_include_directories(hello_world PRIVATE ${sentencepiece_SOURCE_DIR}) +target_compile_definitions(hello_world PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) +target_compile_options(hello_world PRIVATE $<$:-Wno-deprecated-declarations>) diff --git a/examples/hello_world/README.md b/examples/hello_world/README.md new file mode 100644 index 0000000..63c319e --- /dev/null +++ b/examples/hello_world/README.md @@ -0,0 +1,51 @@ +# Hello World Example + +This is a minimal/template project for using `gemma.cpp` as a library. Instead +of an interactive interface, it sets up the model state and generates text for a +single hard coded prompt. + +Build steps are similar to the main `gemma` executable. For now only +`cmake`/`make` is available for builds (PRs welcome for other build options). + +First use `cmake` to configure the project, starting from the `hello_world` +example directory (`gemma.cpp/examples/hello_world`): + +```sh +cmake -B build +``` + +This sets up a build configuration in `gemma.cpp/examples/hello_world/build`. +Note that this fetches `libgemma` from a git commit hash on github. +Alternatively if you want to build using the local version of `gemma.cpp` use: + +```sh +cmake -B build -DBUILD_MODE=local +``` + +Make sure you delete the contents of the build directory before changing +configurations. + +Then use `make` to build the project: + +```sh +cd build +make hello_world +``` + +As with the top-level `gemma.cpp` project you can use the `make` commands `-j` +flag to use parallel threads for faster builds. + +From inside the `gemma.cpp/examples/hello_world/build` directory, there should +be a `hello_world` executable. You can run it with the same 3 model arguments as +gemma.cpp specifying the tokenizer, compressed weights file, and model type, for +example: + +```sh +./hello_world --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it +``` + +Should print a greeting to the terminal: + +``` +"Hello, world! It's a pleasure to greet you all. May your day be filled with joy, peace, and all the things that make your heart soar. +``` diff --git a/examples/hello_world/build/.gitignore b/examples/hello_world/build/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/examples/hello_world/build/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc new file mode 100644 index 0000000..8ec784f --- /dev/null +++ b/examples/hello_world/run.cc @@ -0,0 +1,83 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +// copybara:import_next_line:gemma_cpp +#include "gemma.h" +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/args.h" +// copybara:end +#include "hwy/contrib/thread_pool/thread_pool.h" + +std::vector tokenize( + const std::string& prompt_string, + const sentencepiece::SentencePieceProcessor* tokenizer) { + std::string formatted = "user\n" + prompt_string + + "\nmodel\n"; + std::vector tokens; + HWY_ASSERT(tokenizer->Encode(formatted, &tokens).ok()); + tokens.insert(tokens.begin(), 2); // BOS token + return tokens; +} + +int main(int argc, char** argv) { + gcpp::LoaderArgs loader(argc, argv); + + // Rough heuristic for the number of threads to use + size_t num_threads = static_cast(std::clamp( + static_cast(std::thread::hardware_concurrency()) - 2, 1, 18)); + hwy::ThreadPool pool(num_threads); + + // Instantiate model and KV Cache + gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, + loader.ModelType(), pool); + auto kv_cache = CreateKVCache(loader.ModelType()); + size_t pos = 0; // KV Cache position + + // Initialize random number generator + std::mt19937 gen; + std::random_device rd; + gen.seed(rd()); + + // Tokenize instruction + std::vector tokens = + tokenize("Write a greeting to the world.", model.Tokenizer()); + size_t ntokens = tokens.size(); + + // This callback function gets invoked everytime a token is generated + auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()]( + int token, float) { + ++pos; + if (pos < ntokens) { + // print feedback + } else if (token != gcpp::EOS_ID) { + std::string token_text; + HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text).ok()); + std::cout << token_text << std::flush; + } + return true; + }; + + GenerateGemma(model, + {.max_tokens = 2048, + .max_generated_tokens = 1024, + .temperature = 1.0, + .verbosity = 0}, + tokens, /*KV cache position = */ 0, kv_cache, pool, + stream_token, gen); + std::cout << std::endl; +} diff --git a/gemma.cc b/gemma.cc index eb43e81..4fe2782 100644 --- a/gemma.cc +++ b/gemma.cc @@ -25,8 +25,6 @@ #include "compression/compress-inl.h" // copybara:import_next_line:gemma_cpp #include "ops.h" -// copybara:import_next_line:gemma_cpp -#include "util/args.h" // Path #include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/highway.h" #include "hwy/profiler.h" @@ -231,20 +229,39 @@ struct Activations { struct GemmaInterface { virtual ~GemmaInterface() = default; - virtual const sentencepiece::SentencePieceProcessor& Tokenizer() const = 0; + virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0; - // TODO: group pool/callbacks into struct - virtual void Generate(const InferenceArgs& args, - const std::vector& prompt, size_t start_pos, + virtual void Generate(size_t max_tokens, size_t max_generated_tokens, + float temperature, const std::vector& prompt, + size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) = 0; }; +template +KVCache CreateKVCache() { + return CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim, + Config::kSeqLen); +} + +KVCache CreateKVCache(Model type) { + switch (type) { + case Model::GEMMA_2B: + return CreateKVCache(); + case Model::GEMMA_7B: + return CreateKVCache(); + default: + HWY_ABORT("Model type %d unknown.", static_cast(type)); + } +} + template struct GemmaImpl : public GemmaInterface { - GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool); + GemmaImpl(std::unique_ptr& tokenizer, + hwy::AlignedFreeUniquePtr& compressed_weights, + hwy::ThreadPool& pool); ~GemmaImpl() { using CWeights = CompressedWeights; @@ -252,22 +269,21 @@ struct GemmaImpl : public GemmaInterface { c_weights->c_layer_ptrs.~CompressedLayerPointers(); } - const sentencepiece::SentencePieceProcessor& Tokenizer() const { - return tokenizer; + const sentencepiece::SentencePieceProcessor* Tokenizer() const override { + return tokenizer.get(); } - void Generate(const InferenceArgs& args, const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, + void Generate(size_t max_tokens, size_t max_generated_tokens, + float temperature, const std::vector& prompt, + size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937&, int verbosity); + const AcceptFunc& accept_token, std::mt19937&, + int verbosity) override; - sentencepiece::SentencePieceProcessor tokenizer; - - // CompressedWeights + std::unique_ptr tokenizer; hwy::AlignedFreeUniquePtr compressed_weights; hwy::AlignedUniquePtr> prefill; hwy::AlignedUniquePtr> state; - KVCache kv_cache; }; } // namespace gcpp @@ -294,7 +310,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, static constexpr size_t kModelDim = gcpp::Activations::kModelDim; static constexpr size_t kHeads = TConfig::kHeads; - const float kQueryScale = 1.0 / sqrtf(static_cast(kQKVDim)); + static const float kQueryScale = + static_cast(1.0 / sqrt(static_cast(kQKVDim))); pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { // linear projections to QKV @@ -417,7 +434,8 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, hwy::ThreadPool& inner_pool) { PROFILER_ZONE("Gen.Prefill\\Att\\FFW"); static constexpr size_t kModelDim = TConfig::kModelDim; - static const float kEmbScaling = sqrtf(static_cast(kModelDim)); + static const float kEmbScaling = + static_cast(sqrt(static_cast(kModelDim))); pool.Run( 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { @@ -472,7 +490,8 @@ void Transformer(int token, size_t pos, static constexpr size_t kLayers = TConfig::kLayers; static constexpr size_t kModelDim = TConfig::kModelDim; - static const float kEmbScaling = sqrtf(static_cast(kModelDim)); + static const float kEmbScaling = + static_cast(sqrt(static_cast(kModelDim))); Decompress(c_weights.c_embedder_input_embedding, token * kModelDim, activations.x.data(), kModelDim); @@ -495,8 +514,9 @@ void Transformer(int token, size_t pos, } template -void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, - const std::vector& prompt, size_t pos, +void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, + size_t max_generated_tokens, float temperature, + const std::vector& prompt, size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, @@ -510,7 +530,6 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, const CompressedWeights& c_weights = *reinterpret_cast*>( gemma.compressed_weights.get()); - KVCache& kv_cache = gemma.kv_cache; int token; // pos indexes the KV cache. In the first turn of a chat, pos = 0. @@ -548,8 +567,9 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, // in the future this output should not occur in GenerateImpl but instead // should be available as observable state for frontend code to handle I/O. const double prefill_end = hwy::platform::Now(); - const double prefill_tok_sec = static_cast(pos_offset) / (prefill_end - prefill_start); - std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]\n"; + const double prefill_tok_sec = + static_cast(pos_offset) / (prefill_end - prefill_start); + std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]"; } const double gen_start = hwy::platform::Now(); @@ -558,10 +578,10 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, if (verbosity >= 2) { // Provide usage warnings if max_new_tokens is out of range. - if (args.max_generated_tokens > args.max_tokens) { + if (max_generated_tokens > max_tokens) { std::cout << "Warning: max_new_tokens should be <= max_tokens" << std::endl; - } else if ((prompt.size() + args.max_generated_tokens) > args.max_tokens) { + } else if ((prompt.size() + max_generated_tokens) > max_tokens) { std::cout << "Warning: Prompt size + max_new_tokens exceeds max_tokens." << std::endl; } @@ -570,7 +590,7 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, auto pos_gen_start = pos_offset; token = prompt.at(pos_offset); size_t generate_pos = 0; - for (; pos < args.max_tokens && generate_pos < args.max_generated_tokens; + for (; pos < max_tokens && generate_pos < max_generated_tokens; ++pos, ++pos_offset, ++generate_pos) { Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool); float* final_activation = activations.x.data(); @@ -583,7 +603,7 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, // Barrier: must have all logits so we can subtract max. Softmax(activations.logits.data(), kVocabSize); token = SampleTopK(activations.logits.data(), kVocabSize, gen, - args.temperature, accept_token); + temperature, accept_token); } if (!stream_token(token, activations.logits[token])) { token = EOS_ID; @@ -592,7 +612,8 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, if (verbosity >= 2) { const double gen_end = hwy::platform::Now(); const double gen_tok_sec = - static_cast(pos_offset - pos_gen_start) / (gen_end - gen_start); + static_cast(pos_offset - pos_gen_start) / + (gen_end - gen_start); std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n"; } break; @@ -600,21 +621,27 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, } } -void Generate2B(GemmaImpl& gemma, const InferenceArgs& args, +void Generate2B(GemmaImpl& gemma, size_t max_tokens, + size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { - GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token, + KVCache& kv_cache, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, + int verbosity) { + GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, + start_pos, kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); } -void Generate7B(GemmaImpl& gemma, const InferenceArgs& args, +void Generate7B(GemmaImpl& gemma, size_t max_tokens, + size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { - GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token, + KVCache& kv_cache, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, + int verbosity) { + GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, + start_pos, kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); } @@ -666,10 +693,10 @@ void ForEachTensor(const Weights* weights, template hwy::AlignedFreeUniquePtr GetCompressedWeights( - const Path& model, const Path& cache, hwy::ThreadPool& pool) { + const Path& weights_path, const Path& cache, hwy::ThreadPool& pool) { PROFILER_ZONE("Startup.LoadCache"); - if (!std::filesystem::exists(model.path) && + if (!std::filesystem::exists(weights_path.path) && !std::filesystem::exists(cache.path)) { HWY_ABORT( "Either the model weights (--weights) or cached compressed weights " @@ -689,7 +716,8 @@ hwy::AlignedFreeUniquePtr GetCompressedWeights( if (loader.ReadAll(pool)) return c_weights_u8; // Get weights, compress, and store in cache. - const hwy::AlignedUniquePtr> weights = LoadWeights(model); + const hwy::AlignedUniquePtr> weights = + LoadWeights(weights_path); Compressor compressor(pool); ForEachTensor(weights.get(), *c_weights, compressor); compressor.WriteAll(pool, cache.path.c_str()); @@ -699,14 +727,17 @@ hwy::AlignedFreeUniquePtr GetCompressedWeights( // Type-erased because this function is called via a function pointer. hwy::AlignedFreeUniquePtr GetCompressedWeightsT( - const LoaderArgs& args, hwy::ThreadPool& pool) { - switch (args.ModelType()) { + gcpp::Model model, const Path& weights, const Path& compressed_weights, + hwy::ThreadPool& pool) { + switch (model) { case Model::GEMMA_2B: - return GetCompressedWeights(args.model, args.cache, pool); + return GetCompressedWeights(weights, compressed_weights, + pool); case Model::GEMMA_7B: - return GetCompressedWeights(args.model, args.cache, pool); + return GetCompressedWeights(weights, compressed_weights, + pool); default: - HWY_ABORT("Model type %d unknown.", static_cast(args.ModelType())); + HWY_ABORT("Model type %d unknown.", static_cast(model)); } } @@ -729,75 +760,99 @@ KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) { } template -GemmaImpl::GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool) - : compressed_weights( - HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)), +GemmaImpl::GemmaImpl( + std::unique_ptr& tokenizer, + hwy::AlignedFreeUniquePtr& compressed_weights, + hwy::ThreadPool& pool) + : tokenizer(std::move(tokenizer)), + compressed_weights(std::move(compressed_weights)), prefill(hwy::MakeUniqueAligned>()), - state(hwy::MakeUniqueAligned>()), - kv_cache( - CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim, - Config::kSeqLen)) { - PROFILER_ZONE("Startup.tokenizer"); - - HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok()); -} + state(hwy::MakeUniqueAligned>()) {} template <> -void GemmaImpl::Generate(const InferenceArgs& args, - const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, - const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { +void GemmaImpl::Generate( + size_t max_tokens, size_t max_generated_tokens, float temperature, + const std::vector& prompt, size_t start_pos, KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, const AcceptFunc& accept_token, + std::mt19937& gen, int verbosity) { HWY_DYNAMIC_DISPATCH(Generate2B) - (*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token, - gen, verbosity); + (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, + kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); } template <> -void GemmaImpl::Generate(const InferenceArgs& args, - const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, - const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { +void GemmaImpl::Generate( + size_t max_tokens, size_t max_generated_tokens, float temperature, + const std::vector& prompt, size_t start_pos, KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, const AcceptFunc& accept_token, + std::mt19937& gen, int verbosity) { HWY_DYNAMIC_DISPATCH(Generate7B) - (*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token, - gen, verbosity); + (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, + kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); } -Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { - const Model model_type = args.ModelType(); - model_training = args.ModelTraining(); +Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + const Path& weights_path, Model model_type, + hwy::ThreadPool& pool) { + std::unique_ptr tokenizer; + { + PROFILER_ZONE("Startup.tokenizer"); + tokenizer = std::make_unique(); + if (!tokenizer->Load(tokenizer_path.path).ok()) { + HWY_ABORT("Failed to load the tokenizer file."); + } + } + auto compressed_weights = HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)( + model_type, weights_path, compressed_weights_path, pool); switch (model_type) { case Model::GEMMA_2B: - impl_.reset(new GemmaImpl(args, pool)); + impl_.reset( + new GemmaImpl(tokenizer, compressed_weights, pool)); break; case Model::GEMMA_7B: - impl_.reset(new GemmaImpl(args, pool)); + impl_.reset( + new GemmaImpl(tokenizer, compressed_weights, pool)); break; default: HWY_ABORT("Model type %d unknown.", static_cast(model_type)); } } + +Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + Model model_type, hwy::ThreadPool& pool) + : Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type, + pool) {} + Gemma::~Gemma() = default; // after GemmaInterface is defined -const sentencepiece::SentencePieceProcessor& Gemma::Tokenizer() const { +const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { return impl_->Tokenizer(); } -void GenerateGemma(Gemma& gemma, const InferenceArgs& args, - const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, +void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, + float temperature, const std::vector& prompt, + size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { pool.SetWaitMode(hwy::PoolWaitMode::kSpin); - gemma.impl_->Generate(args, prompt, start_pos, pool, inner_pool, stream_token, + gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt, + start_pos, kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); pool.SetWaitMode(hwy::PoolWaitMode::kBlock); } +void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, + const std::vector& prompt, size_t start_pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + const StreamFunc& stream_token, std::mt19937& gen) { + hwy::ThreadPool inner_pool(0); + GenerateGemma( + gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens, + runtime_config.temperature, prompt, start_pos, kv_cache, pool, inner_pool, + stream_token, [](int) { return true; }, gen, runtime_config.verbosity); +} + } // namespace gcpp #endif // HWY_ONCE diff --git a/gemma.h b/gemma.h index 7195bc9..48fb52b 100644 --- a/gemma.h +++ b/gemma.h @@ -64,6 +64,13 @@ struct KVCache { enum class Model { GEMMA_2B, GEMMA_7B }; enum class ModelTraining { GEMMA_IT, GEMMA_PT }; +struct RuntimeConfig { + size_t max_tokens; + size_t max_generated_tokens; + float temperature; + int verbosity; +}; + struct LoaderArgs : public ArgsBase { LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } @@ -95,6 +102,10 @@ struct LoaderArgs : public ArgsBase { // Returns error string or nullptr if OK. const char* Validate() const { const std::string model_type_lc = ToLower(model_type); + if (model_type.empty()) { + return "Missing --model flag, need to specify either 2b-pt, 7b-pt, " + "2b-it, or 7b-it."; + } if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" && model_type_lc != "2b-it" && model_type_lc != "7b-it") { return "Model type must be 2b-pt, 7b-pt, 2b-it, or " @@ -103,11 +114,7 @@ struct LoaderArgs : public ArgsBase { if (tokenizer.path.empty()) { return "Missing --tokenizer flag, a file for the tokenizer is required."; } - if (model_type.empty()) { - return "Missing --model flag, need to specify either 2b-pt, 7b-pt, " - "2b-it, or 7b-it."; - } - if (cache.path.empty()) { + if (compressed_weights.path.empty()) { return "Missing --compressed_weights flag, a file for the compressed " "model."; } @@ -115,8 +122,8 @@ struct LoaderArgs : public ArgsBase { } Path tokenizer; - Path model; // uncompressed weights OR - Path cache; // compressed weights + Path weights; // uncompressed weights file location + Path compressed_weights; // compressed weights file location std::string model_type; template @@ -124,16 +131,16 @@ struct LoaderArgs : public ArgsBase { visitor(tokenizer, "tokenizer", Path(), "Path name of tokenizer model file.\n Required argument."); visitor( - cache, "compressed_weights", Path(), + compressed_weights, "compressed_weights", Path(), "Path name of compressed weights file, regenerated from `--weights` " "file if " "the compressed weights file does not exist.\n Required argument."); visitor(model_type, "model", std::string(), - "Model type\n 2b-it (2B parameters, instruction-tuned)\n " - "2b-pt (2B parameters, pretrained)\n 7b-it (7B parameters " - "instruction-tuned)\n 7b-pt (7B parameters, pretrained)\n" + "Model type\n 2b-it = 2B parameters, instruction-tuned\n " + "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " + "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n" " Required argument."); - visitor(model, "weights", Path(), + visitor(weights, "weights", Path(), "Path name of model weights (.sbs) file. Only required if " "compressed_weights file is not present and needs to be " "regenerated. This parameter is only required for compressing" @@ -141,23 +148,6 @@ struct LoaderArgs : public ArgsBase { } }; -struct GemmaInterface; - -struct Gemma { - Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); - ~Gemma(); // must be defined after GemmaInterface's dtor is defined. - - const sentencepiece::SentencePieceProcessor& Tokenizer() const; - - std::unique_ptr impl_; - gcpp::ModelTraining model_training; -}; - -// StreamFunc is called with (token, probability). For prompt tokens, -// probability is 0.0f. -using StreamFunc = std::function; -using AcceptFunc = std::function; - struct InferenceArgs : public ArgsBase { InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } @@ -192,19 +182,50 @@ struct InferenceArgs : public ArgsBase { visitor(deterministic, "deterministic", false, "Make top-k sampling deterministic", 2); visitor(multiturn, "multiturn", false, - "Multiturn mode (if 0, this clears the KV cache after every " - "interaction without quitting)\n Default : 0 (conversation " + "Multiturn mode\n 0 = clear KV cache after every " + "interaction\n 1 = continue KV cache after every interaction\n " + " Default : 0 (conversation " "resets every turn)"); } }; -void GenerateGemma(Gemma& gemma, const InferenceArgs& args, - const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& g, +struct GemmaInterface; + +struct Gemma { + Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + const Path& weights_path, Model model_type, hwy::ThreadPool& pool); + Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + Model model_type, hwy::ThreadPool& pool); + ~Gemma(); // must be defined after GemmaInterface's dtor is defined. + const sentencepiece::SentencePieceProcessor* Tokenizer() const; + std::unique_ptr impl_; + gcpp::ModelTraining model_training; +}; + +KVCache CreateKVCache(Model type); // convenient workaround for now +KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len); + +// StreamFunc is called with (token, probability). For prompt tokens, +// probability is 0.0f. +using StreamFunc = std::function; +using AcceptFunc = std::function; + +void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, + float temperature, const std::vector& prompt, + size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, int verbosity); +// Convenience function for the common case: +// - Bundle runtime parameters as RuntimeConfig +// - No threadpools within threadpools (inner_pool = dummy) +// - All tokens accepted +void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, + const std::vector& prompt, size_t start_pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + const StreamFunc& stream_token, std::mt19937& gen); + constexpr int EOS_ID = 1; } // namespace gcpp diff --git a/models/.gitignore b/models/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/run.cc b/run.cc index 610b824..8bc0910 100644 --- a/run.cc +++ b/run.cc @@ -96,10 +96,10 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, std::cerr << "\n"; } -void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const InferenceArgs& args, - int verbosity, const gcpp::AcceptFunc& accept_token, - std::string& eot_line) { +void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const InferenceArgs& args, int verbosity, + const gcpp::AcceptFunc& accept_token, std::string& eot_line) { PROFILER_ZONE("Gen.misc"); int abs_pos = 0; // absolute token index over all turns int current_pos = 0; // token index within the current turn @@ -115,11 +115,11 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, // callback function invoked for each generated token. auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size, - tokenizer = &model.Tokenizer(), + tokenizer = model.Tokenizer(), verbosity](int token, float) { ++abs_pos; ++current_pos; - if (current_pos < prompt_size) { + if (current_pos <= prompt_size) { std::cerr << "." << std::flush; } else if (token == gcpp::EOS_ID) { if (!args.multiturn) { @@ -129,7 +129,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, } } if (verbosity >= 2) { - std::cout << "\n[ End ]" << std::endl; + std::cout << "\n[ End ]\n"; } } else { std::string token_text; @@ -142,7 +142,6 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, std::cout << std::endl << std::endl; } } - // TODO(austinvhuang): is explicit space necessary? std::cout << token_text << std::flush; } return true; @@ -191,7 +190,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, } } - HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok()); + HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt).ok()); // For both pre-trained and instruction-tuned models: prepend "" token // if needed. @@ -204,8 +203,9 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, std::cerr << std::endl << "[ Reading prompt ] " << std::flush; const double time_start = hwy::platform::Now(); - GenerateGemma(model, args, prompt, abs_pos, pool, inner_pool, stream_token, - accept_token, gen, verbosity); + GenerateGemma(model, args.max_tokens, args.max_generated_tokens, + args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool, + stream_token, accept_token, gen, verbosity); const double time_end = hwy::platform::Now(); const double tok_sec = current_pos / (time_end - time_start); if (verbosity >= 2) { @@ -234,7 +234,10 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); } - gcpp::Gemma model(loader, pool); + gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, + loader.ModelType(), pool); + + auto kv_cache = CreateKVCache(loader.ModelType()); if (const char* error = inference.Validate()) { ShowHelp(loader, inference, app); @@ -272,7 +275,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } ReplGemma( - model, pool, inner_pool, inference, app.verbosity, + model, kv_cache, pool, inner_pool, inference, app.verbosity, /*accept_token=*/[](int) { return true; }, app.eot_line); } diff --git a/util/app.h b/util/app.h index 79956be..ea45818 100644 --- a/util/app.h +++ b/util/app.h @@ -97,9 +97,9 @@ class AppArgs : public ArgsBase { template void ForEach(const Visitor& visitor) { visitor(verbosity, "verbosity", 1, - "Show verbose developer information\n 0 = only print generation " - "output\n 1 = standard user-facing terminal ui\n 2 = show " - "developer/debug info).\n Default = 1.", + "Show verbose developer information\n 0 = only print generation " + "output\n 1 = standard user-facing terminal ui\n 2 = show " + "developer/debug info).\n Default = 1.", 2); visitor(num_threads, "num_threads", kDefaultNumThreads, // see ChooseNumThreads