diff --git a/BUILD.bazel b/BUILD.bazel index a4d960c..10b1e42 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -94,13 +94,35 @@ cc_library( ], ) +cc_library( + name = "tokenizer", + srcs = ["gemma/tokenizer.cc"], + hdrs = ["gemma/tokenizer.h"], + deps = [ + "//compression:io", + "@hwy//:hwy", + "@hwy//:nanobenchmark", # timer + "@hwy//:profiler", + "@com_google_sentencepiece//:sentencepiece_processor", + ], +) + +cc_library( + name = "kv_cache", + srcs = ["gemma/kv_cache.cc"], + hdrs = ["gemma/kv_cache.h"], + deps = [ + ":common", + "@hwy//:hwy", + ], +) + cc_library( name = "gemma_lib", srcs = [ "gemma/gemma.cc", ], hdrs = [ - "gemma/activations.h", "gemma/gemma.h", ], exec_properties = { @@ -114,6 +136,8 @@ cc_library( deps = [ ":common", ":ops", + ":tokenizer", + ":kv_cache", ":weights", "//compression:compress", "//compression:io", @@ -122,7 +146,6 @@ cc_library( "@hwy//:nanobenchmark", # timer "@hwy//:profiler", "@hwy//:thread_pool", - "@com_google_sentencepiece//:sentencepiece_processor", ], ) @@ -321,6 +344,7 @@ cc_library( "backprop/forward.cc", ], hdrs = [ + "backprop/activations.h", "backprop/backward.h", "backprop/backward-inl.h", "backprop/forward.h", @@ -340,6 +364,7 @@ cc_library( cc_library( name = "backprop_scalar", hdrs = [ + "backprop/activations.h", "backprop/backward_scalar.h", "backprop/common_scalar.h", "backprop/forward_scalar.h", diff --git a/CMakeLists.txt b/CMakeLists.txt index 5e0a9d1..7da189b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,6 +49,7 @@ set(SOURCES compression/sfp.h compression/sfp-inl.h compression/test_util.h + backprop/activations.h backprop/backward.cc backprop/backward.h backprop/backward-inl.h @@ -62,18 +63,21 @@ set(SOURCES backprop/optimizer.h evals/cross_entropy.cc evals/cross_entropy.h - gemma/configs.h - gemma/activations.h gemma/benchmark_helper.cc gemma/benchmark_helper.h gemma/common.cc gemma/common.h + gemma/configs.h gemma/gemma.cc gemma/gemma.h + gemma/kv_cache.cc + gemma/kv_cache.h gemma/ops.h + gemma/tokenizer.cc + gemma/tokenizer.h + gemma/weights_raw.h gemma/weights.cc gemma/weights.h - gemma/weights_raw.h util/app.h util/args.h ) diff --git a/gemma/activations.h b/backprop/activations.h similarity index 94% rename from gemma/activations.h rename to backprop/activations.h index 6d2bc22..b3bb455 100644 --- a/gemma/activations.h +++ b/backprop/activations.h @@ -13,8 +13,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ +#ifndef THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_ +#define THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_ #include @@ -86,4 +86,4 @@ class ActivationsWrapper { } // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ +#endif // THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_ diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index 1ef6658..67b0aa4 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -22,11 +22,11 @@ #include -#include #include +#include +#include "backprop/activations.h" #include "backprop/prompt.h" -#include "gemma/activations.h" #include "gemma/common.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" diff --git a/backprop/backward.cc b/backprop/backward.cc index 87ede98..89bbef3 100644 --- a/backprop/backward.cc +++ b/backprop/backward.cc @@ -15,8 +15,8 @@ #include "backprop/backward.h" +#include "backprop/activations.h" #include "backprop/prompt.h" -#include "gemma/activations.h" #include "gemma/common.h" #include "hwy/contrib/thread_pool/thread_pool.h" diff --git a/backprop/backward_scalar.h b/backprop/backward_scalar.h index 77cd76f..8a23272 100644 --- a/backprop/backward_scalar.h +++ b/backprop/backward_scalar.h @@ -22,9 +22,9 @@ #include #include +#include "backprop/activations.h" #include "backprop/common_scalar.h" #include "backprop/prompt.h" -#include "gemma/activations.h" #include "gemma/common.h" // EmbeddingScaling #include "gemma/weights_raw.h" diff --git a/backprop/backward_scalar_test.cc b/backprop/backward_scalar_test.cc index 85f63bc..706b0ef 100644 --- a/backprop/backward_scalar_test.cc +++ b/backprop/backward_scalar_test.cc @@ -26,12 +26,12 @@ #include #include "gtest/gtest.h" +#include "backprop/activations.h" #include "backprop/common_scalar.h" #include "backprop/forward_scalar.h" #include "backprop/prompt.h" #include "backprop/sampler.h" #include "backprop/test_util.h" -#include "gemma/activations.h" #include "gemma/configs.h" #include "gemma/weights_raw.h" diff --git a/backprop/backward_test.cc b/backprop/backward_test.cc index 432882e..0cbf69d 100644 --- a/backprop/backward_test.cc +++ b/backprop/backward_test.cc @@ -24,12 +24,12 @@ #include #include +#include "backprop/activations.h" #include "backprop/backward_scalar.h" #include "backprop/common_scalar.h" #include "backprop/forward_scalar.h" #include "backprop/sampler.h" #include "backprop/test_util.h" -#include "gemma/activations.h" #include "gemma/configs.h" #include "gemma/weights_raw.h" #include "hwy/base.h" diff --git a/backprop/forward-inl.h b/backprop/forward-inl.h index 4b58036..c24116f 100644 --- a/backprop/forward-inl.h +++ b/backprop/forward-inl.h @@ -23,7 +23,7 @@ #include #include -#include "gemma/activations.h" +#include "backprop/activations.h" #include "gemma/common.h" #include "gemma/configs.h" #include "hwy/base.h" diff --git a/backprop/forward.cc b/backprop/forward.cc index 1c8670e..1357276 100644 --- a/backprop/forward.cc +++ b/backprop/forward.cc @@ -15,8 +15,8 @@ #include "backprop/forward.h" +#include "backprop/activations.h" #include "backprop/prompt.h" -#include "gemma/activations.h" #include "gemma/common.h" #include "hwy/contrib/thread_pool/thread_pool.h" diff --git a/backprop/forward_scalar.h b/backprop/forward_scalar.h index 6fd58d2..60c8025 100644 --- a/backprop/forward_scalar.h +++ b/backprop/forward_scalar.h @@ -23,9 +23,9 @@ #include #include +#include "backprop/activations.h" #include "backprop/common_scalar.h" #include "backprop/prompt.h" -#include "gemma/activations.h" #include "gemma/common.h" // EmbeddingScaling #include "gemma/weights_raw.h" diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index e0ccd90..2e031d0 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -20,12 +20,12 @@ #include #include "gtest/gtest.h" +#include "backprop/activations.h" #include "backprop/backward.h" #include "backprop/forward.h" #include "backprop/optimizer.h" #include "backprop/prompt.h" #include "backprop/sampler.h" -#include "gemma/activations.h" #include "gemma/common.h" #include "gemma/gemma.h" #include "gemma/weights.h" diff --git a/gemma/common.h b/gemma/common.h index 35f6e78..f4ed10d 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -17,6 +17,7 @@ #define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ #include // sqrtf +#include #include #include @@ -35,6 +36,12 @@ ByteStorageT AllocateSizeof() { return hwy::AllocateAligned(sizeof(T)); } +constexpr size_t kPrefillBatchSize = 16; +constexpr size_t kDecodeBatchSize = 1; +constexpr size_t kBatchedQueryBatchSize = 16; +constexpr size_t kMinAdjustedPrefillBatchSize = + HWY_MAX((size_t)1, kPrefillBatchSize / kBatchedQueryBatchSize); + // Model variants: see configs.h for details. enum class Model { GEMMA_2B, @@ -51,6 +58,13 @@ enum class ModelTraining { GEMMA_IT, GEMMA_PT }; // Tensor types for loading weights. enum class Type { kF32, kBF16, kSFP }; +// TODO(janwas): merge with parser/ToString. +struct ModelInfo { + Model model; + ModelTraining training; + Type weight; +}; + // Returns the return value of FuncT>().operator()(args), where // Config* is selected via `model`. Typically called by CallForModelAndWeight, // but can also be called directly when FuncT does not actually use TWeight. diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 48b4ddc..99ab616 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -37,7 +37,6 @@ #include #include -#include #include #include // std::move #include @@ -54,14 +53,9 @@ #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/profiler.h" #include "hwy/timer.h" -// copybara:import_next_line:sentencepiece -#include "src/sentencepiece_processor.h" namespace gcpp { -// Set this to true to debug tokenizer tokens. -constexpr bool kShowTokenization = false; - // Must be aligned. template struct Activations { @@ -115,114 +109,22 @@ struct Activations { griffin_multiplier; }; -namespace { - -template -struct CreateKVCache { - KVCache operator()() const { - KVCache kv_cache = {}; - - const size_t size_cache_pos = CachePosSize()(); - if (size_cache_pos != 0) { - const size_t seq_len = - (TConfig::kSeqLen + kPrefillBatchSize); - kv_cache.kv_cache = hwy::AllocateAligned(seq_len * size_cache_pos); - } - - // TODO(patrickms): Add query batching support for Griffin. - if (TConfig::kGriffinLayers) { - constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; - const size_t conv1d_cache_size = - TConfig::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) * - TConfig::kModelDim; - if (conv1d_cache_size != 0) { - kv_cache.conv1d_cache = hwy::AllocateAligned(conv1d_cache_size); - hwy::ZeroBytes(kv_cache.conv1d_cache.get(), - conv1d_cache_size * sizeof(kv_cache.conv1d_cache[0])); - } - - const size_t rglru_cache_size = - TConfig::kGriffinLayers * TConfig::kModelDim; - if (rglru_cache_size != 0) { - kv_cache.rglru_cache = hwy::AllocateAligned(rglru_cache_size); - hwy::ZeroBytes(kv_cache.rglru_cache.get(), - rglru_cache_size * sizeof(kv_cache.rglru_cache[0])); - } - } // kGriffinLayers - - return kv_cache; +template +struct AllocateState { + void operator()(ByteStorageT& prefill, ByteStorageT& decode) const { + // When batching queries, the prefill batch size is reduced by a factor + // of kBatchedQueryBatchSize + prefill = + AllocateSizeof>(); + decode = AllocateSizeof< + Activations>(); } }; -} // namespace - -KVCache KVCache::Create(Model model_type) { - // TWeight=float is a placeholder and unused because CreateKVCache does not - // use TConfig::Weight. - return CallForModel(model_type); -} - -class GemmaTokenizer::Impl { - public: - Impl() = default; - explicit Impl(const Path& tokenizer_path) { - PROFILER_ZONE("Startup.tokenizer"); - spp_ = std::make_unique(); - if (!spp_->Load(tokenizer_path.path).ok()) { - HWY_ABORT("Failed to load the tokenizer file."); - } - } - - bool Encode(const std::string& input, - std::vector* pieces) const { - return spp_ && spp_->Encode(input, pieces).ok(); - } - - bool Encode(const std::string& input, std::vector* ids) const { - if constexpr (kShowTokenization) { - bool is_ok = spp_ && spp_->Encode(input, ids).ok(); - for (int i = 0; i < static_cast(ids->size()); i++) { - fprintf(stderr, "%3d: %d\n", i, (*ids)[i]); - } - return is_ok; - } else { - return spp_ && spp_->Encode(input, ids).ok(); - } - } - - // Given a sequence of ids, decodes it into a detokenized output. - bool Decode(const std::vector& ids, std::string* detokenized) const { - return spp_ && spp_->Decode(ids, detokenized).ok(); - } - - private: - std::unique_ptr spp_; -}; - -GemmaTokenizer::GemmaTokenizer(const Path& tokenizer_path) { - impl_ = std::make_unique(tokenizer_path); -} - -// Default suffices, but they must be defined after GemmaTokenizer::Impl. -GemmaTokenizer::GemmaTokenizer() = default; -GemmaTokenizer::~GemmaTokenizer() = default; -GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default; -GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default; - -bool GemmaTokenizer::Encode(const std::string& input, - std::vector* pieces) const { - return impl_->Encode(input, pieces); -} - -bool GemmaTokenizer::Encode(const std::string& input, - std::vector* ids) const { - return impl_->Encode(input, ids); -} - -// Given a sequence of ids, decodes it into a detokenized output. -bool GemmaTokenizer::Decode(const std::vector& ids, - std::string* detokenized) const { - return impl_->Decode(ids, detokenized); +template +Activations& GetActivations(const ByteStorageT& state_u8) { + return *reinterpret_cast*>(state_u8.get()); } // Placeholder for internal test2, do not remove @@ -797,15 +699,9 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens, HWY_ASSERT(prompt_size > 0); } -template -Activations& GetActivations( - const ByteStorageT& state_u8) { - return *reinterpret_cast*>( - state_u8.get()); -} - } // namespace +// TODO(janwas): move into RuntimeConfig bool StreamToken(size_t query_idx, size_t pos, int token, float prob, const RuntimeConfig& runtime_config) { if (runtime_config.batch_stream_token) { @@ -1069,22 +965,6 @@ HWY_AFTER_NAMESPACE(); #if HWY_ONCE namespace gcpp { -namespace { -template -struct AllocateState { - void operator()(ByteStorageT& prefill, ByteStorageT& decode) const { - // When batching queries, the prefill batch size is reduced by a factor - // of kBatchedQueryBatchSize - prefill = AllocateSizeof< - Activations>(); - decode = AllocateSizeof< - Activations>(); - } -}; - -} // namespace - Gemma::Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, hwy::ThreadPool& pool) : pool_(pool), tokenizer_(tokenizer_path), info_(info) { @@ -1136,6 +1016,7 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); } +// TODO(janwas): move to common.h. void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) { // Instruction-tuned models are trained to expect control tokens. diff --git a/gemma/gemma.h b/gemma/gemma.h index 4b2afc2..e35f0ef 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -17,59 +17,22 @@ #define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_ #include -#include #include #include #include +// IWYU pragma: begin_exports #include "compression/io.h" // Path #include "gemma/common.h" +#include "gemma/kv_cache.h" +#include "gemma/tokenizer.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +// IWYU pragma: end_exports #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t -#include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { -constexpr size_t kPrefillBatchSize = 16; -constexpr size_t kDecodeBatchSize = 1; -constexpr size_t kBatchedQueryBatchSize = 16; -constexpr size_t kMinAdjustedPrefillBatchSize = - HWY_MAX((size_t)1, kPrefillBatchSize / kBatchedQueryBatchSize); - -struct KVCache { - hwy::AlignedFreeUniquePtr - kv_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim * 2 - hwy::AlignedFreeUniquePtr - conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kGriffinLayers - hwy::AlignedFreeUniquePtr - rglru_cache; // kModelDim * kGriffinLayers - - static KVCache Create(Model type); -}; - -// The tokenizer's end of sentence and beginning of sentence token ids. -constexpr int EOS_ID = 1; -constexpr int BOS_ID = 2; - -class GemmaTokenizer { - public: - GemmaTokenizer(); - explicit GemmaTokenizer(const Path& tokenizer_path); - - // must come after definition of Impl - ~GemmaTokenizer(); - GemmaTokenizer(GemmaTokenizer&& other); - GemmaTokenizer& operator=(GemmaTokenizer&& other); - - bool Encode(const std::string& input, std::vector* pieces) const; - bool Encode(const std::string& input, std::vector* ids) const; - bool Decode(const std::vector& ids, std::string* detokenized) const; - - private: - class Impl; - std::unique_ptr impl_; -}; - // StreamFunc is called with (token, probability). For prompt tokens, // probability is 0.0f. StreamFunc should return false to stop generation and // true to continue generation. @@ -93,13 +56,6 @@ using SampleFunc = std::function; using LayersOutputFunc = std::function; -// TODO(janwas): move into common.h, merge with parser/ToString. -struct ModelInfo { - Model model; - ModelTraining training; - Type weight; -}; - struct RuntimeConfig { size_t max_tokens; size_t max_generated_tokens; diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc new file mode 100644 index 0000000..10e76e7 --- /dev/null +++ b/gemma/kv_cache.cc @@ -0,0 +1,67 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gemma/kv_cache.h" + +#include "gemma/common.h" // CallForModel +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" // ZeroBytes + +namespace gcpp { +namespace { +template +struct CreateKVCache { + KVCache operator()() const { + KVCache kv_cache = {}; + + const size_t size_cache_pos = CachePosSize()(); + if (size_cache_pos != 0) { + const size_t seq_len = (TConfig::kSeqLen + kPrefillBatchSize); + kv_cache.kv_cache = hwy::AllocateAligned(seq_len * size_cache_pos); + } + + // TODO(patrickms): Add query batching support for Griffin. + if (TConfig::kGriffinLayers) { + constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; + const size_t conv1d_cache_size = + TConfig::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) * + TConfig::kModelDim; + if (conv1d_cache_size != 0) { + kv_cache.conv1d_cache = hwy::AllocateAligned(conv1d_cache_size); + hwy::ZeroBytes(kv_cache.conv1d_cache.get(), + conv1d_cache_size * sizeof(kv_cache.conv1d_cache[0])); + } + + const size_t rglru_cache_size = + TConfig::kGriffinLayers * TConfig::kModelDim; + if (rglru_cache_size != 0) { + kv_cache.rglru_cache = hwy::AllocateAligned(rglru_cache_size); + hwy::ZeroBytes(kv_cache.rglru_cache.get(), + rglru_cache_size * sizeof(kv_cache.rglru_cache[0])); + } + } // kGriffinLayers + + return kv_cache; + } +}; +} // namespace + +KVCache KVCache::Create(Model model_type) { + // TWeight=float is a placeholder and unused because CreateKVCache does not + // use TConfig::Weight. + return CallForModel(model_type); +} + +} // namespace gcpp diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h new file mode 100644 index 0000000..1c92b40 --- /dev/null +++ b/gemma/kv_cache.h @@ -0,0 +1,39 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_ + +#include "gemma/common.h" // Model +#include "hwy/aligned_allocator.h" + +namespace gcpp { + +struct KVCache { + // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim * 2 + hwy::AlignedFreeUniquePtr kv_cache; + + // (kConv1dWidth - 1) * kModelDim * kGriffinLayers + hwy::AlignedFreeUniquePtr conv1d_cache; + + // kModelDim * kGriffinLayers + hwy::AlignedFreeUniquePtr rglru_cache; + + static KVCache Create(Model type); +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_ diff --git a/gemma/tokenizer.cc b/gemma/tokenizer.cc new file mode 100644 index 0000000..0142573 --- /dev/null +++ b/gemma/tokenizer.cc @@ -0,0 +1,98 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gemma/tokenizer.h" + +#include + +#include +#include +#include + +#include "compression/io.h" // Path +#include "hwy/base.h" +#include "hwy/profiler.h" +// copybara:import_next_line:sentencepiece +#include "src/sentencepiece_processor.h" + +namespace gcpp { + +// Set this to true to debug tokenizer tokens. +constexpr bool kShowTokenization = false; + +class GemmaTokenizer::Impl { + public: + Impl() = default; + explicit Impl(const Path& tokenizer_path) { + PROFILER_ZONE("Startup.tokenizer"); + spp_ = std::make_unique(); + if (!spp_->Load(tokenizer_path.path).ok()) { + HWY_ABORT("Failed to load the tokenizer file."); + } + } + + bool Encode(const std::string& input, + std::vector* pieces) const { + return spp_ && spp_->Encode(input, pieces).ok(); + } + + bool Encode(const std::string& input, std::vector* ids) const { + if constexpr (kShowTokenization) { + bool is_ok = spp_ && spp_->Encode(input, ids).ok(); + for (int i = 0; i < static_cast(ids->size()); i++) { + fprintf(stderr, "%3d: %d\n", i, (*ids)[i]); + } + return is_ok; + } else { + return spp_ && spp_->Encode(input, ids).ok(); + } + } + + // Given a sequence of ids, decodes it into a detokenized output. + bool Decode(const std::vector& ids, std::string* detokenized) const { + return spp_ && spp_->Decode(ids, detokenized).ok(); + } + + private: + std::unique_ptr spp_; +}; + +GemmaTokenizer::GemmaTokenizer(const Path& tokenizer_path) { + impl_ = std::make_unique(tokenizer_path); +} + +// Default suffices, but they must be defined after GemmaTokenizer::Impl. +GemmaTokenizer::GemmaTokenizer() = default; +GemmaTokenizer::~GemmaTokenizer() = default; +GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default; +GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default; + +bool GemmaTokenizer::Encode(const std::string& input, + std::vector* pieces) const { + return impl_->Encode(input, pieces); +} + +bool GemmaTokenizer::Encode(const std::string& input, + std::vector* ids) const { + return impl_->Encode(input, ids); +} + +// Given a sequence of ids, decodes it into a detokenized output. +bool GemmaTokenizer::Decode(const std::vector& ids, + std::string* detokenized) const { + return impl_->Decode(ids, detokenized); +} + +} // namespace gcpp diff --git a/gemma/tokenizer.h b/gemma/tokenizer.h new file mode 100644 index 0000000..f42daa7 --- /dev/null +++ b/gemma/tokenizer.h @@ -0,0 +1,52 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TOKENIZER_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_TOKENIZER_H_ + +#include +#include +#include + +#include "compression/io.h" // Path + +namespace gcpp { + +// The tokenizer's end of sentence and beginning of sentence token ids. +constexpr int EOS_ID = 1; +constexpr int BOS_ID = 2; + +class GemmaTokenizer { + public: + GemmaTokenizer(); + explicit GemmaTokenizer(const Path& tokenizer_path); + + // must come after definition of Impl + ~GemmaTokenizer(); + GemmaTokenizer(GemmaTokenizer&& other); + GemmaTokenizer& operator=(GemmaTokenizer&& other); + + bool Encode(const std::string& input, std::vector* pieces) const; + bool Encode(const std::string& input, std::vector* ids) const; + bool Decode(const std::vector& ids, std::string* detokenized) const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TOKENIZER_H_