Merge pull request #82 from google:examples

PiperOrigin-RevId: 615066980
This commit is contained in:
Copybara-Service 2024-03-12 09:24:24 -07:00
commit ccd055e06b
10 changed files with 405 additions and 139 deletions

View File

@ -69,6 +69,7 @@ cc_library(
deps = [ deps = [
":args", ":args",
":transformer_ops", ":transformer_ops",
"//base",
"//compression:compress", "//compression:compress",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:matvec", "@hwy//:matvec",
@ -88,6 +89,7 @@ cc_binary(
":app", ":app",
":args", ":args",
":gemma_lib", ":gemma_lib",
"//base",
"//compression:compress", "//compression:compress",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:nanobenchmark", "@hwy//:nanobenchmark",

View File

@ -0,0 +1,49 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cmake_minimum_required(VERSION 3.11)
project(hello_world)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
include(FetchContent)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
FetchContent_MakeAvailable(highway)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
FetchContent_MakeAvailable(sentencepiece)
# Allow for both local and remote building)
option(BUILD_MODE "'local' or 'remote' git fetch for builds")
if (NOT BUILD_MODE)
set(BUILD_MODE "remote")
endif()
if (BUILD_MODE STREQUAL "local")
# Relative path to gemma.cpp from examples/hello_world/build/
FetchContent_Declare(gemma SOURCE_DIR ../../..)
else()
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 8c7b2cf61b9794b806de091685dc6739dd3db837)
endif()
FetchContent_MakeAvailable(gemma)
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release")
endif()
add_executable(hello_world run.cc)
target_link_libraries(hello_world hwy hwy_contrib sentencepiece libgemma)
FetchContent_GetProperties(sentencepiece)
target_include_directories(hello_world PRIVATE ${sentencepiece_SOURCE_DIR})
target_compile_definitions(hello_world PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
target_compile_options(hello_world PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)

View File

@ -0,0 +1,51 @@
# Hello World Example
This is a minimal/template project for using `gemma.cpp` as a library. Instead
of an interactive interface, it sets up the model state and generates text for a
single hard coded prompt.
Build steps are similar to the main `gemma` executable. For now only
`cmake`/`make` is available for builds (PRs welcome for other build options).
First use `cmake` to configure the project, starting from the `hello_world`
example directory (`gemma.cpp/examples/hello_world`):
```sh
cmake -B build
```
This sets up a build configuration in `gemma.cpp/examples/hello_world/build`.
Note that this fetches `libgemma` from a git commit hash on github.
Alternatively if you want to build using the local version of `gemma.cpp` use:
```sh
cmake -B build -DBUILD_MODE=local
```
Make sure you delete the contents of the build directory before changing
configurations.
Then use `make` to build the project:
```sh
cd build
make hello_world
```
As with the top-level `gemma.cpp` project you can use the `make` commands `-j`
flag to use parallel threads for faster builds.
From inside the `gemma.cpp/examples/hello_world/build` directory, there should
be a `hello_world` executable. You can run it with the same 3 model arguments as
gemma.cpp specifying the tokenizer, compressed weights file, and model type, for
example:
```sh
./hello_world --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it
```
Should print a greeting to the terminal:
```
"Hello, world! It's a pleasure to greet you all. May your day be filled with joy, peace, and all the things that make your heart soar.
```

2
examples/hello_world/build/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
*
!.gitignore

View File

@ -0,0 +1,83 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
// copybara:import_next_line:gemma_cpp
#include "gemma.h"
// copybara:end
// copybara:import_next_line:gemma_cpp
#include "util/args.h"
// copybara:end
#include "hwy/contrib/thread_pool/thread_pool.h"
std::vector<int> tokenize(
const std::string& prompt_string,
const sentencepiece::SentencePieceProcessor* tokenizer) {
std::string formatted = "<start_of_turn>user\n" + prompt_string +
"<end_of_turn>\n<start_of_turn>model\n";
std::vector<int> tokens;
HWY_ASSERT(tokenizer->Encode(formatted, &tokens).ok());
tokens.insert(tokens.begin(), 2); // BOS token
return tokens;
}
int main(int argc, char** argv) {
gcpp::LoaderArgs loader(argc, argv);
// Rough heuristic for the number of threads to use
size_t num_threads = static_cast<size_t>(std::clamp(
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
hwy::ThreadPool pool(num_threads);
// Instantiate model and KV Cache
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights,
loader.ModelType(), pool);
auto kv_cache = CreateKVCache(loader.ModelType());
size_t pos = 0; // KV Cache position
// Initialize random number generator
std::mt19937 gen;
std::random_device rd;
gen.seed(rd());
// Tokenize instruction
std::vector<int> tokens =
tokenize("Write a greeting to the world.", model.Tokenizer());
size_t ntokens = tokens.size();
// This callback function gets invoked everytime a token is generated
auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()](
int token, float) {
++pos;
if (pos < ntokens) {
// print feedback
} else if (token != gcpp::EOS_ID) {
std::string token_text;
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text).ok());
std::cout << token_text << std::flush;
}
return true;
};
GenerateGemma(model,
{.max_tokens = 2048,
.max_generated_tokens = 1024,
.temperature = 1.0,
.verbosity = 0},
tokens, /*KV cache position = */ 0, kv_cache, pool,
stream_token, gen);
std::cout << std::endl;
}

225
gemma.cc
View File

@ -25,8 +25,6 @@
#include "compression/compress-inl.h" #include "compression/compress-inl.h"
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "ops.h" #include "ops.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // Path
#include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/contrib/matvec/matvec-inl.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
@ -231,20 +229,39 @@ struct Activations {
struct GemmaInterface { struct GemmaInterface {
virtual ~GemmaInterface() = default; virtual ~GemmaInterface() = default;
virtual const sentencepiece::SentencePieceProcessor& Tokenizer() const = 0; virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0;
// TODO: group pool/callbacks into struct virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
virtual void Generate(const InferenceArgs& args, float temperature, const std::vector<int>& prompt,
const std::vector<int>& prompt, size_t start_pos, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) = 0; int verbosity) = 0;
}; };
template <class Config>
KVCache CreateKVCache() {
return CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
Config::kSeqLen);
}
KVCache CreateKVCache(Model type) {
switch (type) {
case Model::GEMMA_2B:
return CreateKVCache<ConfigGemma2B>();
case Model::GEMMA_7B:
return CreateKVCache<ConfigGemma7B>();
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(type));
}
}
template <class Config> template <class Config>
struct GemmaImpl : public GemmaInterface { struct GemmaImpl : public GemmaInterface {
GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool); GemmaImpl(std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
hwy::AlignedFreeUniquePtr<uint8_t[]>& compressed_weights,
hwy::ThreadPool& pool);
~GemmaImpl() { ~GemmaImpl() {
using CWeights = CompressedWeights<Config>; using CWeights = CompressedWeights<Config>;
@ -252,22 +269,21 @@ struct GemmaImpl : public GemmaInterface {
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>(); c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
} }
const sentencepiece::SentencePieceProcessor& Tokenizer() const { const sentencepiece::SentencePieceProcessor* Tokenizer() const override {
return tokenizer; return tokenizer.get();
} }
void Generate(const InferenceArgs& args, const std::vector<int>& prompt, void Generate(size_t max_tokens, size_t max_generated_tokens,
size_t start_pos, hwy::ThreadPool& pool, float temperature, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937&, int verbosity); const AcceptFunc& accept_token, std::mt19937&,
int verbosity) override;
sentencepiece::SentencePieceProcessor tokenizer; std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
// CompressedWeights<Config>
hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights; hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights;
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill; hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
hwy::AlignedUniquePtr<Activations<Config, 1>> state; hwy::AlignedUniquePtr<Activations<Config, 1>> state;
KVCache kv_cache;
}; };
} // namespace gcpp } // namespace gcpp
@ -294,7 +310,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
static constexpr size_t kModelDim = static constexpr size_t kModelDim =
gcpp::Activations<TConfig, kBatchSize>::kModelDim; gcpp::Activations<TConfig, kBatchSize>::kModelDim;
static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kHeads = TConfig::kHeads;
const float kQueryScale = 1.0 / sqrtf(static_cast<float>(kQKVDim)); static const float kQueryScale =
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
// linear projections to QKV // linear projections to QKV
@ -417,7 +434,8 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
hwy::ThreadPool& inner_pool) { hwy::ThreadPool& inner_pool) {
PROFILER_ZONE("Gen.Prefill\\Att\\FFW"); PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kModelDim = TConfig::kModelDim;
static const float kEmbScaling = sqrtf(static_cast<float>(kModelDim)); static const float kEmbScaling =
static_cast<float>(sqrt(static_cast<double>(kModelDim)));
pool.Run( pool.Run(
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
@ -472,7 +490,8 @@ void Transformer(int token, size_t pos,
static constexpr size_t kLayers = TConfig::kLayers; static constexpr size_t kLayers = TConfig::kLayers;
static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kModelDim = TConfig::kModelDim;
static const float kEmbScaling = sqrtf(static_cast<float>(kModelDim)); static const float kEmbScaling =
static_cast<float>(sqrt(static_cast<double>(kModelDim)));
Decompress(c_weights.c_embedder_input_embedding, token * kModelDim, Decompress(c_weights.c_embedder_input_embedding, token * kModelDim,
activations.x.data(), kModelDim); activations.x.data(), kModelDim);
@ -495,8 +514,9 @@ void Transformer(int token, size_t pos,
} }
template <class TConfig> template <class TConfig>
void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args, void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
const std::vector<int>& prompt, size_t pos, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
@ -510,7 +530,6 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
const CompressedWeights<TConfig>& c_weights = const CompressedWeights<TConfig>& c_weights =
*reinterpret_cast<CompressedWeights<TConfig>*>( *reinterpret_cast<CompressedWeights<TConfig>*>(
gemma.compressed_weights.get()); gemma.compressed_weights.get());
KVCache& kv_cache = gemma.kv_cache;
int token; int token;
// pos indexes the KV cache. In the first turn of a chat, pos = 0. // pos indexes the KV cache. In the first turn of a chat, pos = 0.
@ -548,8 +567,9 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
// in the future this output should not occur in GenerateImpl but instead // in the future this output should not occur in GenerateImpl but instead
// should be available as observable state for frontend code to handle I/O. // should be available as observable state for frontend code to handle I/O.
const double prefill_end = hwy::platform::Now(); const double prefill_end = hwy::platform::Now();
const double prefill_tok_sec = static_cast<double>(pos_offset) / (prefill_end - prefill_start); const double prefill_tok_sec =
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]\n"; static_cast<double>(pos_offset) / (prefill_end - prefill_start);
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]";
} }
const double gen_start = hwy::platform::Now(); const double gen_start = hwy::platform::Now();
@ -558,10 +578,10 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
if (verbosity >= 2) { if (verbosity >= 2) {
// Provide usage warnings if max_new_tokens is out of range. // Provide usage warnings if max_new_tokens is out of range.
if (args.max_generated_tokens > args.max_tokens) { if (max_generated_tokens > max_tokens) {
std::cout << "Warning: max_new_tokens should be <= max_tokens" std::cout << "Warning: max_new_tokens should be <= max_tokens"
<< std::endl; << std::endl;
} else if ((prompt.size() + args.max_generated_tokens) > args.max_tokens) { } else if ((prompt.size() + max_generated_tokens) > max_tokens) {
std::cout << "Warning: Prompt size + max_new_tokens exceeds max_tokens." std::cout << "Warning: Prompt size + max_new_tokens exceeds max_tokens."
<< std::endl; << std::endl;
} }
@ -570,7 +590,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
auto pos_gen_start = pos_offset; auto pos_gen_start = pos_offset;
token = prompt.at(pos_offset); token = prompt.at(pos_offset);
size_t generate_pos = 0; size_t generate_pos = 0;
for (; pos < args.max_tokens && generate_pos < args.max_generated_tokens; for (; pos < max_tokens && generate_pos < max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) { ++pos, ++pos_offset, ++generate_pos) {
Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool); Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool);
float* final_activation = activations.x.data(); float* final_activation = activations.x.data();
@ -583,7 +603,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
// Barrier: must have all logits so we can subtract max. // Barrier: must have all logits so we can subtract max.
Softmax(activations.logits.data(), kVocabSize); Softmax(activations.logits.data(), kVocabSize);
token = SampleTopK<kTopK>(activations.logits.data(), kVocabSize, gen, token = SampleTopK<kTopK>(activations.logits.data(), kVocabSize, gen,
args.temperature, accept_token); temperature, accept_token);
} }
if (!stream_token(token, activations.logits[token])) { if (!stream_token(token, activations.logits[token])) {
token = EOS_ID; token = EOS_ID;
@ -592,7 +612,8 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
if (verbosity >= 2) { if (verbosity >= 2) {
const double gen_end = hwy::platform::Now(); const double gen_end = hwy::platform::Now();
const double gen_tok_sec = const double gen_tok_sec =
static_cast<double>(pos_offset - pos_gen_start) / (gen_end - gen_start); static_cast<double>(pos_offset - pos_gen_start) /
(gen_end - gen_start);
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n"; std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
} }
break; break;
@ -600,21 +621,27 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
} }
} }
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, const InferenceArgs& args, void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
std::mt19937& gen, int verbosity) { const AcceptFunc& accept_token, std::mt19937& gen,
GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token, int verbosity) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, inner_pool, stream_token,
accept_token, gen, verbosity); accept_token, gen, verbosity);
} }
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, const InferenceArgs& args, void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
std::mt19937& gen, int verbosity) { const AcceptFunc& accept_token, std::mt19937& gen,
GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token, int verbosity) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, inner_pool, stream_token,
accept_token, gen, verbosity); accept_token, gen, verbosity);
} }
@ -666,10 +693,10 @@ void ForEachTensor(const Weights<TConfig>* weights,
template <class TConfig> template <class TConfig>
hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights( hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
const Path& model, const Path& cache, hwy::ThreadPool& pool) { const Path& weights_path, const Path& cache, hwy::ThreadPool& pool) {
PROFILER_ZONE("Startup.LoadCache"); PROFILER_ZONE("Startup.LoadCache");
if (!std::filesystem::exists(model.path) && if (!std::filesystem::exists(weights_path.path) &&
!std::filesystem::exists(cache.path)) { !std::filesystem::exists(cache.path)) {
HWY_ABORT( HWY_ABORT(
"Either the model weights (--weights) or cached compressed weights " "Either the model weights (--weights) or cached compressed weights "
@ -689,7 +716,8 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
if (loader.ReadAll(pool)) return c_weights_u8; if (loader.ReadAll(pool)) return c_weights_u8;
// Get weights, compress, and store in cache. // Get weights, compress, and store in cache.
const hwy::AlignedUniquePtr<Weights<TConfig>> weights = LoadWeights<TConfig>(model); const hwy::AlignedUniquePtr<Weights<TConfig>> weights =
LoadWeights<TConfig>(weights_path);
Compressor compressor(pool); Compressor compressor(pool);
ForEachTensor<TConfig>(weights.get(), *c_weights, compressor); ForEachTensor<TConfig>(weights.get(), *c_weights, compressor);
compressor.WriteAll(pool, cache.path.c_str()); compressor.WriteAll(pool, cache.path.c_str());
@ -699,14 +727,17 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
// Type-erased because this function is called via a function pointer. // Type-erased because this function is called via a function pointer.
hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeightsT( hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeightsT(
const LoaderArgs& args, hwy::ThreadPool& pool) { gcpp::Model model, const Path& weights, const Path& compressed_weights,
switch (args.ModelType()) { hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B: case Model::GEMMA_2B:
return GetCompressedWeights<ConfigGemma2B>(args.model, args.cache, pool); return GetCompressedWeights<ConfigGemma2B>(weights, compressed_weights,
pool);
case Model::GEMMA_7B: case Model::GEMMA_7B:
return GetCompressedWeights<ConfigGemma7B>(args.model, args.cache, pool); return GetCompressedWeights<ConfigGemma7B>(weights, compressed_weights,
pool);
default: default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(args.ModelType())); HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
} }
} }
@ -729,75 +760,99 @@ KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
} }
template <class Config> template <class Config>
GemmaImpl<Config>::GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool) GemmaImpl<Config>::GemmaImpl(
: compressed_weights( std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)), hwy::AlignedFreeUniquePtr<uint8_t[]>& compressed_weights,
hwy::ThreadPool& pool)
: tokenizer(std::move(tokenizer)),
compressed_weights(std::move(compressed_weights)),
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()), prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()), state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {}
kv_cache(
CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
Config::kSeqLen)) {
PROFILER_ZONE("Startup.tokenizer");
HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());
}
template <> template <>
void GemmaImpl<ConfigGemma2B>::Generate(const InferenceArgs& args, void GemmaImpl<ConfigGemma2B>::Generate(
const std::vector<int>& prompt, size_t max_tokens, size_t max_generated_tokens, float temperature,
size_t start_pos, hwy::ThreadPool& pool, const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const StreamFunc& stream_token, const AcceptFunc& accept_token,
const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) { std::mt19937& gen, int verbosity) {
HWY_DYNAMIC_DISPATCH(Generate2B) HWY_DYNAMIC_DISPATCH(Generate2B)
(*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token, (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
gen, verbosity); kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
} }
template <> template <>
void GemmaImpl<ConfigGemma7B>::Generate(const InferenceArgs& args, void GemmaImpl<ConfigGemma7B>::Generate(
const std::vector<int>& prompt, size_t max_tokens, size_t max_generated_tokens, float temperature,
size_t start_pos, hwy::ThreadPool& pool, const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const StreamFunc& stream_token, const AcceptFunc& accept_token,
const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) { std::mt19937& gen, int verbosity) {
HWY_DYNAMIC_DISPATCH(Generate7B) HWY_DYNAMIC_DISPATCH(Generate7B)
(*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token, (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
gen, verbosity); kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
} }
Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
const Model model_type = args.ModelType(); const Path& weights_path, Model model_type,
model_training = args.ModelTraining(); hwy::ThreadPool& pool) {
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
{
PROFILER_ZONE("Startup.tokenizer");
tokenizer = std::make_unique<sentencepiece::SentencePieceProcessor>();
if (!tokenizer->Load(tokenizer_path.path).ok()) {
HWY_ABORT("Failed to load the tokenizer file.");
}
}
auto compressed_weights = HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(
model_type, weights_path, compressed_weights_path, pool);
switch (model_type) { switch (model_type) {
case Model::GEMMA_2B: case Model::GEMMA_2B:
impl_.reset(new GemmaImpl<ConfigGemma2B>(args, pool)); impl_.reset(
new GemmaImpl<ConfigGemma2B>(tokenizer, compressed_weights, pool));
break; break;
case Model::GEMMA_7B: case Model::GEMMA_7B:
impl_.reset(new GemmaImpl<ConfigGemma7B>(args, pool)); impl_.reset(
new GemmaImpl<ConfigGemma7B>(tokenizer, compressed_weights, pool));
break; break;
default: default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type)); HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
} }
} }
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
Model model_type, hwy::ThreadPool& pool)
: Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type,
pool) {}
Gemma::~Gemma() = default; // after GemmaInterface is defined Gemma::~Gemma() = default; // after GemmaInterface is defined
const sentencepiece::SentencePieceProcessor& Gemma::Tokenizer() const { const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
return impl_->Tokenizer(); return impl_->Tokenizer();
} }
void GenerateGemma(Gemma& gemma, const InferenceArgs& args, void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
const std::vector<int>& prompt, size_t start_pos, float temperature, const std::vector<int>& prompt,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) { int verbosity) {
pool.SetWaitMode(hwy::PoolWaitMode::kSpin); pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
gemma.impl_->Generate(args, prompt, start_pos, pool, inner_pool, stream_token, gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, inner_pool, stream_token,
accept_token, gen, verbosity); accept_token, gen, verbosity);
pool.SetWaitMode(hwy::PoolWaitMode::kBlock); pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
} }
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, std::mt19937& gen) {
hwy::ThreadPool inner_pool(0);
GenerateGemma(
gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens,
runtime_config.temperature, prompt, start_pos, kv_cache, pool, inner_pool,
stream_token, [](int) { return true; }, gen, runtime_config.verbosity);
}
} // namespace gcpp } // namespace gcpp
#endif // HWY_ONCE #endif // HWY_ONCE

