// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #include #include #include #include #include #include #include // copybara:import_next_line:gemma_cpp #include "compression/compress.h" // SfpStream/NuqStream // copybara:import_next_line:gemma_cpp #include "configs.h" // kSeqLen // copybara:import_next_line:gemma_cpp #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t #include "hwy/contrib/thread_pool/thread_pool.h" #include "util/args.h" // ArgsBase // copybara:import_next_line:sentencepiece #include "src/sentencepiece_processor.h" namespace gcpp { // Allowable types for GEMMA_WEIGHT_T (can be specified at compilation time): // float, hwy::bfloat16_t, SfpStream, NuqStream #ifndef GEMMA_WEIGHT_T #define GEMMA_WEIGHT_T SfpStream #endif // !GEMMA_WEIGHT_T using WeightT = GEMMA_WEIGHT_T; using EmbedderInputT = hwy::bfloat16_t; constexpr size_t kPrefillBatchSize = 16; constexpr bool kSystemPrompt = false; struct KVCache { hwy::AlignedFreeUniquePtr key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim hwy::AlignedFreeUniquePtr value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim }; // Model variants: see configs.h for details. enum class Model { GEMMA_2B, GEMMA_7B }; enum class ModelTraining { GEMMA_IT, GEMMA_PT }; struct LoaderArgs : public ArgsBase { LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } static std::string ToLower(const std::string& text) { std::string result = text; std::transform(begin(result), end(result), begin(result), [](unsigned char c) { return std::tolower(c); }); return result; } gcpp::Model ModelType() const { const std::string model_type_lc = ToLower(model_type); if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") { return gcpp::Model::GEMMA_2B; } else { return gcpp::Model::GEMMA_7B; } } gcpp::ModelTraining ModelTraining() const { const std::string model_type_lc = ToLower(model_type); if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") { return gcpp::ModelTraining::GEMMA_PT; } else { return gcpp::ModelTraining::GEMMA_IT; } } // Returns error string or nullptr if OK. const char* Validate() const { const std::string model_type_lc = ToLower(model_type); if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" && model_type_lc != "2b-it" && model_type_lc != "7b-it") { return "Model type must be 2b-pt, 7b-pt, 2b-it, or " "7b-it."; } if (tokenizer.path.empty()) { return "Missing --tokenizer flag, a file for the tokenizer is required."; } if (model_type.empty()) { return "Missing --model flag, need to specify either 2b-pt, 7b-pt, " "2b-it, or 7b-it."; } if (cache.path.empty()) { return "Missing --compressed_weights flag, a file for the compressed " "model."; } return nullptr; } Path tokenizer; Path model; // uncompressed weights OR Path cache; // compressed weights std::string model_type; template void ForEach(const Visitor& visitor) { visitor(tokenizer, "tokenizer", Path(), "Path name of tokenizer model file. (required)"); visitor( cache, "compressed_weights", Path(), "Path name of compressed weights file, regenerated from `--weights` " "file if " "the compressed weights file does not exist. (required)"); visitor(model_type, "model", std::string(), "Model type - can be 2b-it (2B parameters, instruction-tuned), " "2b-pt (2B parameters, pretrained), 7b-it (7B parameters, " "instruction-tuned), or 7b-pt (7B parameters, pretrained). " "(required)"); visitor(model, "weights", Path(), "Path name of model weights (.sbs) file. Only required if " "compressed_weights file is not present and needs to be " "regenerated. Otherwise, not needed"); } }; struct GemmaInterface; struct Gemma { Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. const sentencepiece::SentencePieceProcessor& Tokenizer() const; std::unique_ptr impl_; gcpp::ModelTraining model_training; }; // StreamFunc is called with (token, probability). For prompt tokens, // probability is 0.0f. using StreamFunc = std::function; using AcceptFunc = std::function; struct InferenceArgs : public ArgsBase { InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } size_t max_tokens; size_t max_generated_tokens; float temperature; bool deterministic; bool multiturn; // Returns error string or nullptr if OK. const char* Validate() const { if (max_tokens > gcpp::kSeqLen) { return "max_tokens is larger than the maximum sequence length (see " "configs.h)."; } if (max_generated_tokens > max_tokens) { return "Maximum number of generated tokens is larger than the maximum " "total tokens."; } return nullptr; } template void ForEach(const Visitor& visitor) { visitor(max_tokens, "max_tokens", size_t{3072}, "Maximum number of tokens in prompt + generation."); visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, "Maximum number of tokens to generate."); visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2); visitor(deterministic, "deterministic", false, "Make top-k sampling deterministic", 2); visitor(multiturn, "multiturn", false, "Multiturn mode (if 0, this clears the KV cache after every " "interaction without quitting)\n Default = 0 (conversation " "resets every turn)"); } }; void GenerateGemma(Gemma& gemma, const InferenceArgs& args, const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& g, int verbosity); constexpr int EOS_ID = 1; } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_