mirror of https://github.com/google/gemma.cpp.git
Merge branch 'google:main' into main
This commit is contained in:
commit
e0b912fc46
|
|
@ -1,4 +1,5 @@
|
|||
FormatStyle: file
|
||||
WarningsAsErrors: "*"
|
||||
Checks: "-*,\
|
||||
abseil-*,\
|
||||
-abseil-string-find-startswith,\
|
||||
|
|
@ -204,3 +205,6 @@ Checks: "-*,\
|
|||
-readability-uppercase-literal-suffix,\
|
||||
-readability-use-anyofallof
|
||||
"
|
||||
CheckOptions:
|
||||
- { key: readability-identifier-naming.ConstexprVariableCase, value: CamelCase }
|
||||
- { key: readability-identifier-naming.ConstexprVariablePrefix, value: k }
|
||||
|
|
|
|||
25
BUILD.bazel
25
BUILD.bazel
|
|
@ -46,17 +46,6 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "app",
|
||||
hdrs = [
|
||||
"util/app.h",
|
||||
],
|
||||
deps = [
|
||||
":args",
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gemma_lib",
|
||||
srcs = [
|
||||
|
|
@ -69,6 +58,7 @@ cc_library(
|
|||
deps = [
|
||||
":args",
|
||||
":transformer_ops",
|
||||
# "//base",
|
||||
"//compression:compress",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:matvec",
|
||||
|
|
@ -79,6 +69,18 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "app",
|
||||
hdrs = [
|
||||
"util/app.h",
|
||||
],
|
||||
deps = [
|
||||
":args",
|
||||
":gemma_lib",
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "gemma",
|
||||
srcs = [
|
||||
|
|
@ -88,6 +90,7 @@ cc_binary(
|
|||
":app",
|
||||
":args",
|
||||
":gemma_lib",
|
||||
# "//base",
|
||||
"//compression:compress",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:nanobenchmark",
|
||||
|
|
|
|||
|
|
@ -341,7 +341,7 @@ BlobError BlobReader::Open(const char* filename) {
|
|||
#endif
|
||||
if (fd_ < 0) return __LINE__;
|
||||
|
||||
#if HWY_OS_LINUX
|
||||
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
|
||||
// Doubles the readahead window, which seems slightly faster when cached.
|
||||
(void)posix_fadvise(fd_, 0, 0, POSIX_FADV_SEQUENTIAL);
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
cmake_minimum_required(VERSION 3.11)
|
||||
project(hello_world)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
||||
|
||||
|
||||
# Allow for both local and remote building)
|
||||
option(BUILD_MODE "'local' or 'remote' git fetch for builds")
|
||||
if (NOT BUILD_MODE)
|
||||
set(BUILD_MODE "remote")
|
||||
endif()
|
||||
if (BUILD_MODE STREQUAL "local")
|
||||
# Relative path to gemma.cpp from examples/hello_world/build/
|
||||
FetchContent_Declare(gemma SOURCE_DIR ../../..)
|
||||
else()
|
||||
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e)
|
||||
endif()
|
||||
FetchContent_MakeAvailable(gemma)
|
||||
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE "Release")
|
||||
endif()
|
||||
|
||||
add_executable(hello_world run.cc)
|
||||
target_link_libraries(hello_world hwy hwy_contrib sentencepiece libgemma)
|
||||
FetchContent_GetProperties(sentencepiece)
|
||||
target_include_directories(hello_world PRIVATE ${sentencepiece_SOURCE_DIR})
|
||||
target_compile_definitions(hello_world PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
|
||||
target_compile_options(hello_world PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
# Hello World Example
|
||||
|
||||
This is a minimal/template project for using `gemma.cpp` as a library. Instead
|
||||
of an interactive interface, it sets up the model state and generates text for a
|
||||
single hard coded prompt.
|
||||
|
||||
Build steps are similar to the main `gemma` executable. For now only
|
||||
`cmake`/`make` is available for builds (PRs welcome for other build options).
|
||||
|
||||
First use `cmake` to configure the project, starting from the `hello_world`
|
||||
example directory (`gemma.cpp/examples/hello_world`):
|
||||
|
||||
```sh
|
||||
cmake -B build
|
||||
```
|
||||
|
||||
This sets up a build configuration in `gemma.cpp/examples/hello_world/build`.
|
||||
Note that this fetches `libgemma` from a git commit hash on github.
|
||||
Alternatively if you want to build using the local version of `gemma.cpp` use:
|
||||
|
||||
```sh
|
||||
cmake -B build -DBUILD_MODE=local
|
||||
```
|
||||
|
||||
Make sure you delete the contents of the build directory before changing
|
||||
configurations.
|
||||
|
||||
Then use `make` to build the project:
|
||||
|
||||
```sh
|
||||
cd build
|
||||
make hello_world
|
||||
```
|
||||
|
||||
As with the top-level `gemma.cpp` project you can use the `make` commands `-j`
|
||||
flag to use parallel threads for faster builds.
|
||||
|
||||
From inside the `gemma.cpp/examples/hello_world/build` directory, there should
|
||||
be a `hello_world` executable. You can run it with the same 3 model arguments as
|
||||
gemma.cpp specifying the tokenizer, compressed weights file, and model type, for
|
||||
example:
|
||||
|
||||
```sh
|
||||
./hello_world --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it
|
||||
```
|
||||
|
||||
Should print a greeting to the terminal:
|
||||
|
||||
```
|
||||
"Hello, world! It's a pleasure to greet you all. May your day be filled with joy, peace, and all the things that make your heart soar.
|
||||
```
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
*
|
||||
!.gitignore
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "gemma.h"
|
||||
// copybara:end
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h"
|
||||
// copybara:end
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/app.h" // LoaderArgs
|
||||
// copybara:end
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
std::vector<int> tokenize(
|
||||
const std::string& prompt_string,
|
||||
const sentencepiece::SentencePieceProcessor* tokenizer) {
|
||||
std::string formatted = "<start_of_turn>user\n" + prompt_string +
|
||||
"<end_of_turn>\n<start_of_turn>model\n";
|
||||
std::vector<int> tokens;
|
||||
HWY_ASSERT(tokenizer->Encode(formatted, &tokens).ok());
|
||||
tokens.insert(tokens.begin(), 2); // BOS token
|
||||
return tokens;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
|
||||
// Rough heuristic for the number of threads to use
|
||||
size_t num_threads = static_cast<size_t>(std::clamp(
|
||||
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
|
||||
hwy::ThreadPool pool(num_threads);
|
||||
|
||||
// Instantiate model and KV Cache
|
||||
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights,
|
||||
loader.ModelType(), pool);
|
||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||
size_t pos = 0; // KV Cache position
|
||||
|
||||
// Initialize random number generator
|
||||
std::mt19937 gen;
|
||||
std::random_device rd;
|
||||
gen.seed(rd());
|
||||
|
||||
// Tokenize instruction
|
||||
std::vector<int> tokens =
|
||||
tokenize("Write a greeting to the world.", model.Tokenizer());
|
||||
size_t ntokens = tokens.size();
|
||||
|
||||
// This callback function gets invoked everytime a token is generated
|
||||
auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()](
|
||||
int token, float) {
|
||||
++pos;
|
||||
if (pos < ntokens) {
|
||||
// print feedback
|
||||
} else if (token != gcpp::EOS_ID) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text).ok());
|
||||
std::cout << token_text << std::flush;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
GenerateGemma(model,
|
||||
{.max_tokens = 2048,
|
||||
.max_generated_tokens = 1024,
|
||||
.temperature = 1.0,
|
||||
.verbosity = 0},
|
||||
tokens, /*KV cache position = */ 0, kv_cache, pool,
|
||||
stream_token, gen);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
229
gemma.cc
229
gemma.cc
|
|
@ -25,8 +25,6 @@
|
|||
#include "compression/compress-inl.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "ops.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h" // Path
|
||||
#include "hwy/contrib/matvec/matvec-inl.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
|
@ -231,20 +229,39 @@ struct Activations {
|
|||
struct GemmaInterface {
|
||||
virtual ~GemmaInterface() = default;
|
||||
|
||||
virtual const sentencepiece::SentencePieceProcessor& Tokenizer() const = 0;
|
||||
virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0;
|
||||
|
||||
// TODO: group pool/callbacks into struct
|
||||
virtual void Generate(const InferenceArgs& args,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) = 0;
|
||||
};
|
||||
|
||||
template <class Config>
|
||||
KVCache CreateKVCache() {
|
||||
return CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
|
||||
Config::kSeqLen);
|
||||
}
|
||||
|
||||
KVCache CreateKVCache(Model type) {
|
||||
switch (type) {
|
||||
case Model::GEMMA_2B:
|
||||
return CreateKVCache<ConfigGemma2B>();
|
||||
case Model::GEMMA_7B:
|
||||
return CreateKVCache<ConfigGemma7B>();
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(type));
|
||||
}
|
||||
}
|
||||
|
||||
template <class Config>
|
||||
struct GemmaImpl : public GemmaInterface {
|
||||
GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool);
|
||||
GemmaImpl(std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]>& compressed_weights,
|
||||
hwy::ThreadPool& pool);
|
||||
|
||||
~GemmaImpl() {
|
||||
using CWeights = CompressedWeights<Config>;
|
||||
|
|
@ -252,22 +269,21 @@ struct GemmaImpl : public GemmaInterface {
|
|||
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
|
||||
}
|
||||
|
||||
const sentencepiece::SentencePieceProcessor& Tokenizer() const {
|
||||
return tokenizer;
|
||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const override {
|
||||
return tokenizer.get();
|
||||
}
|
||||
|
||||
void Generate(const InferenceArgs& args, const std::vector<int>& prompt,
|
||||
size_t start_pos, hwy::ThreadPool& pool,
|
||||
void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937&, int verbosity);
|
||||
const AcceptFunc& accept_token, std::mt19937&,
|
||||
int verbosity) override;
|
||||
|
||||
sentencepiece::SentencePieceProcessor tokenizer;
|
||||
|
||||
// CompressedWeights<Config>
|
||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights;
|
||||
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
|
||||
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
|
||||
KVCache kv_cache;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -294,7 +310,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
static constexpr size_t kModelDim =
|
||||
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
const float kQueryScale = 1.0 / sqrtf(static_cast<float>(kQKVDim));
|
||||
static const float kQueryScale =
|
||||
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
||||
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||
// linear projections to QKV
|
||||
|
|
@ -417,7 +434,8 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
|||
hwy::ThreadPool& inner_pool) {
|
||||
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static const float kEmbScaling = sqrtf(static_cast<float>(kModelDim));
|
||||
static const float kEmbScaling =
|
||||
static_cast<float>(sqrt(static_cast<double>(kModelDim)));
|
||||
|
||||
pool.Run(
|
||||
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
||||
|
|
@ -472,7 +490,8 @@ void Transformer(int token, size_t pos,
|
|||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
|
||||
static const float kEmbScaling = sqrtf(static_cast<float>(kModelDim));
|
||||
static const float kEmbScaling =
|
||||
static_cast<float>(sqrt(static_cast<double>(kModelDim)));
|
||||
|
||||
Decompress(c_weights.c_embedder_input_embedding, token * kModelDim,
|
||||
activations.x.data(), kModelDim);
|
||||
|
|
@ -495,8 +514,9 @@ void Transformer(int token, size_t pos,
|
|||
}
|
||||
|
||||
template <class TConfig>
|
||||
void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
|
||||
const std::vector<int>& prompt, size_t pos,
|
||||
void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
|
|
@ -510,7 +530,6 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
|
|||
const CompressedWeights<TConfig>& c_weights =
|
||||
*reinterpret_cast<CompressedWeights<TConfig>*>(
|
||||
gemma.compressed_weights.get());
|
||||
KVCache& kv_cache = gemma.kv_cache;
|
||||
int token;
|
||||
|
||||
// pos indexes the KV cache. In the first turn of a chat, pos = 0.
|
||||
|
|
@ -548,8 +567,9 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
|
|||
// in the future this output should not occur in GenerateImpl but instead
|
||||
// should be available as observable state for frontend code to handle I/O.
|
||||
const double prefill_end = hwy::platform::Now();
|
||||
const double prefill_tok_sec = static_cast<double>(pos_offset) / (prefill_end - prefill_start);
|
||||
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]\n";
|
||||
const double prefill_tok_sec =
|
||||
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
|
||||
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]";
|
||||
}
|
||||
|
||||
const double gen_start = hwy::platform::Now();
|
||||
|
|
@ -558,10 +578,10 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
|
|||
|
||||
if (verbosity >= 2) {
|
||||
// Provide usage warnings if max_new_tokens is out of range.
|
||||
if (args.max_generated_tokens > args.max_tokens) {
|
||||
if (max_generated_tokens > max_tokens) {
|
||||
std::cout << "Warning: max_new_tokens should be <= max_tokens"
|
||||
<< std::endl;
|
||||
} else if ((prompt.size() + args.max_generated_tokens) > args.max_tokens) {
|
||||
} else if ((prompt.size() + max_generated_tokens) > max_tokens) {
|
||||
std::cout << "Warning: Prompt size + max_new_tokens exceeds max_tokens."
|
||||
<< std::endl;
|
||||
}
|
||||
|
|
@ -570,7 +590,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
|
|||
auto pos_gen_start = pos_offset;
|
||||
token = prompt.at(pos_offset);
|
||||
size_t generate_pos = 0;
|
||||
for (; pos < args.max_tokens && generate_pos < args.max_generated_tokens;
|
||||
for (; pos < max_tokens && generate_pos < max_generated_tokens;
|
||||
++pos, ++pos_offset, ++generate_pos) {
|
||||
Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool);
|
||||
float* final_activation = activations.x.data();
|
||||
|
|
@ -583,7 +603,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
|
|||
// Barrier: must have all logits so we can subtract max.
|
||||
Softmax(activations.logits.data(), kVocabSize);
|
||||
token = SampleTopK<kTopK>(activations.logits.data(), kVocabSize, gen,
|
||||
args.temperature, accept_token);
|
||||
temperature, accept_token);
|
||||
}
|
||||
if (!stream_token(token, activations.logits[token])) {
|
||||
token = EOS_ID;
|
||||
|
|
@ -592,7 +612,8 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
|
|||
if (verbosity >= 2) {
|
||||
const double gen_end = hwy::platform::Now();
|
||||
const double gen_tok_sec =
|
||||
static_cast<double>(pos_offset - pos_gen_start) / (gen_end - gen_start);
|
||||
static_cast<double>(pos_offset - pos_gen_start) /
|
||||
(gen_end - gen_start);
|
||||
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
|
||||
}
|
||||
break;
|
||||
|
|
@ -600,21 +621,27 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
|
|||
}
|
||||
}
|
||||
|
||||
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, const InferenceArgs& args,
|
||||
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) {
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, const InferenceArgs& args,
|
||||
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) {
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
|
|
@ -666,10 +693,10 @@ void ForEachTensor(const Weights<TConfig>* weights,
|
|||
|
||||
template <class TConfig>
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
|
||||
const Path& model, const Path& cache, hwy::ThreadPool& pool) {
|
||||
const Path& weights_path, const Path& cache, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Startup.LoadCache");
|
||||
|
||||
if (!std::filesystem::exists(model.path) &&
|
||||
if (!std::filesystem::exists(weights_path.path) &&
|
||||
!std::filesystem::exists(cache.path)) {
|
||||
HWY_ABORT(
|
||||
"Either the model weights (--weights) or cached compressed weights "
|
||||
|
|
@ -689,7 +716,8 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
|
|||
if (loader.ReadAll(pool)) return c_weights_u8;
|
||||
|
||||
// Get weights, compress, and store in cache.
|
||||
const hwy::AlignedUniquePtr<Weights<TConfig>> weights = LoadWeights<TConfig>(model);
|
||||
const hwy::AlignedUniquePtr<Weights<TConfig>> weights =
|
||||
LoadWeights<TConfig>(weights_path);
|
||||
Compressor compressor(pool);
|
||||
ForEachTensor<TConfig>(weights.get(), *c_weights, compressor);
|
||||
compressor.WriteAll(pool, cache.path.c_str());
|
||||
|
|
@ -699,14 +727,17 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
|
|||
|
||||
// Type-erased because this function is called via a function pointer.
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeightsT(
|
||||
const LoaderArgs& args, hwy::ThreadPool& pool) {
|
||||
switch (args.ModelType()) {
|
||||
gcpp::Model model, const Path& weights, const Path& compressed_weights,
|
||||
hwy::ThreadPool& pool) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return GetCompressedWeights<ConfigGemma2B>(args.model, args.cache, pool);
|
||||
return GetCompressedWeights<ConfigGemma2B>(weights, compressed_weights,
|
||||
pool);
|
||||
case Model::GEMMA_7B:
|
||||
return GetCompressedWeights<ConfigGemma7B>(args.model, args.cache, pool);
|
||||
return GetCompressedWeights<ConfigGemma7B>(weights, compressed_weights,
|
||||
pool);
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(args.ModelType()));
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -729,75 +760,99 @@ KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
|
|||
}
|
||||
|
||||
template <class Config>
|
||||
GemmaImpl<Config>::GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool)
|
||||
: compressed_weights(
|
||||
HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)),
|
||||
GemmaImpl<Config>::GemmaImpl(
|
||||
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]>& compressed_weights,
|
||||
hwy::ThreadPool& pool)
|
||||
: tokenizer(std::move(tokenizer)),
|
||||
compressed_weights(std::move(compressed_weights)),
|
||||
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
|
||||
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()),
|
||||
kv_cache(
|
||||
CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
|
||||
Config::kSeqLen)) {
|
||||
PROFILER_ZONE("Startup.tokenizer");
|
||||
|
||||
HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());
|
||||
}
|
||||
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {}
|
||||
|
||||
template <>
|
||||
void GemmaImpl<ConfigGemma2B>::Generate(const InferenceArgs& args,
|
||||
const std::vector<int>& prompt,
|
||||
size_t start_pos, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
void GemmaImpl<ConfigGemma2B>::Generate(
|
||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
HWY_DYNAMIC_DISPATCH(Generate2B)
|
||||
(*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token,
|
||||
gen, verbosity);
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
}
|
||||
template <>
|
||||
void GemmaImpl<ConfigGemma7B>::Generate(const InferenceArgs& args,
|
||||
const std::vector<int>& prompt,
|
||||
size_t start_pos, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
void GemmaImpl<ConfigGemma7B>::Generate(
|
||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
HWY_DYNAMIC_DISPATCH(Generate7B)
|
||||
(*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token,
|
||||
gen, verbosity);
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
|
||||
const Model model_type = args.ModelType();
|
||||
model_training = args.ModelTraining();
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||
const Path& weights_path, Model model_type,
|
||||
hwy::ThreadPool& pool) {
|
||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
||||
{
|
||||
PROFILER_ZONE("Startup.tokenizer");
|
||||
tokenizer = std::make_unique<sentencepiece::SentencePieceProcessor>();
|
||||
if (!tokenizer->Load(tokenizer_path.path).ok()) {
|
||||
HWY_ABORT("Failed to load the tokenizer file.");
|
||||
}
|
||||
}
|
||||
auto compressed_weights = HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(
|
||||
model_type, weights_path, compressed_weights_path, pool);
|
||||
switch (model_type) {
|
||||
case Model::GEMMA_2B:
|
||||
impl_.reset(new GemmaImpl<ConfigGemma2B>(args, pool));
|
||||
impl_.reset(
|
||||
new GemmaImpl<ConfigGemma2B>(tokenizer, compressed_weights, pool));
|
||||
break;
|
||||
case Model::GEMMA_7B:
|
||||
impl_.reset(new GemmaImpl<ConfigGemma7B>(args, pool));
|
||||
impl_.reset(
|
||||
new GemmaImpl<ConfigGemma7B>(tokenizer, compressed_weights, pool));
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
|
||||
}
|
||||
}
|
||||
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||
Model model_type, hwy::ThreadPool& pool)
|
||||
: Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type,
|
||||
pool) {}
|
||||
|
||||
Gemma::~Gemma() = default; // after GemmaInterface is defined
|
||||
|
||||
const sentencepiece::SentencePieceProcessor& Gemma::Tokenizer() const {
|
||||
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
|
||||
return impl_->Tokenizer();
|
||||
}
|
||||
|
||||
void GenerateGemma(Gemma& gemma, const InferenceArgs& args,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) {
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
gemma.impl_->Generate(args, prompt, start_pos, pool, inner_pool, stream_token,
|
||||
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
|
||||
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token, std::mt19937& gen) {
|
||||
hwy::ThreadPool inner_pool(0);
|
||||
GenerateGemma(
|
||||
gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens,
|
||||
runtime_config.temperature, prompt, start_pos, kv_cache, pool, inner_pool,
|
||||
stream_token, [](int) { return true; }, gen, runtime_config.verbosity);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
151
gemma.h
151
gemma.h
|
|
@ -64,147 +64,50 @@ struct KVCache {
|
|||
enum class Model { GEMMA_2B, GEMMA_7B };
|
||||
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
||||
|
||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
static std::string ToLower(const std::string& text) {
|
||||
std::string result = text;
|
||||
std::transform(begin(result), end(result), begin(result),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
return result;
|
||||
}
|
||||
|
||||
gcpp::Model ModelType() const {
|
||||
const std::string model_type_lc = ToLower(model_type);
|
||||
if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") {
|
||||
return gcpp::Model::GEMMA_2B;
|
||||
} else {
|
||||
return gcpp::Model::GEMMA_7B;
|
||||
}
|
||||
}
|
||||
|
||||
gcpp::ModelTraining ModelTraining() const {
|
||||
const std::string model_type_lc = ToLower(model_type);
|
||||
if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") {
|
||||
return gcpp::ModelTraining::GEMMA_PT;
|
||||
} else {
|
||||
return gcpp::ModelTraining::GEMMA_IT;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() const {
|
||||
const std::string model_type_lc = ToLower(model_type);
|
||||
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
|
||||
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
|
||||
return "Model type must be 2b-pt, 7b-pt, 2b-it, or "
|
||||
"7b-it.";
|
||||
}
|
||||
if (tokenizer.path.empty()) {
|
||||
return "Missing --tokenizer flag, a file for the tokenizer is required.";
|
||||
}
|
||||
if (model_type.empty()) {
|
||||
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
|
||||
"2b-it, or 7b-it.";
|
||||
}
|
||||
if (cache.path.empty()) {
|
||||
return "Missing --compressed_weights flag, a file for the compressed "
|
||||
"model.";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Path tokenizer;
|
||||
Path model; // uncompressed weights OR
|
||||
Path cache; // compressed weights
|
||||
std::string model_type;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(tokenizer, "tokenizer", Path(),
|
||||
"Path name of tokenizer model file.\n Required argument.");
|
||||
visitor(
|
||||
cache, "compressed_weights", Path(),
|
||||
"Path name of compressed weights file, regenerated from `--weights` "
|
||||
"file if "
|
||||
"the compressed weights file does not exist.\n Required argument.");
|
||||
visitor(model_type, "model", std::string(),
|
||||
"Model type\n 2b-it (2B parameters, instruction-tuned)\n "
|
||||
"2b-pt (2B parameters, pretrained)\n 7b-it (7B parameters "
|
||||
"instruction-tuned)\n 7b-pt (7B parameters, pretrained)\n"
|
||||
" Required argument.");
|
||||
visitor(model, "weights", Path(),
|
||||
"Path name of model weights (.sbs) file. Only required if "
|
||||
"compressed_weights file is not present and needs to be "
|
||||
"regenerated. This parameter is only required for compressing"
|
||||
"new model weight exports, otherwise it is not needed.");
|
||||
}
|
||||
struct RuntimeConfig {
|
||||
size_t max_tokens;
|
||||
size_t max_generated_tokens;
|
||||
float temperature;
|
||||
int verbosity;
|
||||
};
|
||||
|
||||
struct GemmaInterface;
|
||||
|
||||
struct Gemma {
|
||||
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
|
||||
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||
const Path& weights_path, Model model_type, hwy::ThreadPool& pool);
|
||||
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||
Model model_type, hwy::ThreadPool& pool);
|
||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
||||
|
||||
const sentencepiece::SentencePieceProcessor& Tokenizer() const;
|
||||
|
||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
||||
std::unique_ptr<GemmaInterface> impl_;
|
||||
gcpp::ModelTraining model_training;
|
||||
};
|
||||
|
||||
KVCache CreateKVCache(Model type); // convenient workaround for now
|
||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len);
|
||||
|
||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||
// probability is 0.0f.
|
||||
using StreamFunc = std::function<bool(int, float)>;
|
||||
using AcceptFunc = std::function<bool(int)>;
|
||||
|
||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
size_t max_tokens;
|
||||
size_t max_generated_tokens;
|
||||
|
||||
float temperature;
|
||||
bool deterministic;
|
||||
bool multiturn;
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() const {
|
||||
if (max_tokens > gcpp::kSeqLen) {
|
||||
return "max_tokens is larger than the maximum sequence length (see "
|
||||
"configs.h).";
|
||||
}
|
||||
if (max_generated_tokens > max_tokens) {
|
||||
return "Maximum number of generated tokens is larger than the maximum "
|
||||
"total tokens.";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(max_tokens, "max_tokens", size_t{3072},
|
||||
"Maximum number of tokens in prompt + generation.");
|
||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||
"Maximum number of tokens to generate.");
|
||||
|
||||
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
|
||||
visitor(deterministic, "deterministic", false,
|
||||
"Make top-k sampling deterministic", 2);
|
||||
visitor(multiturn, "multiturn", false,
|
||||
"Multiturn mode (if 0, this clears the KV cache after every "
|
||||
"interaction without quitting)\n Default : 0 (conversation "
|
||||
"resets every turn)");
|
||||
}
|
||||
};
|
||||
|
||||
void GenerateGemma(Gemma& gemma, const InferenceArgs& args,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& g,
|
||||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity);
|
||||
|
||||
// Convenience function for the common case:
|
||||
// - Bundle runtime parameters as RuntimeConfig
|
||||
// - No threadpools within threadpools (inner_pool = dummy)
|
||||
// - All tokens accepted
|
||||
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token, std::mt19937& gen);
|
||||
|
||||
constexpr int EOS_ID = 1;
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
48
ops.h
48
ops.h
|
|
@ -340,11 +340,21 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
|
|||
// = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT.
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
|
||||
const float* HWY_RESTRICT a, size_t size) {
|
||||
float total = 0.f;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
total += a[i] * a[i];
|
||||
const hn::ScalableTag<float> d;
|
||||
const size_t N = hn::Lanes(d);
|
||||
HWY_DASSERT(size >= 2 * N);
|
||||
HWY_DASSERT(size % (2 * N) == 0);
|
||||
|
||||
auto sum0 = hn::Zero(d);
|
||||
auto sum1 = hn::Zero(d);
|
||||
for (size_t i = 0; i <= size - 2 * N; i += 2 * N) {
|
||||
const auto a0 = LoadU(d, a + i);
|
||||
sum0 = MulAdd(a0, a0, sum0);
|
||||
const auto a1 = LoadU(d, a + i + N);
|
||||
sum1 = MulAdd(a1, a1, sum1);
|
||||
}
|
||||
return total;
|
||||
|
||||
return ReduceSum(d, Add(sum0, sum1));
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||
|
|
@ -362,12 +372,30 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
|||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
|
||||
float* HWY_RESTRICT out, size_t size) {
|
||||
constexpr float eps = 1e-6f;
|
||||
float ss = SquaredL2(x, size);
|
||||
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
|
||||
for (size_t j = 0; j < size; j++) {
|
||||
// Note 1.0f centering here
|
||||
out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
constexpr float kEps = 1e-6f;
|
||||
constexpr size_t kUnrollSize = 2;
|
||||
|
||||
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
||||
const hn::Repartition<float, decltype(dbf)> df32;
|
||||
const size_t N32 = hn::Lanes(df32);
|
||||
|
||||
const float ss = SquaredL2(x, size);
|
||||
const auto vss =
|
||||
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
|
||||
|
||||
HWY_DASSERT(size % (kUnrollSize * MaxLanes(df32)) == 0);
|
||||
for (size_t i = 0; i < size; i += kUnrollSize * N32) {
|
||||
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i);
|
||||
const auto w0 = hn::PromoteLowerTo(df32, w16);
|
||||
const auto w1 = hn::PromoteUpperTo(df32, w16);
|
||||
const auto m0 = hn::Mul(vss, hn::LoadU(df32, x + i));
|
||||
const auto m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32));
|
||||
|
||||
// (1+weight) * m = m + weight*m = one FMA.
|
||||
hn::StoreU(hn::MulAdd(m0, w0, m0), df32, out + i);
|
||||
hn::StoreU(hn::MulAdd(m1, w1, m1), df32, out + i + N32);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
31
run.cc
31
run.cc
|
|
@ -66,8 +66,8 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
<< std::thread::hardware_concurrency() << std::endl
|
||||
<< "Instruction set : "
|
||||
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
|
||||
<< hwy::VectorBytes() * 8 << " bits)"
|
||||
<< "\n"
|
||||
<< hwy::VectorBytes() * 8 << " bits)" << "\n"
|
||||
<< "Compiled config : " << CompiledConfig() << "\n"
|
||||
<< "Weight Type : "
|
||||
<< gcpp::TypeName(gcpp::WeightT()) << "\n"
|
||||
<< "EmbedderInput Type : "
|
||||
|
|
@ -96,10 +96,10 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
|
|||
std::cerr << "\n";
|
||||
}
|
||||
|
||||
void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const InferenceArgs& args,
|
||||
int verbosity, const gcpp::AcceptFunc& accept_token,
|
||||
std::string& eot_line) {
|
||||
void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const InferenceArgs& args, int verbosity,
|
||||
const gcpp::AcceptFunc& accept_token, std::string& eot_line) {
|
||||
PROFILER_ZONE("Gen.misc");
|
||||
int abs_pos = 0; // absolute token index over all turns
|
||||
int current_pos = 0; // token index within the current turn
|
||||
|
|
@ -115,7 +115,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
|||
|
||||
// callback function invoked for each generated token.
|
||||
auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size,
|
||||
tokenizer = &model.Tokenizer(),
|
||||
tokenizer = model.Tokenizer(),
|
||||
verbosity](int token, float) {
|
||||
++abs_pos;
|
||||
++current_pos;
|
||||
|
|
@ -129,7 +129,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
|||
}
|
||||
}
|
||||
if (verbosity >= 2) {
|
||||
std::cout << "\n[ End ]" << std::endl;
|
||||
std::cout << "\n[ End ]\n";
|
||||
}
|
||||
} else {
|
||||
std::string token_text;
|
||||
|
|
@ -142,7 +142,6 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
|||
std::cout << std::endl << std::endl;
|
||||
}
|
||||
}
|
||||
// TODO(austinvhuang): is explicit space necessary?
|
||||
std::cout << token_text << std::flush;
|
||||
}
|
||||
return true;
|
||||
|
|
@ -191,7 +190,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
|||
}
|
||||
}
|
||||
|
||||
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok());
|
||||
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt).ok());
|
||||
|
||||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
||||
// if needed.
|
||||
|
|
@ -204,8 +203,9 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
|||
std::cerr << std::endl << "[ Reading prompt ] " << std::flush;
|
||||
|
||||
const double time_start = hwy::platform::Now();
|
||||
GenerateGemma(model, args, prompt, abs_pos, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||
args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool,
|
||||
stream_token, accept_token, gen, verbosity);
|
||||
const double time_end = hwy::platform::Now();
|
||||
const double tok_sec = current_pos / (time_end - time_start);
|
||||
if (verbosity >= 2) {
|
||||
|
|
@ -234,7 +234,10 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
|
||||
}
|
||||
|
||||
gcpp::Gemma model(loader, pool);
|
||||
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights,
|
||||
loader.ModelType(), pool);
|
||||
|
||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||
|
||||
if (const char* error = inference.Validate()) {
|
||||
ShowHelp(loader, inference, app);
|
||||
|
|
@ -272,7 +275,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
}
|
||||
|
||||
ReplGemma(
|
||||
model, pool, inner_pool, inference, app.verbosity,
|
||||
model, kv_cache, pool, inner_pool, inference, app.verbosity,
|
||||
/*accept_token=*/[](int) { return true; }, app.eot_line);
|
||||
}
|
||||
|
||||
|
|
|
|||
153
util/app.h
153
util/app.h
|
|
@ -18,10 +18,13 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
||||
|
||||
#include <iterator>
|
||||
#if HWY_OS_LINUX
|
||||
#include <sched.h>
|
||||
|
||||
#include <cctype>
|
||||
#include <cerrno> // IDE does not recognize errno.h as providing errno.
|
||||
#include <string>
|
||||
#endif
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
|
@ -29,6 +32,14 @@
|
|||
#include <algorithm> // std::clamp
|
||||
#include <thread> // NOLINT>
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "configs.h"
|
||||
// copybara:end
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "gemma.h"
|
||||
// copybara:end
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h"
|
||||
// copybara:end
|
||||
|
|
@ -36,6 +47,24 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
static inline const char* CompiledConfig() {
|
||||
if (HWY_IS_ASAN) {
|
||||
return "asan";
|
||||
} else if (HWY_IS_MSAN) {
|
||||
return "msan";
|
||||
} else if (HWY_IS_TSAN) {
|
||||
return "tsan";
|
||||
#if defined(HWY_IS_UBSAN)
|
||||
} else if (HWY_IS_UBSAN) {
|
||||
return "ubsan";
|
||||
#endif
|
||||
} else if (HWY_IS_DEBUG_BUILD) {
|
||||
return "dbg";
|
||||
} else {
|
||||
return "opt";
|
||||
}
|
||||
}
|
||||
|
||||
static inline void PinThreadToCore(size_t cpu_index) {
|
||||
#if HWY_OS_LINUX
|
||||
// Forces the thread to run on the logical processor with the same number.
|
||||
|
|
@ -79,9 +108,9 @@ class AppArgs : public ArgsBase<AppArgs> {
|
|||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(verbosity, "verbosity", 1,
|
||||
"Show verbose developer information\n 0 = only print generation "
|
||||
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
||||
"developer/debug info).\n Default = 1.",
|
||||
"Show verbose developer information\n 0 = only print generation "
|
||||
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
||||
"developer/debug info).\n Default = 1.",
|
||||
2);
|
||||
visitor(num_threads, "num_threads",
|
||||
kDefaultNumThreads, // see ChooseNumThreads
|
||||
|
|
@ -98,6 +127,124 @@ class AppArgs : public ArgsBase<AppArgs> {
|
|||
}
|
||||
};
|
||||
|
||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
static std::string ToLower(const std::string& text) {
|
||||
std::string result = text;
|
||||
std::transform(begin(result), end(result), begin(result),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
return result;
|
||||
}
|
||||
|
||||
gcpp::Model ModelType() const {
|
||||
const std::string model_type_lc = ToLower(model_type);
|
||||
if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") {
|
||||
return gcpp::Model::GEMMA_2B;
|
||||
} else {
|
||||
return gcpp::Model::GEMMA_7B;
|
||||
}
|
||||
}
|
||||
|
||||
gcpp::ModelTraining ModelTraining() const {
|
||||
const std::string model_type_lc = ToLower(model_type);
|
||||
if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") {
|
||||
return gcpp::ModelTraining::GEMMA_PT;
|
||||
} else {
|
||||
return gcpp::ModelTraining::GEMMA_IT;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() const {
|
||||
const std::string model_type_lc = ToLower(model_type);
|
||||
if (model_type.empty()) {
|
||||
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
|
||||
"2b-it, or 7b-it.";
|
||||
}
|
||||
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
|
||||
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
|
||||
return "Model type must be 2b-pt, 7b-pt, 2b-it, or "
|
||||
"7b-it.";
|
||||
}
|
||||
if (tokenizer.path.empty()) {
|
||||
return "Missing --tokenizer flag, a file for the tokenizer is required.";
|
||||
}
|
||||
if (compressed_weights.path.empty()) {
|
||||
return "Missing --compressed_weights flag, a file for the compressed "
|
||||
"model.";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Path tokenizer;
|
||||
Path weights; // uncompressed weights file location
|
||||
Path compressed_weights; // compressed weights file location
|
||||
std::string model_type;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(tokenizer, "tokenizer", Path(),
|
||||
"Path name of tokenizer model file.\n Required argument.");
|
||||
visitor(
|
||||
compressed_weights, "compressed_weights", Path(),
|
||||
"Path name of compressed weights file, regenerated from `--weights` "
|
||||
"file if "
|
||||
"the compressed weights file does not exist.\n Required argument.");
|
||||
visitor(model_type, "model", std::string(),
|
||||
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
|
||||
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
||||
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n"
|
||||
" Required argument.");
|
||||
visitor(weights, "weights", Path(),
|
||||
"Path name of model weights (.sbs) file. Only required if "
|
||||
"compressed_weights file is not present and needs to be "
|
||||
"regenerated. This parameter is only required for compressing"
|
||||
"new model weight exports, otherwise it is not needed.");
|
||||
}
|
||||
};
|
||||
|
||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
size_t max_tokens;
|
||||
size_t max_generated_tokens;
|
||||
|
||||
float temperature;
|
||||
bool deterministic;
|
||||
bool multiturn;
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() const {
|
||||
if (max_tokens > gcpp::kSeqLen) {
|
||||
return "max_tokens is larger than the maximum sequence length (see "
|
||||
"configs.h).";
|
||||
}
|
||||
if (max_generated_tokens > max_tokens) {
|
||||
return "Maximum number of generated tokens is larger than the maximum "
|
||||
"total tokens.";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(max_tokens, "max_tokens", size_t{3072},
|
||||
"Maximum number of tokens in prompt + generation.");
|
||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||
"Maximum number of tokens to generate.");
|
||||
|
||||
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
|
||||
visitor(deterministic, "deterministic", false,
|
||||
"Make top-k sampling deterministic", 2);
|
||||
visitor(multiturn, "multiturn", false,
|
||||
"Multiturn mode\n 0 = clear KV cache after every "
|
||||
"interaction\n 1 = continue KV cache after every interaction\n "
|
||||
" Default : 0 (conversation "
|
||||
"resets every turn)");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
||||
|
|
|
|||
Loading…
Reference in New Issue