mirror of https://github.com/google/gemma.cpp.git
Prep for sharding gemma.cc: split into kv_cache, tokenizer.
Move activations.h to backprop/ to make space for another activations.h. PiperOrigin-RevId: 648744500
This commit is contained in:
parent
85fcd3cd80
commit
09a7e75ead
29
BUILD.bazel
29
BUILD.bazel
|
|
@ -94,13 +94,35 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tokenizer",
|
||||||
|
srcs = ["gemma/tokenizer.cc"],
|
||||||
|
hdrs = ["gemma/tokenizer.h"],
|
||||||
|
deps = [
|
||||||
|
"//compression:io",
|
||||||
|
"@hwy//:hwy",
|
||||||
|
"@hwy//:nanobenchmark", # timer
|
||||||
|
"@hwy//:profiler",
|
||||||
|
"@com_google_sentencepiece//:sentencepiece_processor",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "kv_cache",
|
||||||
|
srcs = ["gemma/kv_cache.cc"],
|
||||||
|
hdrs = ["gemma/kv_cache.h"],
|
||||||
|
deps = [
|
||||||
|
":common",
|
||||||
|
"@hwy//:hwy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gemma_lib",
|
name = "gemma_lib",
|
||||||
srcs = [
|
srcs = [
|
||||||
"gemma/gemma.cc",
|
"gemma/gemma.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"gemma/activations.h",
|
|
||||||
"gemma/gemma.h",
|
"gemma/gemma.h",
|
||||||
],
|
],
|
||||||
exec_properties = {
|
exec_properties = {
|
||||||
|
|
@ -114,6 +136,8 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":common",
|
":common",
|
||||||
":ops",
|
":ops",
|
||||||
|
":tokenizer",
|
||||||
|
":kv_cache",
|
||||||
":weights",
|
":weights",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
|
|
@ -122,7 +146,6 @@ cc_library(
|
||||||
"@hwy//:nanobenchmark", # timer
|
"@hwy//:nanobenchmark", # timer
|
||||||
"@hwy//:profiler",
|
"@hwy//:profiler",
|
||||||
"@hwy//:thread_pool",
|
"@hwy//:thread_pool",
|
||||||
"@com_google_sentencepiece//:sentencepiece_processor",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -321,6 +344,7 @@ cc_library(
|
||||||
"backprop/forward.cc",
|
"backprop/forward.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
|
"backprop/activations.h",
|
||||||
"backprop/backward.h",
|
"backprop/backward.h",
|
||||||
"backprop/backward-inl.h",
|
"backprop/backward-inl.h",
|
||||||
"backprop/forward.h",
|
"backprop/forward.h",
|
||||||
|
|
@ -340,6 +364,7 @@ cc_library(
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "backprop_scalar",
|
name = "backprop_scalar",
|
||||||
hdrs = [
|
hdrs = [
|
||||||
|
"backprop/activations.h",
|
||||||
"backprop/backward_scalar.h",
|
"backprop/backward_scalar.h",
|
||||||
"backprop/common_scalar.h",
|
"backprop/common_scalar.h",
|
||||||
"backprop/forward_scalar.h",
|
"backprop/forward_scalar.h",
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ set(SOURCES
|
||||||
compression/sfp.h
|
compression/sfp.h
|
||||||
compression/sfp-inl.h
|
compression/sfp-inl.h
|
||||||
compression/test_util.h
|
compression/test_util.h
|
||||||
|
backprop/activations.h
|
||||||
backprop/backward.cc
|
backprop/backward.cc
|
||||||
backprop/backward.h
|
backprop/backward.h
|
||||||
backprop/backward-inl.h
|
backprop/backward-inl.h
|
||||||
|
|
@ -62,18 +63,21 @@ set(SOURCES
|
||||||
backprop/optimizer.h
|
backprop/optimizer.h
|
||||||
evals/cross_entropy.cc
|
evals/cross_entropy.cc
|
||||||
evals/cross_entropy.h
|
evals/cross_entropy.h
|
||||||
gemma/configs.h
|
|
||||||
gemma/activations.h
|
|
||||||
gemma/benchmark_helper.cc
|
gemma/benchmark_helper.cc
|
||||||
gemma/benchmark_helper.h
|
gemma/benchmark_helper.h
|
||||||
gemma/common.cc
|
gemma/common.cc
|
||||||
gemma/common.h
|
gemma/common.h
|
||||||
|
gemma/configs.h
|
||||||
gemma/gemma.cc
|
gemma/gemma.cc
|
||||||
gemma/gemma.h
|
gemma/gemma.h
|
||||||
|
gemma/kv_cache.cc
|
||||||
|
gemma/kv_cache.h
|
||||||
gemma/ops.h
|
gemma/ops.h
|
||||||
|
gemma/tokenizer.cc
|
||||||
|
gemma/tokenizer.h
|
||||||
|
gemma/weights_raw.h
|
||||||
gemma/weights.cc
|
gemma/weights.cc
|
||||||
gemma/weights.h
|
gemma/weights.h
|
||||||
gemma/weights_raw.h
|
|
||||||
util/app.h
|
util/app.h
|
||||||
util/args.h
|
util/args.h
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,8 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
|
#define THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
|
|
@ -86,4 +86,4 @@ class ActivationsWrapper {
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_
|
||||||
|
|
@ -22,11 +22,11 @@
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "backprop/activations.h"
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "gemma/activations.h"
|
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,8 @@
|
||||||
|
|
||||||
#include "backprop/backward.h"
|
#include "backprop/backward.h"
|
||||||
|
|
||||||
|
#include "backprop/activations.h"
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "gemma/activations.h"
|
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,9 +22,9 @@
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "backprop/activations.h"
|
||||||
#include "backprop/common_scalar.h"
|
#include "backprop/common_scalar.h"
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "gemma/activations.h"
|
|
||||||
#include "gemma/common.h" // EmbeddingScaling
|
#include "gemma/common.h" // EmbeddingScaling
|
||||||
#include "gemma/weights_raw.h"
|
#include "gemma/weights_raw.h"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,12 +26,12 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
#include "backprop/activations.h"
|
||||||
#include "backprop/common_scalar.h"
|
#include "backprop/common_scalar.h"
|
||||||
#include "backprop/forward_scalar.h"
|
#include "backprop/forward_scalar.h"
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "backprop/sampler.h"
|
#include "backprop/sampler.h"
|
||||||
#include "backprop/test_util.h"
|
#include "backprop/test_util.h"
|
||||||
#include "gemma/activations.h"
|
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "gemma/weights_raw.h"
|
#include "gemma/weights_raw.h"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,12 +24,12 @@
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "backprop/activations.h"
|
||||||
#include "backprop/backward_scalar.h"
|
#include "backprop/backward_scalar.h"
|
||||||
#include "backprop/common_scalar.h"
|
#include "backprop/common_scalar.h"
|
||||||
#include "backprop/forward_scalar.h"
|
#include "backprop/forward_scalar.h"
|
||||||
#include "backprop/sampler.h"
|
#include "backprop/sampler.h"
|
||||||
#include "backprop/test_util.h"
|
#include "backprop/test_util.h"
|
||||||
#include "gemma/activations.h"
|
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "gemma/weights_raw.h"
|
#include "gemma/weights_raw.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gemma/activations.h"
|
#include "backprop/activations.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,8 @@
|
||||||
|
|
||||||
#include "backprop/forward.h"
|
#include "backprop/forward.h"
|
||||||
|
|
||||||
|
#include "backprop/activations.h"
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "gemma/activations.h"
|
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,9 +23,9 @@
|
||||||
#include <complex>
|
#include <complex>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "backprop/activations.h"
|
||||||
#include "backprop/common_scalar.h"
|
#include "backprop/common_scalar.h"
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "gemma/activations.h"
|
|
||||||
#include "gemma/common.h" // EmbeddingScaling
|
#include "gemma/common.h" // EmbeddingScaling
|
||||||
#include "gemma/weights_raw.h"
|
#include "gemma/weights_raw.h"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,12 +20,12 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
#include "backprop/activations.h"
|
||||||
#include "backprop/backward.h"
|
#include "backprop/backward.h"
|
||||||
#include "backprop/forward.h"
|
#include "backprop/forward.h"
|
||||||
#include "backprop/optimizer.h"
|
#include "backprop/optimizer.h"
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "backprop/sampler.h"
|
#include "backprop/sampler.h"
|
||||||
#include "gemma/activations.h"
|
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
||||||
|
|
||||||
#include <math.h> // sqrtf
|
#include <math.h> // sqrtf
|
||||||
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
@ -35,6 +36,12 @@ ByteStorageT AllocateSizeof() {
|
||||||
return hwy::AllocateAligned<uint8_t>(sizeof(T));
|
return hwy::AllocateAligned<uint8_t>(sizeof(T));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
constexpr size_t kPrefillBatchSize = 16;
|
||||||
|
constexpr size_t kDecodeBatchSize = 1;
|
||||||
|
constexpr size_t kBatchedQueryBatchSize = 16;
|
||||||
|
constexpr size_t kMinAdjustedPrefillBatchSize =
|
||||||
|
HWY_MAX((size_t)1, kPrefillBatchSize / kBatchedQueryBatchSize);
|
||||||
|
|
||||||
// Model variants: see configs.h for details.
|
// Model variants: see configs.h for details.
|
||||||
enum class Model {
|
enum class Model {
|
||||||
GEMMA_2B,
|
GEMMA_2B,
|
||||||
|
|
@ -51,6 +58,13 @@ enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
||||||
// Tensor types for loading weights.
|
// Tensor types for loading weights.
|
||||||
enum class Type { kF32, kBF16, kSFP };
|
enum class Type { kF32, kBF16, kSFP };
|
||||||
|
|
||||||
|
// TODO(janwas): merge with parser/ToString.
|
||||||
|
struct ModelInfo {
|
||||||
|
Model model;
|
||||||
|
ModelTraining training;
|
||||||
|
Type weight;
|
||||||
|
};
|
||||||
|
|
||||||
// Returns the return value of FuncT<Config*<TWeight>>().operator()(args), where
|
// Returns the return value of FuncT<Config*<TWeight>>().operator()(args), where
|
||||||
// Config* is selected via `model`. Typically called by CallForModelAndWeight,
|
// Config* is selected via `model`. Typically called by CallForModelAndWeight,
|
||||||
// but can also be called directly when FuncT does not actually use TWeight.
|
// but can also be called directly when FuncT does not actually use TWeight.
|
||||||
|
|
|
||||||
149
gemma/gemma.cc
149
gemma/gemma.cc
|
|
@ -37,7 +37,6 @@
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <memory>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility> // std::move
|
#include <utility> // std::move
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
@ -54,14 +53,9 @@
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
// copybara:import_next_line:sentencepiece
|
|
||||||
#include "src/sentencepiece_processor.h"
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// Set this to true to debug tokenizer tokens.
|
|
||||||
constexpr bool kShowTokenization = false;
|
|
||||||
|
|
||||||
// Must be aligned.
|
// Must be aligned.
|
||||||
template <class TConfig, size_t kBatchSize>
|
template <class TConfig, size_t kBatchSize>
|
||||||
struct Activations {
|
struct Activations {
|
||||||
|
|
@ -115,114 +109,22 @@ struct Activations {
|
||||||
griffin_multiplier;
|
griffin_multiplier;
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace {
|
template <typename TConfig>
|
||||||
|
struct AllocateState {
|
||||||
template <class TConfig>
|
void operator()(ByteStorageT& prefill, ByteStorageT& decode) const {
|
||||||
struct CreateKVCache {
|
// When batching queries, the prefill batch size is reduced by a factor
|
||||||
KVCache operator()() const {
|
// of kBatchedQueryBatchSize
|
||||||
KVCache kv_cache = {};
|
prefill =
|
||||||
|
AllocateSizeof<Activations<TConfig, kMinAdjustedPrefillBatchSize *
|
||||||
const size_t size_cache_pos = CachePosSize<TConfig>()();
|
kBatchedQueryBatchSize>>();
|
||||||
if (size_cache_pos != 0) {
|
decode = AllocateSizeof<
|
||||||
const size_t seq_len =
|
Activations<TConfig, kDecodeBatchSize * kBatchedQueryBatchSize>>();
|
||||||
(TConfig::kSeqLen + kPrefillBatchSize);
|
|
||||||
kv_cache.kv_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(patrickms): Add query batching support for Griffin.
|
|
||||||
if (TConfig::kGriffinLayers) {
|
|
||||||
constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
|
||||||
const size_t conv1d_cache_size =
|
|
||||||
TConfig::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
|
|
||||||
TConfig::kModelDim;
|
|
||||||
if (conv1d_cache_size != 0) {
|
|
||||||
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);
|
|
||||||
hwy::ZeroBytes(kv_cache.conv1d_cache.get(),
|
|
||||||
conv1d_cache_size * sizeof(kv_cache.conv1d_cache[0]));
|
|
||||||
}
|
|
||||||
|
|
||||||
const size_t rglru_cache_size =
|
|
||||||
TConfig::kGriffinLayers * TConfig::kModelDim;
|
|
||||||
if (rglru_cache_size != 0) {
|
|
||||||
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
|
|
||||||
hwy::ZeroBytes(kv_cache.rglru_cache.get(),
|
|
||||||
rglru_cache_size * sizeof(kv_cache.rglru_cache[0]));
|
|
||||||
}
|
|
||||||
} // kGriffinLayers
|
|
||||||
|
|
||||||
return kv_cache;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
template <class TConfig, size_t kBatchSize>
|
||||||
|
Activations<TConfig, kBatchSize>& GetActivations(const ByteStorageT& state_u8) {
|
||||||
KVCache KVCache::Create(Model model_type) {
|
return *reinterpret_cast<Activations<TConfig, kBatchSize>*>(state_u8.get());
|
||||||
// TWeight=float is a placeholder and unused because CreateKVCache does not
|
|
||||||
// use TConfig::Weight.
|
|
||||||
return CallForModel</*TWeight=*/float, CreateKVCache>(model_type);
|
|
||||||
}
|
|
||||||
|
|
||||||
class GemmaTokenizer::Impl {
|
|
||||||
public:
|
|
||||||
Impl() = default;
|
|
||||||
explicit Impl(const Path& tokenizer_path) {
|
|
||||||
PROFILER_ZONE("Startup.tokenizer");
|
|
||||||
spp_ = std::make_unique<sentencepiece::SentencePieceProcessor>();
|
|
||||||
if (!spp_->Load(tokenizer_path.path).ok()) {
|
|
||||||
HWY_ABORT("Failed to load the tokenizer file.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Encode(const std::string& input,
|
|
||||||
std::vector<std::string>* pieces) const {
|
|
||||||
return spp_ && spp_->Encode(input, pieces).ok();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Encode(const std::string& input, std::vector<int>* ids) const {
|
|
||||||
if constexpr (kShowTokenization) {
|
|
||||||
bool is_ok = spp_ && spp_->Encode(input, ids).ok();
|
|
||||||
for (int i = 0; i < static_cast<int>(ids->size()); i++) {
|
|
||||||
fprintf(stderr, "%3d: %d\n", i, (*ids)[i]);
|
|
||||||
}
|
|
||||||
return is_ok;
|
|
||||||
} else {
|
|
||||||
return spp_ && spp_->Encode(input, ids).ok();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Given a sequence of ids, decodes it into a detokenized output.
|
|
||||||
bool Decode(const std::vector<int>& ids, std::string* detokenized) const {
|
|
||||||
return spp_ && spp_->Decode(ids, detokenized).ok();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> spp_;
|
|
||||||
};
|
|
||||||
|
|
||||||
GemmaTokenizer::GemmaTokenizer(const Path& tokenizer_path) {
|
|
||||||
impl_ = std::make_unique<Impl>(tokenizer_path);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default suffices, but they must be defined after GemmaTokenizer::Impl.
|
|
||||||
GemmaTokenizer::GemmaTokenizer() = default;
|
|
||||||
GemmaTokenizer::~GemmaTokenizer() = default;
|
|
||||||
GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default;
|
|
||||||
GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default;
|
|
||||||
|
|
||||||
bool GemmaTokenizer::Encode(const std::string& input,
|
|
||||||
std::vector<std::string>* pieces) const {
|
|
||||||
return impl_->Encode(input, pieces);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool GemmaTokenizer::Encode(const std::string& input,
|
|
||||||
std::vector<int>* ids) const {
|
|
||||||
return impl_->Encode(input, ids);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Given a sequence of ids, decodes it into a detokenized output.
|
|
||||||
bool GemmaTokenizer::Decode(const std::vector<int>& ids,
|
|
||||||
std::string* detokenized) const {
|
|
||||||
return impl_->Decode(ids, detokenized);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Placeholder for internal test2, do not remove
|
// Placeholder for internal test2, do not remove
|
||||||
|
|
@ -797,15 +699,9 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
||||||
HWY_ASSERT(prompt_size > 0);
|
HWY_ASSERT(prompt_size > 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig, size_t kBatchSize>
|
|
||||||
Activations<TConfig, kBatchSize>& GetActivations(
|
|
||||||
const ByteStorageT& state_u8) {
|
|
||||||
return *reinterpret_cast<Activations<TConfig, kBatchSize>*>(
|
|
||||||
state_u8.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// TODO(janwas): move into RuntimeConfig
|
||||||
bool StreamToken(size_t query_idx, size_t pos, int token, float prob,
|
bool StreamToken(size_t query_idx, size_t pos, int token, float prob,
|
||||||
const RuntimeConfig& runtime_config) {
|
const RuntimeConfig& runtime_config) {
|
||||||
if (runtime_config.batch_stream_token) {
|
if (runtime_config.batch_stream_token) {
|
||||||
|
|
@ -1069,22 +965,6 @@ HWY_AFTER_NAMESPACE();
|
||||||
#if HWY_ONCE
|
#if HWY_ONCE
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
namespace {
|
|
||||||
template <typename TConfig>
|
|
||||||
struct AllocateState {
|
|
||||||
void operator()(ByteStorageT& prefill, ByteStorageT& decode) const {
|
|
||||||
// When batching queries, the prefill batch size is reduced by a factor
|
|
||||||
// of kBatchedQueryBatchSize
|
|
||||||
prefill = AllocateSizeof<
|
|
||||||
Activations<TConfig,
|
|
||||||
kMinAdjustedPrefillBatchSize * kBatchedQueryBatchSize>>();
|
|
||||||
decode = AllocateSizeof<
|
|
||||||
Activations<TConfig, kDecodeBatchSize * kBatchedQueryBatchSize>>();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
||||||
const ModelInfo& info, hwy::ThreadPool& pool)
|
const ModelInfo& info, hwy::ThreadPool& pool)
|
||||||
: pool_(pool), tokenizer_(tokenizer_path), info_(info) {
|
: pool_(pool), tokenizer_(tokenizer_path), info_(info) {
|
||||||
|
|
@ -1136,6 +1016,7 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(janwas): move to common.h.
|
||||||
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
|
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
|
||||||
|
|
||||||
// Instruction-tuned models are trained to expect control tokens.
|
// Instruction-tuned models are trained to expect control tokens.
|
||||||
|
|
|
||||||
|
|
@ -17,59 +17,22 @@
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
// IWYU pragma: begin_exports
|
||||||
#include "compression/io.h" // Path
|
#include "compression/io.h" // Path
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
|
#include "gemma/kv_cache.h"
|
||||||
|
#include "gemma/tokenizer.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
// IWYU pragma: end_exports
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
#include "hwy/base.h" // hwy::bfloat16_t
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
constexpr size_t kPrefillBatchSize = 16;
|
|
||||||
constexpr size_t kDecodeBatchSize = 1;
|
|
||||||
constexpr size_t kBatchedQueryBatchSize = 16;
|
|
||||||
constexpr size_t kMinAdjustedPrefillBatchSize =
|
|
||||||
HWY_MAX((size_t)1, kPrefillBatchSize / kBatchedQueryBatchSize);
|
|
||||||
|
|
||||||
struct KVCache {
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]>
|
|
||||||
kv_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim * 2
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]>
|
|
||||||
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kGriffinLayers
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]>
|
|
||||||
rglru_cache; // kModelDim * kGriffinLayers
|
|
||||||
|
|
||||||
static KVCache Create(Model type);
|
|
||||||
};
|
|
||||||
|
|
||||||
// The tokenizer's end of sentence and beginning of sentence token ids.
|
|
||||||
constexpr int EOS_ID = 1;
|
|
||||||
constexpr int BOS_ID = 2;
|
|
||||||
|
|
||||||
class GemmaTokenizer {
|
|
||||||
public:
|
|
||||||
GemmaTokenizer();
|
|
||||||
explicit GemmaTokenizer(const Path& tokenizer_path);
|
|
||||||
|
|
||||||
// must come after definition of Impl
|
|
||||||
~GemmaTokenizer();
|
|
||||||
GemmaTokenizer(GemmaTokenizer&& other);
|
|
||||||
GemmaTokenizer& operator=(GemmaTokenizer&& other);
|
|
||||||
|
|
||||||
bool Encode(const std::string& input, std::vector<std::string>* pieces) const;
|
|
||||||
bool Encode(const std::string& input, std::vector<int>* ids) const;
|
|
||||||
bool Decode(const std::vector<int>& ids, std::string* detokenized) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
class Impl;
|
|
||||||
std::unique_ptr<Impl> impl_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||||
// probability is 0.0f. StreamFunc should return false to stop generation and
|
// probability is 0.0f. StreamFunc should return false to stop generation and
|
||||||
// true to continue generation.
|
// true to continue generation.
|
||||||
|
|
@ -93,13 +56,6 @@ using SampleFunc = std::function<int(const float*, size_t)>;
|
||||||
using LayersOutputFunc =
|
using LayersOutputFunc =
|
||||||
std::function<void(int, const std::string&, const float*, size_t)>;
|
std::function<void(int, const std::string&, const float*, size_t)>;
|
||||||
|
|
||||||
// TODO(janwas): move into common.h, merge with parser/ToString.
|
|
||||||
struct ModelInfo {
|
|
||||||
Model model;
|
|
||||||
ModelTraining training;
|
|
||||||
Type weight;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct RuntimeConfig {
|
struct RuntimeConfig {
|
||||||
size_t max_tokens;
|
size_t max_tokens;
|
||||||
size_t max_generated_tokens;
|
size_t max_generated_tokens;
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "gemma/kv_cache.h"
|
||||||
|
|
||||||
|
#include "gemma/common.h" // CallForModel
|
||||||
|
#include "hwy/aligned_allocator.h"
|
||||||
|
#include "hwy/base.h" // ZeroBytes
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
namespace {
|
||||||
|
template <class TConfig>
|
||||||
|
struct CreateKVCache {
|
||||||
|
KVCache operator()() const {
|
||||||
|
KVCache kv_cache = {};
|
||||||
|
|
||||||
|
const size_t size_cache_pos = CachePosSize<TConfig>()();
|
||||||
|
if (size_cache_pos != 0) {
|
||||||
|
const size_t seq_len = (TConfig::kSeqLen + kPrefillBatchSize);
|
||||||
|
kv_cache.kv_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(patrickms): Add query batching support for Griffin.
|
||||||
|
if (TConfig::kGriffinLayers) {
|
||||||
|
constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||||
|
const size_t conv1d_cache_size =
|
||||||
|
TConfig::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
|
||||||
|
TConfig::kModelDim;
|
||||||
|
if (conv1d_cache_size != 0) {
|
||||||
|
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);
|
||||||
|
hwy::ZeroBytes(kv_cache.conv1d_cache.get(),
|
||||||
|
conv1d_cache_size * sizeof(kv_cache.conv1d_cache[0]));
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t rglru_cache_size =
|
||||||
|
TConfig::kGriffinLayers * TConfig::kModelDim;
|
||||||
|
if (rglru_cache_size != 0) {
|
||||||
|
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
|
||||||
|
hwy::ZeroBytes(kv_cache.rglru_cache.get(),
|
||||||
|
rglru_cache_size * sizeof(kv_cache.rglru_cache[0]));
|
||||||
|
}
|
||||||
|
} // kGriffinLayers
|
||||||
|
|
||||||
|
return kv_cache;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
KVCache KVCache::Create(Model model_type) {
|
||||||
|
// TWeight=float is a placeholder and unused because CreateKVCache does not
|
||||||
|
// use TConfig::Weight.
|
||||||
|
return CallForModel</*TWeight=*/float, CreateKVCache>(model_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
|
||||||
|
|
||||||
|
#include "gemma/common.h" // Model
|
||||||
|
#include "hwy/aligned_allocator.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
struct KVCache {
|
||||||
|
// kSeqLen * kGemmaLayers * kKVHeads * kQKVDim * 2
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> kv_cache;
|
||||||
|
|
||||||
|
// (kConv1dWidth - 1) * kModelDim * kGriffinLayers
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> conv1d_cache;
|
||||||
|
|
||||||
|
// kModelDim * kGriffinLayers
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> rglru_cache;
|
||||||
|
|
||||||
|
static KVCache Create(Model type);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
|
||||||
|
|
@ -0,0 +1,98 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "gemma/tokenizer.h"
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "compression/io.h" // Path
|
||||||
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/profiler.h"
|
||||||
|
// copybara:import_next_line:sentencepiece
|
||||||
|
#include "src/sentencepiece_processor.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
// Set this to true to debug tokenizer tokens.
|
||||||
|
constexpr bool kShowTokenization = false;
|
||||||
|
|
||||||
|
class GemmaTokenizer::Impl {
|
||||||
|
public:
|
||||||
|
Impl() = default;
|
||||||
|
explicit Impl(const Path& tokenizer_path) {
|
||||||
|
PROFILER_ZONE("Startup.tokenizer");
|
||||||
|
spp_ = std::make_unique<sentencepiece::SentencePieceProcessor>();
|
||||||
|
if (!spp_->Load(tokenizer_path.path).ok()) {
|
||||||
|
HWY_ABORT("Failed to load the tokenizer file.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Encode(const std::string& input,
|
||||||
|
std::vector<std::string>* pieces) const {
|
||||||
|
return spp_ && spp_->Encode(input, pieces).ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Encode(const std::string& input, std::vector<int>* ids) const {
|
||||||
|
if constexpr (kShowTokenization) {
|
||||||
|
bool is_ok = spp_ && spp_->Encode(input, ids).ok();
|
||||||
|
for (int i = 0; i < static_cast<int>(ids->size()); i++) {
|
||||||
|
fprintf(stderr, "%3d: %d\n", i, (*ids)[i]);
|
||||||
|
}
|
||||||
|
return is_ok;
|
||||||
|
} else {
|
||||||
|
return spp_ && spp_->Encode(input, ids).ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Given a sequence of ids, decodes it into a detokenized output.
|
||||||
|
bool Decode(const std::vector<int>& ids, std::string* detokenized) const {
|
||||||
|
return spp_ && spp_->Decode(ids, detokenized).ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<sentencepiece::SentencePieceProcessor> spp_;
|
||||||
|
};
|
||||||
|
|
||||||
|
GemmaTokenizer::GemmaTokenizer(const Path& tokenizer_path) {
|
||||||
|
impl_ = std::make_unique<Impl>(tokenizer_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default suffices, but they must be defined after GemmaTokenizer::Impl.
|
||||||
|
GemmaTokenizer::GemmaTokenizer() = default;
|
||||||
|
GemmaTokenizer::~GemmaTokenizer() = default;
|
||||||
|
GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default;
|
||||||
|
GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default;
|
||||||
|
|
||||||
|
bool GemmaTokenizer::Encode(const std::string& input,
|
||||||
|
std::vector<std::string>* pieces) const {
|
||||||
|
return impl_->Encode(input, pieces);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GemmaTokenizer::Encode(const std::string& input,
|
||||||
|
std::vector<int>* ids) const {
|
||||||
|
return impl_->Encode(input, ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Given a sequence of ids, decodes it into a detokenized output.
|
||||||
|
bool GemmaTokenizer::Decode(const std::vector<int>& ids,
|
||||||
|
std::string* detokenized) const {
|
||||||
|
return impl_->Decode(ids, detokenized);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TOKENIZER_H_
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TOKENIZER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "compression/io.h" // Path
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
// The tokenizer's end of sentence and beginning of sentence token ids.
|
||||||
|
constexpr int EOS_ID = 1;
|
||||||
|
constexpr int BOS_ID = 2;
|
||||||
|
|
||||||
|
class GemmaTokenizer {
|
||||||
|
public:
|
||||||
|
GemmaTokenizer();
|
||||||
|
explicit GemmaTokenizer(const Path& tokenizer_path);
|
||||||
|
|
||||||
|
// must come after definition of Impl
|
||||||
|
~GemmaTokenizer();
|
||||||
|
GemmaTokenizer(GemmaTokenizer&& other);
|
||||||
|
GemmaTokenizer& operator=(GemmaTokenizer&& other);
|
||||||
|
|
||||||
|
bool Encode(const std::string& input, std::vector<std::string>* pieces) const;
|
||||||
|
bool Encode(const std::string& input, std::vector<int>* ids) const;
|
||||||
|
bool Decode(const std::vector<int>& ids, std::string* detokenized) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Impl;
|
||||||
|
std::unique_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TOKENIZER_H_
|
||||||
Loading…
Reference in New Issue