Prep for sharding gemma.cc: split into kv_cache, tokenizer.

Move activations.h to backprop/ to make space for another activations.h.

PiperOrigin-RevId: 648744500
This commit is contained in:
Jan Wassenberg 2024-07-02 09:30:24 -07:00 committed by Copybara-Service
parent 85fcd3cd80
commit 09a7e75ead
19 changed files with 337 additions and 201 deletions

View File

@ -94,13 +94,35 @@ cc_library(
], ],
) )
cc_library(
name = "tokenizer",
srcs = ["gemma/tokenizer.cc"],
hdrs = ["gemma/tokenizer.h"],
deps = [
"//compression:io",
"@hwy//:hwy",
"@hwy//:nanobenchmark", # timer
"@hwy//:profiler",
"@com_google_sentencepiece//:sentencepiece_processor",
],
)
cc_library(
name = "kv_cache",
srcs = ["gemma/kv_cache.cc"],
hdrs = ["gemma/kv_cache.h"],
deps = [
":common",
"@hwy//:hwy",
],
)
cc_library( cc_library(
name = "gemma_lib", name = "gemma_lib",
srcs = [ srcs = [
"gemma/gemma.cc", "gemma/gemma.cc",
], ],
hdrs = [ hdrs = [
"gemma/activations.h",
"gemma/gemma.h", "gemma/gemma.h",
], ],
exec_properties = { exec_properties = {
@ -114,6 +136,8 @@ cc_library(
deps = [ deps = [
":common", ":common",
":ops", ":ops",
":tokenizer",
":kv_cache",
":weights", ":weights",
"//compression:compress", "//compression:compress",
"//compression:io", "//compression:io",
@ -122,7 +146,6 @@ cc_library(
"@hwy//:nanobenchmark", # timer "@hwy//:nanobenchmark", # timer
"@hwy//:profiler", "@hwy//:profiler",
"@hwy//:thread_pool", "@hwy//:thread_pool",
"@com_google_sentencepiece//:sentencepiece_processor",
], ],
) )
@ -321,6 +344,7 @@ cc_library(
"backprop/forward.cc", "backprop/forward.cc",
], ],
hdrs = [ hdrs = [
"backprop/activations.h",
"backprop/backward.h", "backprop/backward.h",
"backprop/backward-inl.h", "backprop/backward-inl.h",
"backprop/forward.h", "backprop/forward.h",
@ -340,6 +364,7 @@ cc_library(
cc_library( cc_library(
name = "backprop_scalar", name = "backprop_scalar",
hdrs = [ hdrs = [
"backprop/activations.h",
"backprop/backward_scalar.h", "backprop/backward_scalar.h",
"backprop/common_scalar.h", "backprop/common_scalar.h",
"backprop/forward_scalar.h", "backprop/forward_scalar.h",

View File

@ -49,6 +49,7 @@ set(SOURCES
compression/sfp.h compression/sfp.h
compression/sfp-inl.h compression/sfp-inl.h
compression/test_util.h compression/test_util.h
backprop/activations.h
backprop/backward.cc backprop/backward.cc
backprop/backward.h backprop/backward.h
backprop/backward-inl.h backprop/backward-inl.h
@ -62,18 +63,21 @@ set(SOURCES
backprop/optimizer.h backprop/optimizer.h
evals/cross_entropy.cc evals/cross_entropy.cc
evals/cross_entropy.h evals/cross_entropy.h
gemma/configs.h
gemma/activations.h
gemma/benchmark_helper.cc gemma/benchmark_helper.cc
gemma/benchmark_helper.h gemma/benchmark_helper.h
gemma/common.cc gemma/common.cc
gemma/common.h gemma/common.h
gemma/configs.h
gemma/gemma.cc gemma/gemma.cc
gemma/gemma.h gemma/gemma.h
gemma/kv_cache.cc
gemma/kv_cache.h
gemma/ops.h gemma/ops.h
gemma/tokenizer.cc
gemma/tokenizer.h
gemma/weights_raw.h
gemma/weights.cc gemma/weights.cc
gemma/weights.h gemma/weights.h
gemma/weights_raw.h
util/app.h util/app.h
util/args.h util/args.h
) )

View File

@ -13,8 +13,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ #ifndef THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ #define THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_
#include <stddef.h> #include <stddef.h>
@ -86,4 +86,4 @@ class ActivationsWrapper {
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ #endif // THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_

View File

@ -22,11 +22,11 @@
#include <stddef.h> #include <stddef.h>
#include <algorithm>
#include <cmath> #include <cmath>
#include <vector>
#include "backprop/activations.h"
#include "backprop/prompt.h" #include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"

View File

@ -15,8 +15,8 @@
#include "backprop/backward.h" #include "backprop/backward.h"
#include "backprop/activations.h"
#include "backprop/prompt.h" #include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"

View File

@ -22,9 +22,9 @@
#include <cmath> #include <cmath>
#include <vector> #include <vector>
#include "backprop/activations.h"
#include "backprop/common_scalar.h" #include "backprop/common_scalar.h"
#include "backprop/prompt.h" #include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/common.h" // EmbeddingScaling #include "gemma/common.h" // EmbeddingScaling
#include "gemma/weights_raw.h" #include "gemma/weights_raw.h"

View File

@ -26,12 +26,12 @@
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "backprop/activations.h"
#include "backprop/common_scalar.h" #include "backprop/common_scalar.h"
#include "backprop/forward_scalar.h" #include "backprop/forward_scalar.h"
#include "backprop/prompt.h" #include "backprop/prompt.h"
#include "backprop/sampler.h" #include "backprop/sampler.h"
#include "backprop/test_util.h" #include "backprop/test_util.h"
#include "gemma/activations.h"
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/weights_raw.h" #include "gemma/weights_raw.h"

View File

@ -24,12 +24,12 @@
#include <random> #include <random>
#include <vector> #include <vector>
#include "backprop/activations.h"
#include "backprop/backward_scalar.h" #include "backprop/backward_scalar.h"
#include "backprop/common_scalar.h" #include "backprop/common_scalar.h"
#include "backprop/forward_scalar.h" #include "backprop/forward_scalar.h"
#include "backprop/sampler.h" #include "backprop/sampler.h"
#include "backprop/test_util.h" #include "backprop/test_util.h"
#include "gemma/activations.h"
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/weights_raw.h" #include "gemma/weights_raw.h"
#include "hwy/base.h" #include "hwy/base.h"

View File

@ -23,7 +23,7 @@
#include <cmath> #include <cmath>
#include <vector> #include <vector>
#include "gemma/activations.h" #include "backprop/activations.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/configs.h" #include "gemma/configs.h"
#include "hwy/base.h" #include "hwy/base.h"

View File

@ -15,8 +15,8 @@
#include "backprop/forward.h" #include "backprop/forward.h"
#include "backprop/activations.h"
#include "backprop/prompt.h" #include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"

View File

@ -23,9 +23,9 @@
#include <complex> #include <complex>
#include <vector> #include <vector>
#include "backprop/activations.h"
#include "backprop/common_scalar.h" #include "backprop/common_scalar.h"
#include "backprop/prompt.h" #include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/common.h" // EmbeddingScaling #include "gemma/common.h" // EmbeddingScaling
#include "gemma/weights_raw.h" #include "gemma/weights_raw.h"

View File

@ -20,12 +20,12 @@
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "backprop/activations.h"
#include "backprop/backward.h" #include "backprop/backward.h"
#include "backprop/forward.h" #include "backprop/forward.h"
#include "backprop/optimizer.h" #include "backprop/optimizer.h"
#include "backprop/prompt.h" #include "backprop/prompt.h"
#include "backprop/sampler.h" #include "backprop/sampler.h"
#include "gemma/activations.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/weights.h" #include "gemma/weights.h"

View File

@ -17,6 +17,7 @@
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
#include <math.h> // sqrtf #include <math.h> // sqrtf
#include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include <string> #include <string>
@ -35,6 +36,12 @@ ByteStorageT AllocateSizeof() {
return hwy::AllocateAligned<uint8_t>(sizeof(T)); return hwy::AllocateAligned<uint8_t>(sizeof(T));
} }
constexpr size_t kPrefillBatchSize = 16;
constexpr size_t kDecodeBatchSize = 1;
constexpr size_t kBatchedQueryBatchSize = 16;
constexpr size_t kMinAdjustedPrefillBatchSize =
HWY_MAX((size_t)1, kPrefillBatchSize / kBatchedQueryBatchSize);
// Model variants: see configs.h for details. // Model variants: see configs.h for details.
enum class Model { enum class Model {
GEMMA_2B, GEMMA_2B,
@ -51,6 +58,13 @@ enum class ModelTraining { GEMMA_IT, GEMMA_PT };
// Tensor types for loading weights. // Tensor types for loading weights.
enum class Type { kF32, kBF16, kSFP }; enum class Type { kF32, kBF16, kSFP };
// TODO(janwas): merge with parser/ToString.
struct ModelInfo {
Model model;
ModelTraining training;
Type weight;
};
// Returns the return value of FuncT<Config*<TWeight>>().operator()(args), where // Returns the return value of FuncT<Config*<TWeight>>().operator()(args), where
// Config* is selected via `model`. Typically called by CallForModelAndWeight, // Config* is selected via `model`. Typically called by CallForModelAndWeight,
// but can also be called directly when FuncT does not actually use TWeight. // but can also be called directly when FuncT does not actually use TWeight.

View File

@ -37,7 +37,6 @@
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <memory>
#include <string> #include <string>
#include <utility> // std::move #include <utility> // std::move
#include <vector> #include <vector>
@ -54,14 +53,9 @@
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
#include "hwy/timer.h" #include "hwy/timer.h"
// copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h"
namespace gcpp { namespace gcpp {
// Set this to true to debug tokenizer tokens.
constexpr bool kShowTokenization = false;
// Must be aligned. // Must be aligned.
template <class TConfig, size_t kBatchSize> template <class TConfig, size_t kBatchSize>
struct Activations { struct Activations {
@ -115,114 +109,22 @@ struct Activations {
griffin_multiplier; griffin_multiplier;
}; };
namespace { template <typename TConfig>
struct AllocateState {
template <class TConfig> void operator()(ByteStorageT& prefill, ByteStorageT& decode) const {
struct CreateKVCache { // When batching queries, the prefill batch size is reduced by a factor
KVCache operator()() const { // of kBatchedQueryBatchSize
KVCache kv_cache = {}; prefill =
AllocateSizeof<Activations<TConfig, kMinAdjustedPrefillBatchSize *
const size_t size_cache_pos = CachePosSize<TConfig>()(); kBatchedQueryBatchSize>>();
if (size_cache_pos != 0) { decode = AllocateSizeof<
const size_t seq_len = Activations<TConfig, kDecodeBatchSize * kBatchedQueryBatchSize>>();
(TConfig::kSeqLen + kPrefillBatchSize);
kv_cache.kv_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
}
// TODO(patrickms): Add query batching support for Griffin.
if (TConfig::kGriffinLayers) {
constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
const size_t conv1d_cache_size =
TConfig::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
TConfig::kModelDim;
if (conv1d_cache_size != 0) {
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);
hwy::ZeroBytes(kv_cache.conv1d_cache.get(),
conv1d_cache_size * sizeof(kv_cache.conv1d_cache[0]));
}
const size_t rglru_cache_size =
TConfig::kGriffinLayers * TConfig::kModelDim;
if (rglru_cache_size != 0) {
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
hwy::ZeroBytes(kv_cache.rglru_cache.get(),
rglru_cache_size * sizeof(kv_cache.rglru_cache[0]));
}
} // kGriffinLayers
return kv_cache;
} }
}; };
} // namespace template <class TConfig, size_t kBatchSize>
Activations<TConfig, kBatchSize>& GetActivations(const ByteStorageT& state_u8) {
KVCache KVCache::Create(Model model_type) { return *reinterpret_cast<Activations<TConfig, kBatchSize>*>(state_u8.get());
// TWeight=float is a placeholder and unused because CreateKVCache does not
// use TConfig::Weight.
return CallForModel</*TWeight=*/float, CreateKVCache>(model_type);
}
class GemmaTokenizer::Impl {
public:
Impl() = default;
explicit Impl(const Path& tokenizer_path) {
PROFILER_ZONE("Startup.tokenizer");
spp_ = std::make_unique<sentencepiece::SentencePieceProcessor>();
if (!spp_->Load(tokenizer_path.path).ok()) {
HWY_ABORT("Failed to load the tokenizer file.");
}
}
bool Encode(const std::string& input,
std::vector<std::string>* pieces) const {
return spp_ && spp_->Encode(input, pieces).ok();
}
bool Encode(const std::string& input, std::vector<int>* ids) const {
if constexpr (kShowTokenization) {
bool is_ok = spp_ && spp_->Encode(input, ids).ok();
for (int i = 0; i < static_cast<int>(ids->size()); i++) {
fprintf(stderr, "%3d: %d\n", i, (*ids)[i]);
}
return is_ok;
} else {
return spp_ && spp_->Encode(input, ids).ok();
}
}
// Given a sequence of ids, decodes it into a detokenized output.
bool Decode(const std::vector<int>& ids, std::string* detokenized) const {
return spp_ && spp_->Decode(ids, detokenized).ok();
}
private:
std::unique_ptr<sentencepiece::SentencePieceProcessor> spp_;
};
GemmaTokenizer::GemmaTokenizer(const Path& tokenizer_path) {
impl_ = std::make_unique<Impl>(tokenizer_path);
}
// Default suffices, but they must be defined after GemmaTokenizer::Impl.
GemmaTokenizer::GemmaTokenizer() = default;
GemmaTokenizer::~GemmaTokenizer() = default;
GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default;
GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default;
bool GemmaTokenizer::Encode(const std::string& input,
std::vector<std::string>* pieces) const {
return impl_->Encode(input, pieces);
}
bool GemmaTokenizer::Encode(const std::string& input,
std::vector<int>* ids) const {
return impl_->Encode(input, ids);
}
// Given a sequence of ids, decodes it into a detokenized output.
bool GemmaTokenizer::Decode(const std::vector<int>& ids,
std::string* detokenized) const {
return impl_->Decode(ids, detokenized);
} }
// Placeholder for internal test2, do not remove // Placeholder for internal test2, do not remove
@ -797,15 +699,9 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
HWY_ASSERT(prompt_size > 0); HWY_ASSERT(prompt_size > 0);
} }
template <class TConfig, size_t kBatchSize>
Activations<TConfig, kBatchSize>& GetActivations(
const ByteStorageT& state_u8) {
return *reinterpret_cast<Activations<TConfig, kBatchSize>*>(
state_u8.get());
}
} // namespace } // namespace
// TODO(janwas): move into RuntimeConfig
bool StreamToken(size_t query_idx, size_t pos, int token, float prob, bool StreamToken(size_t query_idx, size_t pos, int token, float prob,
const RuntimeConfig& runtime_config) { const RuntimeConfig& runtime_config) {
if (runtime_config.batch_stream_token) { if (runtime_config.batch_stream_token) {
@ -1069,22 +965,6 @@ HWY_AFTER_NAMESPACE();
#if HWY_ONCE #if HWY_ONCE
namespace gcpp { namespace gcpp {
namespace {
template <typename TConfig>
struct AllocateState {
void operator()(ByteStorageT& prefill, ByteStorageT& decode) const {
// When batching queries, the prefill batch size is reduced by a factor
// of kBatchedQueryBatchSize
prefill = AllocateSizeof<
Activations<TConfig,
kMinAdjustedPrefillBatchSize * kBatchedQueryBatchSize>>();
decode = AllocateSizeof<
Activations<TConfig, kDecodeBatchSize * kBatchedQueryBatchSize>>();
}
};
} // namespace
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
const ModelInfo& info, hwy::ThreadPool& pool) const ModelInfo& info, hwy::ThreadPool& pool)
: pool_(pool), tokenizer_(tokenizer_path), info_(info) { : pool_(pool), tokenizer_(tokenizer_path), info_(info) {
@ -1136,6 +1016,7 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
} }
// TODO(janwas): move to common.h.
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) { void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
// Instruction-tuned models are trained to expect control tokens. // Instruction-tuned models are trained to expect control tokens.

View File

@ -17,59 +17,22 @@
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
#include <functional> #include <functional>
#include <memory>
#include <random> #include <random>
#include <string> #include <string>
#include <vector> #include <vector>
// IWYU pragma: begin_exports
#include "compression/io.h" // Path #include "compression/io.h" // Path
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/kv_cache.h"
#include "gemma/tokenizer.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// IWYU pragma: end_exports
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::bfloat16_t #include "hwy/base.h" // hwy::bfloat16_t
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp { namespace gcpp {
constexpr size_t kPrefillBatchSize = 16;
constexpr size_t kDecodeBatchSize = 1;
constexpr size_t kBatchedQueryBatchSize = 16;
constexpr size_t kMinAdjustedPrefillBatchSize =
HWY_MAX((size_t)1, kPrefillBatchSize / kBatchedQueryBatchSize);
struct KVCache {
hwy::AlignedFreeUniquePtr<float[]>
kv_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim * 2
hwy::AlignedFreeUniquePtr<float[]>
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kGriffinLayers
hwy::AlignedFreeUniquePtr<float[]>
rglru_cache; // kModelDim * kGriffinLayers
static KVCache Create(Model type);
};
// The tokenizer's end of sentence and beginning of sentence token ids.
constexpr int EOS_ID = 1;
constexpr int BOS_ID = 2;
class GemmaTokenizer {
public:
GemmaTokenizer();
explicit GemmaTokenizer(const Path& tokenizer_path);
// must come after definition of Impl
~GemmaTokenizer();
GemmaTokenizer(GemmaTokenizer&& other);
GemmaTokenizer& operator=(GemmaTokenizer&& other);
bool Encode(const std::string& input, std::vector<std::string>* pieces) const;
bool Encode(const std::string& input, std::vector<int>* ids) const;
bool Decode(const std::vector<int>& ids, std::string* detokenized) const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
// StreamFunc is called with (token, probability). For prompt tokens, // StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f. StreamFunc should return false to stop generation and // probability is 0.0f. StreamFunc should return false to stop generation and
// true to continue generation. // true to continue generation.
@ -93,13 +56,6 @@ using SampleFunc = std::function<int(const float*, size_t)>;
using LayersOutputFunc = using LayersOutputFunc =
std::function<void(int, const std::string&, const float*, size_t)>; std::function<void(int, const std::string&, const float*, size_t)>;
// TODO(janwas): move into common.h, merge with parser/ToString.
struct ModelInfo {
Model model;
ModelTraining training;
Type weight;
};
struct RuntimeConfig { struct RuntimeConfig {
size_t max_tokens; size_t max_tokens;
size_t max_generated_tokens; size_t max_generated_tokens;

67
gemma/kv_cache.cc Normal file
View File

@ -0,0 +1,67 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/kv_cache.h"
#include "gemma/common.h" // CallForModel
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // ZeroBytes
namespace gcpp {
namespace {
template <class TConfig>
struct CreateKVCache {
KVCache operator()() const {
KVCache kv_cache = {};
const size_t size_cache_pos = CachePosSize<TConfig>()();
if (size_cache_pos != 0) {
const size_t seq_len = (TConfig::kSeqLen + kPrefillBatchSize);
kv_cache.kv_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
}
// TODO(patrickms): Add query batching support for Griffin.
if (TConfig::kGriffinLayers) {
constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
const size_t conv1d_cache_size =
TConfig::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
TConfig::kModelDim;
if (conv1d_cache_size != 0) {
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);
hwy::ZeroBytes(kv_cache.conv1d_cache.get(),
conv1d_cache_size * sizeof(kv_cache.conv1d_cache[0]));
}
const size_t rglru_cache_size =
TConfig::kGriffinLayers * TConfig::kModelDim;
if (rglru_cache_size != 0) {
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
hwy::ZeroBytes(kv_cache.rglru_cache.get(),
rglru_cache_size * sizeof(kv_cache.rglru_cache[0]));
}
} // kGriffinLayers
return kv_cache;
}
};
} // namespace
KVCache KVCache::Create(Model model_type) {
// TWeight=float is a placeholder and unused because CreateKVCache does not
// use TConfig::Weight.
return CallForModel</*TWeight=*/float, CreateKVCache>(model_type);
}
} // namespace gcpp

39
gemma/kv_cache.h Normal file
View File

@ -0,0 +1,39 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
#include "gemma/common.h" // Model
#include "hwy/aligned_allocator.h"
namespace gcpp {
struct KVCache {
// kSeqLen * kGemmaLayers * kKVHeads * kQKVDim * 2
hwy::AlignedFreeUniquePtr<float[]> kv_cache;
// (kConv1dWidth - 1) * kModelDim * kGriffinLayers
hwy::AlignedFreeUniquePtr<float[]> conv1d_cache;
// kModelDim * kGriffinLayers
hwy::AlignedFreeUniquePtr<float[]> rglru_cache;
static KVCache Create(Model type);
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_

98
gemma/tokenizer.cc Normal file
View File

@ -0,0 +1,98 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/tokenizer.h"
#include <stdio.h>
#include <memory>
#include <string>
#include <vector>
#include "compression/io.h" // Path
#include "hwy/base.h"
#include "hwy/profiler.h"
// copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h"
namespace gcpp {
// Set this to true to debug tokenizer tokens.
constexpr bool kShowTokenization = false;
class GemmaTokenizer::Impl {
public:
Impl() = default;
explicit Impl(const Path& tokenizer_path) {
PROFILER_ZONE("Startup.tokenizer");
spp_ = std::make_unique<sentencepiece::SentencePieceProcessor>();
if (!spp_->Load(tokenizer_path.path).ok()) {
HWY_ABORT("Failed to load the tokenizer file.");
}
}
bool Encode(const std::string& input,
std::vector<std::string>* pieces) const {
return spp_ && spp_->Encode(input, pieces).ok();
}
bool Encode(const std::string& input, std::vector<int>* ids) const {
if constexpr (kShowTokenization) {
bool is_ok = spp_ && spp_->Encode(input, ids).ok();
for (int i = 0; i < static_cast<int>(ids->size()); i++) {
fprintf(stderr, "%3d: %d\n", i, (*ids)[i]);
}
return is_ok;
} else {
return spp_ && spp_->Encode(input, ids).ok();
}
}
// Given a sequence of ids, decodes it into a detokenized output.
bool Decode(const std::vector<int>& ids, std::string* detokenized) const {
return spp_ && spp_->Decode(ids, detokenized).ok();
}
private:
std::unique_ptr<sentencepiece::SentencePieceProcessor> spp_;
};
GemmaTokenizer::GemmaTokenizer(const Path& tokenizer_path) {
impl_ = std::make_unique<Impl>(tokenizer_path);
}
// Default suffices, but they must be defined after GemmaTokenizer::Impl.
GemmaTokenizer::GemmaTokenizer() = default;
GemmaTokenizer::~GemmaTokenizer() = default;
GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default;
GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default;
bool GemmaTokenizer::Encode(const std::string& input,
std::vector<std::string>* pieces) const {
return impl_->Encode(input, pieces);
}
bool GemmaTokenizer::Encode(const std::string& input,
std::vector<int>* ids) const {
return impl_->Encode(input, ids);
}
// Given a sequence of ids, decodes it into a detokenized output.
bool GemmaTokenizer::Decode(const std::vector<int>& ids,
std::string* detokenized) const {
return impl_->Decode(ids, detokenized);
}
} // namespace gcpp

52
gemma/tokenizer.h Normal file
View File

@ -0,0 +1,52 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TOKENIZER_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TOKENIZER_H_
#include <memory>
#include <string>
#include <vector>
#include "compression/io.h" // Path
namespace gcpp {
// The tokenizer's end of sentence and beginning of sentence token ids.
constexpr int EOS_ID = 1;
constexpr int BOS_ID = 2;
class GemmaTokenizer {
public:
GemmaTokenizer();
explicit GemmaTokenizer(const Path& tokenizer_path);
// must come after definition of Impl
~GemmaTokenizer();
GemmaTokenizer(GemmaTokenizer&& other);
GemmaTokenizer& operator=(GemmaTokenizer&& other);
bool Encode(const std::string& input, std::vector<std::string>* pieces) const;
bool Encode(const std::string& input, std::vector<int>* ids) const;
bool Decode(const std::vector<int>& ids, std::string* detokenized) const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TOKENIZER_H_