// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #include #include #include #include // copybara:import_next_line:gemma_cpp #include "compression/compress.h" // SfpStream/NuqStream // copybara:import_next_line:gemma_cpp #include "configs.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t #include "hwy/contrib/thread_pool/thread_pool.h" // copybara:import_next_line:gemma_cpp #include "util/args.h" // Path // copybara:import_next_line:sentencepiece #include "src/sentencepiece_processor.h" namespace gcpp { using GemmaWeightT = GEMMA_WEIGHT_T; using EmbedderInputT = hwy::bfloat16_t; constexpr size_t kPrefillBatchSize = 16; constexpr bool kSystemPrompt = false; struct KVCache { hwy::AlignedFreeUniquePtr key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim hwy::AlignedFreeUniquePtr value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim }; // Model variants: see configs.h for details. enum class Model { GEMMA_2B, GEMMA_7B }; enum class ModelTraining { GEMMA_IT, GEMMA_PT }; struct RuntimeConfig { size_t max_tokens; size_t max_generated_tokens; float temperature; int verbosity; }; struct GemmaInterface; struct Gemma { Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; }; KVCache CreateKVCache(Model type); // convenient workaround for now KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len); // StreamFunc is called with (token, probability). For prompt tokens, // probability is 0.0f. using StreamFunc = std::function; using AcceptFunc = std::function; void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity); // Convenience function for the common case: // - Bundle runtime parameters as RuntimeConfig // - No threadpools within threadpools (inner_pool = dummy) // - All tokens accepted void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, const StreamFunc& stream_token, std::mt19937& gen); void CompressWeights(gcpp::Model model, const Path& weights, const Path& compressed_weights, hwy::ThreadPool& pool); constexpr int EOS_ID = 1; } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_