93
gemma.h
View File

@ -64,6 +64,13 @@ struct KVCache {
enum class Model { GEMMA_2B, GEMMA_7B }; enum class Model { GEMMA_2B, GEMMA_7B };
enum class ModelTraining { GEMMA_IT, GEMMA_PT }; enum class ModelTraining { GEMMA_IT, GEMMA_PT };
struct RuntimeConfig {
size_t max_tokens;
size_t max_generated_tokens;
float temperature;
int verbosity;
};
struct LoaderArgs : public ArgsBase<LoaderArgs> { struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
@ -95,6 +102,10 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
// Returns error string or nullptr if OK. // Returns error string or nullptr if OK.
const char* Validate() const { const char* Validate() const {
const std::string model_type_lc = ToLower(model_type); const std::string model_type_lc = ToLower(model_type);
if (model_type.empty()) {
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
"2b-it, or 7b-it.";
}
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" && if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
model_type_lc != "2b-it" && model_type_lc != "7b-it") { model_type_lc != "2b-it" && model_type_lc != "7b-it") {
return "Model type must be 2b-pt, 7b-pt, 2b-it, or " return "Model type must be 2b-pt, 7b-pt, 2b-it, or "
@ -103,11 +114,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
if (tokenizer.path.empty()) { if (tokenizer.path.empty()) {
return "Missing --tokenizer flag, a file for the tokenizer is required."; return "Missing --tokenizer flag, a file for the tokenizer is required.";
} }
if (model_type.empty()) { if (compressed_weights.path.empty()) {
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
"2b-it, or 7b-it.";
}
if (cache.path.empty()) {
return "Missing --compressed_weights flag, a file for the compressed " return "Missing --compressed_weights flag, a file for the compressed "
"model."; "model.";
} }
@ -115,8 +122,8 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
} }
Path tokenizer; Path tokenizer;
Path model; // uncompressed weights OR Path weights; // uncompressed weights file location
Path cache; // compressed weights Path compressed_weights; // compressed weights file location
std::string model_type; std::string model_type;
template <class Visitor> template <class Visitor>
@ -124,16 +131,16 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
visitor(tokenizer, "tokenizer", Path(), visitor(tokenizer, "tokenizer", Path(),
"Path name of tokenizer model file.\n Required argument."); "Path name of tokenizer model file.\n Required argument.");
visitor( visitor(
cache, "compressed_weights", Path(), compressed_weights, "compressed_weights", Path(),
"Path name of compressed weights file, regenerated from `--weights` " "Path name of compressed weights file, regenerated from `--weights` "
"file if " "file if "
"the compressed weights file does not exist.\n Required argument."); "the compressed weights file does not exist.\n Required argument.");
visitor(model_type, "model", std::string(), visitor(model_type, "model", std::string(),
"Model type\n 2b-it (2B parameters, instruction-tuned)\n " "Model type\n 2b-it = 2B parameters, instruction-tuned\n "
"2b-pt (2B parameters, pretrained)\n 7b-it (7B parameters " "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
"instruction-tuned)\n 7b-pt (7B parameters, pretrained)\n" "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n"
" Required argument."); " Required argument.");
visitor(model, "weights", Path(), visitor(weights, "weights", Path(),
"Path name of model weights (.sbs) file. Only required if " "Path name of model weights (.sbs) file. Only required if "
"compressed_weights file is not present and needs to be " "compressed_weights file is not present and needs to be "
"regenerated. This parameter is only required for compressing" "regenerated. This parameter is only required for compressing"
@ -141,23 +148,6 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
} }
}; };
struct GemmaInterface;
struct Gemma {
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
const sentencepiece::SentencePieceProcessor& Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_;
gcpp::ModelTraining model_training;
};
// StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f.
using StreamFunc = std::function<bool(int, float)>;
using AcceptFunc = std::function<bool(int)>;
struct InferenceArgs : public ArgsBase<InferenceArgs> { struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
@ -192,19 +182,50 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
visitor(deterministic, "deterministic", false, visitor(deterministic, "deterministic", false,
"Make top-k sampling deterministic", 2); "Make top-k sampling deterministic", 2);
visitor(multiturn, "multiturn", false, visitor(multiturn, "multiturn", false,
"Multiturn mode (if 0, this clears the KV cache after every " "Multiturn mode\n 0 = clear KV cache after every "
"interaction without quitting)\n Default : 0 (conversation " "interaction\n 1 = continue KV cache after every interaction\n "
" Default : 0 (conversation "
"resets every turn)"); "resets every turn)");
} }
}; };
void GenerateGemma(Gemma& gemma, const InferenceArgs& args, struct GemmaInterface;
const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, struct Gemma {
const StreamFunc& stream_token, Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
const AcceptFunc& accept_token, std::mt19937& g, const Path& weights_path, Model model_type, hwy::ThreadPool& pool);
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
Model model_type, hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_;
gcpp::ModelTraining model_training;
};
KVCache CreateKVCache(Model type); // convenient workaround for now
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len);
// StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f.
using StreamFunc = std::function<bool(int, float)>;
using AcceptFunc = std::function<bool(int)>;
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity); int verbosity);
// Convenience function for the common case:
// - Bundle runtime parameters as RuntimeConfig
// - No threadpools within threadpools (inner_pool = dummy)
// - All tokens accepted
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, std::mt19937& gen);
constexpr int EOS_ID = 1; constexpr int EOS_ID = 1;
} // namespace gcpp } // namespace gcpp

