mirror of https://github.com/google/gemma.cpp.git
Prep for sharding gemma.cc: split into kv_cache, tokenizer.
Move activations.h to backprop/ to make space for another activations.h. PiperOrigin-RevId: 648744500
This commit is contained in:
parent
85fcd3cd80
commit
09a7e75ead
29
BUILD.bazel
29
BUILD.bazel
|
|
@ -94,13 +94,35 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tokenizer",
|
||||
srcs = ["gemma/tokenizer.cc"],
|
||||
hdrs = ["gemma/tokenizer.h"],
|
||||
deps = [
|
||||
"//compression:io",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:nanobenchmark", # timer
|
||||
"@hwy//:profiler",
|
||||
"@com_google_sentencepiece//:sentencepiece_processor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "kv_cache",
|
||||
srcs = ["gemma/kv_cache.cc"],
|
||||
hdrs = ["gemma/kv_cache.h"],
|
||||
deps = [
|
||||
":common",
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gemma_lib",
|
||||
srcs = [
|
||||
"gemma/gemma.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/activations.h",
|
||||
"gemma/gemma.h",
|
||||
],
|
||||
exec_properties = {
|
||||
|
|
@ -114,6 +136,8 @@ cc_library(
|
|||
deps = [
|
||||
":common",
|
||||
":ops",
|
||||
":tokenizer",
|
||||
":kv_cache",
|
||||
":weights",
|
||||
"//compression:compress",
|
||||
"//compression:io",
|
||||
|
|
@ -122,7 +146,6 @@ cc_library(
|
|||
"@hwy//:nanobenchmark", # timer
|
||||
"@hwy//:profiler",
|
||||
"@hwy//:thread_pool",
|
||||
"@com_google_sentencepiece//:sentencepiece_processor",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -321,6 +344,7 @@ cc_library(
|
|||
"backprop/forward.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"backprop/activations.h",
|
||||
"backprop/backward.h",
|
||||
"backprop/backward-inl.h",
|
||||
"backprop/forward.h",
|
||||
|
|
@ -340,6 +364,7 @@ cc_library(
|
|||
cc_library(
|
||||
name = "backprop_scalar",
|
||||
hdrs = [
|
||||
"backprop/activations.h",
|
||||
"backprop/backward_scalar.h",
|
||||
"backprop/common_scalar.h",
|
||||
"backprop/forward_scalar.h",
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ set(SOURCES
|
|||
compression/sfp.h
|
||||
compression/sfp-inl.h
|
||||
compression/test_util.h
|
||||
backprop/activations.h
|
||||
backprop/backward.cc
|
||||
backprop/backward.h
|
||||
backprop/backward-inl.h
|
||||
|
|
@ -62,18 +63,21 @@ set(SOURCES
|
|||
backprop/optimizer.h
|
||||
evals/cross_entropy.cc
|
||||
evals/cross_entropy.h
|
||||
gemma/configs.h
|
||||
gemma/activations.h
|
||||
gemma/benchmark_helper.cc
|
||||
gemma/benchmark_helper.h
|
||||
gemma/common.cc
|
||||
gemma/common.h
|
||||
gemma/configs.h
|
||||
gemma/gemma.cc
|
||||
gemma/gemma.h
|
||||
gemma/kv_cache.cc
|
||||
gemma/kv_cache.h
|
||||
gemma/ops.h
|
||||
gemma/tokenizer.cc
|
||||
gemma/tokenizer.h
|
||||
gemma/weights_raw.h
|
||||
gemma/weights.cc
|
||||
gemma/weights.h
|
||||
gemma/weights_raw.h
|
||||
util/app.h
|
||||
util/args.h
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
|
|
@ -86,4 +86,4 @@ class ActivationsWrapper {
|
|||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_
|
||||
|
|
@ -22,11 +22,11 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
|
|||
|
|
@ -15,8 +15,8 @@
|
|||
|
||||
#include "backprop/backward.h"
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -22,9 +22,9 @@
|
|||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/common_scalar.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h" // EmbeddingScaling
|
||||
#include "gemma/weights_raw.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -26,12 +26,12 @@
|
|||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/common_scalar.h"
|
||||
#include "backprop/forward_scalar.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "backprop/test_util.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights_raw.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -24,12 +24,12 @@
|
|||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/backward_scalar.h"
|
||||
#include "backprop/common_scalar.h"
|
||||
#include "backprop/forward_scalar.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "backprop/test_util.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights_raw.h"
|
||||
#include "hwy/base.h"
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/activations.h"
|
||||
#include "backprop/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "hwy/base.h"
|
||||
|
|
|
|||
|
|
@ -15,8 +15,8 @@
|
|||
|
||||
#include "backprop/forward.h"
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@
|
|||
#include <complex>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/common_scalar.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h" // EmbeddingScaling
|
||||
#include "gemma/weights_raw.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -20,12 +20,12 @@
|
|||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/backward.h"
|
||||
#include "backprop/forward.h"
|
||||
#include "backprop/optimizer.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
||||
|
||||
#include <math.h> // sqrtf
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <string>
|
||||
|
|
@ -35,6 +36,12 @@ ByteStorageT AllocateSizeof() {
|
|||
return hwy::AllocateAligned<uint8_t>(sizeof(T));
|
||||
}
|
||||
|
||||
constexpr size_t kPrefillBatchSize = 16;
|
||||
constexpr size_t kDecodeBatchSize = 1;
|
||||
constexpr size_t kBatchedQueryBatchSize = 16;
|
||||
constexpr size_t kMinAdjustedPrefillBatchSize =
|
||||
HWY_MAX((size_t)1, kPrefillBatchSize / kBatchedQueryBatchSize);
|
||||
|
||||
// Model variants: see configs.h for details.
|
||||
enum class Model {
|
||||
GEMMA_2B,
|
||||
|
|
@ -51,6 +58,13 @@ enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
|||
// Tensor types for loading weights.
|
||||
enum class Type { kF32, kBF16, kSFP };
|
||||
|
||||
// TODO(janwas): merge with parser/ToString.
|
||||
struct ModelInfo {
|
||||
Model model;
|
||||
ModelTraining training;
|
||||
Type weight;
|
||||
};
|
||||
|
||||
// Returns the return value of FuncT<Config*<TWeight>>().operator()(args), where
|
||||
// Config* is selected via `model`. Typically called by CallForModelAndWeight,
|
||||
// but can also be called directly when FuncT does not actually use TWeight.
|
||||
|
|
|
|||
149
gemma/gemma.cc
149
gemma/gemma.cc
|
|
@ -37,7 +37,6 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility> // std::move
|
||||
#include <vector>
|
||||
|
|
@ -54,14 +53,9 @@
|
|||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/timer.h"
|
||||
// copybara:import_next_line:sentencepiece
|
||||
#include "src/sentencepiece_processor.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Set this to true to debug tokenizer tokens.
|
||||
constexpr bool kShowTokenization = false;
|
||||
|
||||
// Must be aligned.
|
||||
template <class TConfig, size_t kBatchSize>
|
||||
struct Activations {
|
||||
|
|
@ -115,114 +109,22 @@ struct Activations {
|
|||
griffin_multiplier;
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
template <class TConfig>
|
||||
struct CreateKVCache {
|
||||
KVCache operator()() const {
|
||||
KVCache kv_cache = {};
|
||||
|
||||
const size_t size_cache_pos = CachePosSize<TConfig>()();
|
||||
if (size_cache_pos != 0) {
|
||||
const size_t seq_len =
|
||||
(TConfig::kSeqLen + kPrefillBatchSize);
|
||||
kv_cache.kv_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
}
|
||||
|
||||
// TODO(patrickms): Add query batching support for Griffin.
|
||||
if (TConfig::kGriffinLayers) {
|
||||
constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||
const size_t conv1d_cache_size =
|
||||
TConfig::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
|
||||
TConfig::kModelDim;
|
||||
if (conv1d_cache_size != 0) {
|
||||
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);
|
||||
hwy::ZeroBytes(kv_cache.conv1d_cache.get(),
|
||||
conv1d_cache_size * sizeof(kv_cache.conv1d_cache[0]));
|
||||
}
|
||||
|
||||
const size_t rglru_cache_size =
|
||||
TConfig::kGriffinLayers * TConfig::kModelDim;
|
||||
if (rglru_cache_size != 0) {
|
||||
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
|
||||
hwy::ZeroBytes(kv_cache.rglru_cache.get(),
|
||||
rglru_cache_size * sizeof(kv_cache.rglru_cache[0]));
|
||||
}
|
||||
} // kGriffinLayers
|
||||
|
||||
return kv_cache;
|
||||
template <typename TConfig>
|
||||
struct AllocateState {
|
||||
void operator()(ByteStorageT& prefill, ByteStorageT& decode) const {
|
||||
// When batching queries, the prefill batch size is reduced by a factor
|
||||
// of kBatchedQueryBatchSize
|
||||
prefill =
|
||||
AllocateSizeof<Activations<TConfig, kMinAdjustedPrefillBatchSize *
|
||||
kBatchedQueryBatchSize>>();
|
||||
decode = AllocateSizeof<
|
||||
Activations<TConfig, kDecodeBatchSize * kBatchedQueryBatchSize>>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
KVCache KVCache::Create(Model model_type) {
|
||||
// TWeight=float is a placeholder and unused because CreateKVCache does not
|
||||
// use TConfig::Weight.
|
||||
return CallForModel</*TWeight=*/float, CreateKVCache>(model_type);
|
||||
}
|
||||
|
||||
class GemmaTokenizer::Impl {
|
||||
public:
|
||||
Impl() = default;
|
||||
explicit Impl(const Path& tokenizer_path) {
|
||||
PROFILER_ZONE("Startup.tokenizer");
|
||||
spp_ = std::make_unique<sentencepiece::SentencePieceProcessor>();
|
||||
if (!spp_->Load(tokenizer_path.path).ok()) {
|
||||
HWY_ABORT("Failed to load the tokenizer file.");
|
||||
}
|
||||
}
|
||||
|
||||
bool Encode(const std::string& input,
|
||||
std::vector<std::string>* pieces) const {
|
||||
return spp_ && spp_->Encode(input, pieces).ok();
|
||||
}
|
||||
|
||||
bool Encode(const std::string& input, std::vector<int>* ids) const {
|
||||
if constexpr (kShowTokenization) {
|
||||
bool is_ok = spp_ && spp_->Encode(input, ids).ok();
|
||||
for (int i = 0; i < static_cast<int>(ids->size()); i++) {
|
||||
fprintf(stderr, "%3d: %d\n", i, (*ids)[i]);
|
||||
}
|
||||
return is_ok;
|
||||
} else {
|
||||
return spp_ && spp_->Encode(input, ids).ok();
|
||||
}
|
||||
}
|
||||
|
||||
// Given a sequence of ids, decodes it into a detokenized output.
|
||||
bool Decode(const std::vector<int>& ids, std::string* detokenized) const {
|
||||
return spp_ && spp_->Decode(ids, detokenized).ok();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> spp_;
|
||||
};
|
||||
|
||||
GemmaTokenizer::GemmaTokenizer(const Path& tokenizer_path) {
|
||||
impl_ = std::make_unique<Impl>(tokenizer_path);
|
||||
}
|
||||
|
||||
// Default suffices, but they must be defined after GemmaTokenizer::Impl.
|
||||
GemmaTokenizer::GemmaTokenizer() = default;
|
||||
GemmaTokenizer::~GemmaTokenizer() = default;
|
||||
GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default;
|
||||
GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default;
|
||||
|
||||
bool GemmaTokenizer::Encode(const std::string& input,
|
||||
std::vector<std::string>* pieces) const {
|
||||
return impl_->Encode(input, pieces);
|
||||
}
|
||||
|
||||
bool GemmaTokenizer::Encode(const std::string& input,
|
||||
std::vector<int>* ids) const {
|
||||
return impl_->Encode(input, ids);
|
||||
}
|
||||
|
||||
// Given a sequence of ids, decodes it into a detokenized output.
|
||||
bool GemmaTokenizer::Decode(const std::vector<int>& ids,
|
||||
std::string* detokenized) const {
|
||||
return impl_->Decode(ids, detokenized);
|
||||
template <class TConfig, size_t kBatchSize>
|
||||
Activations<TConfig, kBatchSize>& GetActivations(const ByteStorageT& state_u8) {
|
||||
return *reinterpret_cast<Activations<TConfig, kBatchSize>*>(state_u8.get());
|
||||
}
|
||||
|
||||
// Placeholder for internal test2, do not remove
|
||||
|
|
@ -797,15 +699,9 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
|||
HWY_ASSERT(prompt_size > 0);
|
||||
}
|
||||
|
||||
template <class TConfig, size_t kBatchSize>
|
||||
Activations<TConfig, kBatchSize>& GetActivations(
|
||||
const ByteStorageT& state_u8) {
|
||||
return *reinterpret_cast<Activations<TConfig, kBatchSize>*>(
|
||||
state_u8.get());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// TODO(janwas): move into RuntimeConfig
|
||||
bool StreamToken(size_t query_idx, size_t pos, int token, float prob,
|
||||
const RuntimeConfig& runtime_config) {
|
||||
if (runtime_config.batch_stream_token) {
|
||||
|
|
@ -1069,22 +965,6 @@ HWY_AFTER_NAMESPACE();
|
|||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
|
||||
namespace {
|
||||
template <typename TConfig>
|
||||
struct AllocateState {
|
||||
void operator()(ByteStorageT& prefill, ByteStorageT& decode) const {
|
||||
// When batching queries, the prefill batch size is reduced by a factor
|
||||
// of kBatchedQueryBatchSize
|
||||
prefill = AllocateSizeof<
|
||||
Activations<TConfig,
|
||||
kMinAdjustedPrefillBatchSize * kBatchedQueryBatchSize>>();
|
||||
decode = AllocateSizeof<
|
||||
Activations<TConfig, kDecodeBatchSize * kBatchedQueryBatchSize>>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
||||
const ModelInfo& info, hwy::ThreadPool& pool)
|
||||
: pool_(pool), tokenizer_(tokenizer_path), info_(info) {
|
||||
|
|
@ -1136,6 +1016,7 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
|||
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
|
||||
// TODO(janwas): move to common.h.
|
||||
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
|
||||
|
||||
// Instruction-tuned models are trained to expect control tokens.
|
||||
|
|
|
|||
|
|
@ -17,59 +17,22 @@
|
|||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
// IWYU pragma: end_exports
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // hwy::bfloat16_t
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
constexpr size_t kPrefillBatchSize = 16;
|
||||
constexpr size_t kDecodeBatchSize = 1;
|
||||
constexpr size_t kBatchedQueryBatchSize = 16;
|
||||
constexpr size_t kMinAdjustedPrefillBatchSize =
|
||||
HWY_MAX((size_t)1, kPrefillBatchSize / kBatchedQueryBatchSize);
|
||||
|
||||
struct KVCache {
|
||||
hwy::AlignedFreeUniquePtr<float[]>
|
||||
kv_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim * 2
|
||||
hwy::AlignedFreeUniquePtr<float[]>
|
||||
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kGriffinLayers
|
||||
hwy::AlignedFreeUniquePtr<float[]>
|
||||
rglru_cache; // kModelDim * kGriffinLayers
|
||||
|
||||
static KVCache Create(Model type);
|
||||
};
|
||||
|
||||
// The tokenizer's end of sentence and beginning of sentence token ids.
|
||||
constexpr int EOS_ID = 1;
|
||||
constexpr int BOS_ID = 2;
|
||||
|
||||
class GemmaTokenizer {
|
||||
public:
|
||||
GemmaTokenizer();
|
||||
explicit GemmaTokenizer(const Path& tokenizer_path);
|
||||
|
||||
// must come after definition of Impl
|
||||
~GemmaTokenizer();
|
||||
GemmaTokenizer(GemmaTokenizer&& other);
|
||||
GemmaTokenizer& operator=(GemmaTokenizer&& other);
|
||||
|
||||
bool Encode(const std::string& input, std::vector<std::string>* pieces) const;
|
||||
bool Encode(const std::string& input, std::vector<int>* ids) const;
|
||||
bool Decode(const std::vector<int>& ids, std::string* detokenized) const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||
// probability is 0.0f. StreamFunc should return false to stop generation and
|
||||
// true to continue generation.
|
||||
|
|
@ -93,13 +56,6 @@ using SampleFunc = std::function<int(const float*, size_t)>;
|
|||
using LayersOutputFunc =
|
||||
std::function<void(int, const std::string&, const float*, size_t)>;
|
||||
|
||||
// TODO(janwas): move into common.h, merge with parser/ToString.
|
||||
struct ModelInfo {
|
||||
Model model;
|
||||
ModelTraining training;
|
||||
Type weight;
|
||||
};
|
||||
|
||||
struct RuntimeConfig {
|
||||
size_t max_tokens;
|
||||
size_t max_generated_tokens;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,67 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/kv_cache.h"
|
||||
|
||||
#include "gemma/common.h" // CallForModel
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // ZeroBytes
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
template <class TConfig>
|
||||
struct CreateKVCache {
|
||||
KVCache operator()() const {
|
||||
KVCache kv_cache = {};
|
||||
|
||||
const size_t size_cache_pos = CachePosSize<TConfig>()();
|
||||
if (size_cache_pos != 0) {
|
||||
const size_t seq_len = (TConfig::kSeqLen + kPrefillBatchSize);
|
||||
kv_cache.kv_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
}
|
||||
|
||||
// TODO(patrickms): Add query batching support for Griffin.
|
||||
if (TConfig::kGriffinLayers) {
|
||||
constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||
const size_t conv1d_cache_size =
|
||||
TConfig::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
|
||||
TConfig::kModelDim;
|
||||
if (conv1d_cache_size != 0) {
|
||||
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);
|
||||
hwy::ZeroBytes(kv_cache.conv1d_cache.get(),
|
||||
conv1d_cache_size * sizeof(kv_cache.conv1d_cache[0]));
|
||||
}
|
||||
|
||||
const size_t rglru_cache_size =
|
||||
TConfig::kGriffinLayers * TConfig::kModelDim;
|
||||
if (rglru_cache_size != 0) {
|
||||
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
|
||||
hwy::ZeroBytes(kv_cache.rglru_cache.get(),
|
||||
rglru_cache_size * sizeof(kv_cache.rglru_cache[0]));
|
||||
}
|
||||
} // kGriffinLayers
|
||||
|
||||
return kv_cache;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
KVCache KVCache::Create(Model model_type) {
|
||||
// TWeight=float is a placeholder and unused because CreateKVCache does not
|
||||
// use TConfig::Weight.
|
||||
return CallForModel</*TWeight=*/float, CreateKVCache>(model_type);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
|
||||
|
||||
#include "gemma/common.h" // Model
|
||||
#include "hwy/aligned_allocator.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
struct KVCache {
|
||||
// kSeqLen * kGemmaLayers * kKVHeads * kQKVDim * 2
|
||||
hwy::AlignedFreeUniquePtr<float[]> kv_cache;
|
||||
|
||||
// (kConv1dWidth - 1) * kModelDim * kGriffinLayers
|
||||
hwy::AlignedFreeUniquePtr<float[]> conv1d_cache;
|
||||
|
||||
// kModelDim * kGriffinLayers
|
||||
hwy::AlignedFreeUniquePtr<float[]> rglru_cache;
|
||||
|
||||
static KVCache Create(Model type);
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/tokenizer.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/profiler.h"
|
||||
// copybara:import_next_line:sentencepiece
|
||||
#include "src/sentencepiece_processor.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Set this to true to debug tokenizer tokens.
|
||||
constexpr bool kShowTokenization = false;
|
||||
|
||||
class GemmaTokenizer::Impl {
|
||||
public:
|
||||
Impl() = default;
|
||||
explicit Impl(const Path& tokenizer_path) {
|
||||
PROFILER_ZONE("Startup.tokenizer");
|
||||
spp_ = std::make_unique<sentencepiece::SentencePieceProcessor>();
|
||||
if (!spp_->Load(tokenizer_path.path).ok()) {
|
||||
HWY_ABORT("Failed to load the tokenizer file.");
|
||||
}
|
||||
}
|
||||
|
||||
bool Encode(const std::string& input,
|
||||
std::vector<std::string>* pieces) const {
|
||||
return spp_ && spp_->Encode(input, pieces).ok();
|
||||
}
|
||||
|
||||
bool Encode(const std::string& input, std::vector<int>* ids) const {
|
||||
if constexpr (kShowTokenization) {
|
||||
bool is_ok = spp_ && spp_->Encode(input, ids).ok();
|
||||
for (int i = 0; i < static_cast<int>(ids->size()); i++) {
|
||||
fprintf(stderr, "%3d: %d\n", i, (*ids)[i]);
|
||||
}
|
||||
return is_ok;
|
||||
} else {
|
||||
return spp_ && spp_->Encode(input, ids).ok();
|
||||
}
|
||||
}
|
||||
|
||||
// Given a sequence of ids, decodes it into a detokenized output.
|
||||
bool Decode(const std::vector<int>& ids, std::string* detokenized) const {
|
||||
return spp_ && spp_->Decode(ids, detokenized).ok();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> spp_;
|
||||
};
|
||||
|
||||
GemmaTokenizer::GemmaTokenizer(const Path& tokenizer_path) {
|
||||
impl_ = std::make_unique<Impl>(tokenizer_path);
|
||||
}
|
||||
|
||||
// Default suffices, but they must be defined after GemmaTokenizer::Impl.
|
||||
GemmaTokenizer::GemmaTokenizer() = default;
|
||||
GemmaTokenizer::~GemmaTokenizer() = default;
|
||||
GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default;
|
||||
GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default;
|
||||
|
||||
bool GemmaTokenizer::Encode(const std::string& input,
|
||||
std::vector<std::string>* pieces) const {
|
||||
return impl_->Encode(input, pieces);
|
||||
}
|
||||
|
||||
bool GemmaTokenizer::Encode(const std::string& input,
|
||||
std::vector<int>* ids) const {
|
||||
return impl_->Encode(input, ids);
|
||||
}
|
||||
|
||||
// Given a sequence of ids, decodes it into a detokenized output.
|
||||
bool GemmaTokenizer::Decode(const std::vector<int>& ids,
|
||||
std::string* detokenized) const {
|
||||
return impl_->Decode(ids, detokenized);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TOKENIZER_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TOKENIZER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// The tokenizer's end of sentence and beginning of sentence token ids.
|
||||
constexpr int EOS_ID = 1;
|
||||
constexpr int BOS_ID = 2;
|
||||
|
||||
class GemmaTokenizer {
|
||||
public:
|
||||
GemmaTokenizer();
|
||||
explicit GemmaTokenizer(const Path& tokenizer_path);
|
||||
|
||||
// must come after definition of Impl
|
||||
~GemmaTokenizer();
|
||||
GemmaTokenizer(GemmaTokenizer&& other);
|
||||
GemmaTokenizer& operator=(GemmaTokenizer&& other);
|
||||
|
||||
bool Encode(const std::string& input, std::vector<std::string>* pieces) const;
|
||||
bool Encode(const std::string& input, std::vector<int>* ids) const;
|
||||
bool Decode(const std::vector<int>& ids, std::string* detokenized) const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TOKENIZER_H_
|
||||
Loading…
Reference in New Issue