0
models/.gitignore vendored Normal file
View File

29
run.cc
View File

@ -96,10 +96,10 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
std::cerr << "\n"; std::cerr << "\n";
} }
void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
hwy::ThreadPool& inner_pool, const InferenceArgs& args, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
int verbosity, const gcpp::AcceptFunc& accept_token, const InferenceArgs& args, int verbosity,
std::string& eot_line) { const gcpp::AcceptFunc& accept_token, std::string& eot_line) {
PROFILER_ZONE("Gen.misc"); PROFILER_ZONE("Gen.misc");
int abs_pos = 0; // absolute token index over all turns int abs_pos = 0; // absolute token index over all turns
int current_pos = 0; // token index within the current turn int current_pos = 0; // token index within the current turn
@ -115,11 +115,11 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
// callback function invoked for each generated token. // callback function invoked for each generated token.
auto stream_token = [&abs_pos, &current_pos, &args, &gen, &prompt_size, auto stream_token = [&abs_pos, &current_pos, &args, &gen, &prompt_size,
tokenizer = &model.Tokenizer(), tokenizer = model.Tokenizer(),
verbosity](int token, float) { verbosity](int token, float) {
++abs_pos; ++abs_pos;
++current_pos; ++current_pos;
if (current_pos < prompt_size) { if (current_pos <= prompt_size) {
std::cerr << "." << std::flush; std::cerr << "." << std::flush;
} else if (token == gcpp::EOS_ID) { } else if (token == gcpp::EOS_ID) {
if (!args.multiturn) { if (!args.multiturn) {
@ -129,7 +129,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
} }
} }
if (verbosity >= 2) { if (verbosity >= 2) {
std::cout << "\n[ End ]" << std::endl; std::cout << "\n[ End ]\n";
} }
} else { } else {
std::string token_text; std::string token_text;
@ -142,7 +142,6 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
std::cout << std::endl << std::endl; std::cout << std::endl << std::endl;
} }
} }
// TODO(austinvhuang): is explicit space necessary?
std::cout << token_text << std::flush; std::cout << token_text << std::flush;
} }
return true; return true;
@ -191,7 +190,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
} }
} }
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok()); HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt).ok());
// For both pre-trained and instruction-tuned models: prepend "<bos>" token // For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed. // if needed.
@ -204,8 +203,9 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
std::cerr << std::endl << "[ Reading prompt ] " << std::flush; std::cerr << std::endl << "[ Reading prompt ] " << std::flush;
const double time_start = hwy::platform::Now(); const double time_start = hwy::platform::Now();
GenerateGemma(model, args, prompt, abs_pos, pool, inner_pool, stream_token, GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
accept_token, gen, verbosity); args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool,
stream_token, accept_token, gen, verbosity);
const double time_end = hwy::platform::Now(); const double time_end = hwy::platform::Now();
const double tok_sec = current_pos / (time_end - time_start); const double tok_sec = current_pos / (time_end - time_start);
if (verbosity >= 2) { if (verbosity >= 2) {
@ -234,7 +234,10 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
} }
gcpp::Gemma model(loader, pool); gcpp::Gemma model(loader.tokenizer, loader.compressed_weights,
loader.ModelType(), pool);
auto kv_cache = CreateKVCache(loader.ModelType());
if (const char* error = inference.Validate()) { if (const char* error = inference.Validate()) {
ShowHelp(loader, inference, app); ShowHelp(loader, inference, app);
@ -272,7 +275,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
} }
ReplGemma( ReplGemma(
model, pool, inner_pool, inference, app.verbosity, model, kv_cache, pool, inner_pool, inference, app.verbosity,
/*accept_token=*/[](int) { return true; }, app.eot_line); /*accept_token=*/[](int) { return true; }, app.eot_line);
} }