From afc354dcb19574fe1ddafaf7ffadb9292b3871ab Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Mon, 26 Feb 2024 19:04:33 -0800 Subject: [PATCH] Import from GitHub. PiperOrigin-RevId: 610595796 --- DEVELOPERS.md | 18 ++++++++++++++++++ README.md | 8 ++++++-- configs.h | 21 +++++++++++++-------- gemma.h | 51 ++++++++++++++++++++++++--------------------------- 4 files changed, 61 insertions(+), 37 deletions(-) diff --git a/DEVELOPERS.md b/DEVELOPERS.md index d06b0f8..bdc02c0 100644 --- a/DEVELOPERS.md +++ b/DEVELOPERS.md @@ -70,3 +70,21 @@ The implementation code is roughly split into 4 layers, from high to low level: 4. Backend (`highway`) - Low-level hardware interface (SIMD in the case of highway) supporting the implementations in (3). + +## Compile-Time Flags (Advanced) + +There are several compile-time flags to be aware of (note these may or may not +be exposed to the build system): + +- `GEMMA_WEIGHT_T` : Sets the level of compression for weights (surfaced as + WEIGHT_TYPE in CMakeLists.txt). Currently this should be set to `SfpStream` + (default, if no flag is specified) for 8-bit SFP, or `hwy::bfloat16_t` to + enable for higher-fidelity (but slower) bfloat16 support. This is defined in + `gemma.h`. +- `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV + Cache. The default is 4096 tokens but can be overridden. This is not exposed + through `CMakeLists.txt` yet. + +In the medium term both of these will likely be deprecated in favor of handling +options at runtime - allowing for multiple weight compression schemes in a single +build and dynamically resizes the KV cache as needed. diff --git a/README.md b/README.md index 5932726..8db6862 100644 --- a/README.md +++ b/README.md @@ -114,8 +114,12 @@ convenient directory location (e.g. the `build/` directory in this repo). The build system uses [CMake](https://cmake.org/). To build the gemma inference runtime, create a build directory and generate the build files using `cmake` -from the top-level project directory. For the 8-bit switched floating point -weights (sfp), run cmake with no options: +from the top-level project directory. Note if you previous ran `cmake` and are +re-running with a different setting, be sure to clean out the `build/` directory +with `rm -rf build/*` (warning this will delete any other files in the `build/` +directory. + +For the 8-bit switched floating point weights (sfp), run cmake with no options: #### Unix-like Platforms ```sh diff --git a/configs.h b/configs.h index ebe6220..4be5f75 100644 --- a/configs.h +++ b/configs.h @@ -18,21 +18,26 @@ #ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ #define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ +// Allow changing pre-allocated kv cache size as a compiler flag +#ifndef GEMMA_MAX_SEQLEN +#define GEMMA_MAX_SEQLEN 4096 +#endif // !GEMMA_MAX_SEQLEN + #include namespace gcpp { -static constexpr size_t kSeqLen = 7168; +static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN; struct ConfigGemma7B { static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kVocabSize = 256128; static constexpr int kLayers = 28; static constexpr int kModelDim = 3072; - static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 + static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 static constexpr int kHeads = 16; - static constexpr int kKVHeads = 16; // standard MHA, no GQA or MQA - static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kKVHeads = 16; // standard MHA + static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = 1; }; @@ -41,13 +46,13 @@ struct ConfigGemma2B { static constexpr int kVocabSize = 256128; static constexpr int kLayers = 18; static constexpr int kModelDim = 2048; - static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 + static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 static constexpr int kHeads = 8; static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support - static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = 1; }; -} // namespace gcpp +} // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ +#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ diff --git a/gemma.h b/gemma.h index 5dc9f62..1e76a37 100644 --- a/gemma.h +++ b/gemma.h @@ -42,7 +42,7 @@ namespace gcpp { // float, hwy::bfloat16_t, SfpStream, NuqStream #ifndef GEMMA_WEIGHT_T #define GEMMA_WEIGHT_T SfpStream -#endif // !GEMMA_WEIGHT_T +#endif // !GEMMA_WEIGHT_T using WeightT = GEMMA_WEIGHT_T; using EmbedderInputT = hwy::bfloat16_t; @@ -51,9 +51,9 @@ constexpr bool kSystemPrompt = false; struct KVCache { hwy::AlignedFreeUniquePtr - key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim + key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim hwy::AlignedFreeUniquePtr - value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim + value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim }; // Model variants: see configs.h for details. @@ -61,9 +61,9 @@ enum class Model { GEMMA_2B, GEMMA_7B }; enum class ModelTraining { GEMMA_IT, GEMMA_PT }; struct LoaderArgs : public ArgsBase { - LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + LoaderArgs(int argc, char *argv[]) { InitAndParse(argc, argv); } - static std::string ToLower(const std::string& text) { + static std::string ToLower(const std::string &text) { std::string result = text; std::transform(begin(result), end(result), begin(result), [](unsigned char c) { return std::tolower(c); }); @@ -89,7 +89,7 @@ struct LoaderArgs : public ArgsBase { } // Returns error string or nullptr if OK. - const char* Validate() const { + const char *Validate() const { const std::string model_type_lc = ToLower(model_type); if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" && model_type_lc != "2b-it" && model_type_lc != "7b-it") { @@ -111,12 +111,11 @@ struct LoaderArgs : public ArgsBase { } Path tokenizer; - Path model; // uncompressed weights OR - Path cache; // compressed weights + Path model; // uncompressed weights OR + Path cache; // compressed weights std::string model_type; - template - void ForEach(const Visitor& visitor) { + template void ForEach(const Visitor &visitor) { visitor(tokenizer, "tokenizer", Path(), "Path name of tokenizer model file. (required)"); visitor( @@ -139,10 +138,10 @@ struct LoaderArgs : public ArgsBase { struct GemmaInterface; struct Gemma { - Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); - ~Gemma(); // must be defined after GemmaInterface's dtor is defined. + Gemma(const LoaderArgs &args, hwy::ThreadPool &pool); + ~Gemma(); // must be defined after GemmaInterface's dtor is defined. - const sentencepiece::SentencePieceProcessor& Tokenizer() const; + const sentencepiece::SentencePieceProcessor &Tokenizer() const; std::unique_ptr impl_; gcpp::ModelTraining model_training; @@ -154,7 +153,7 @@ using StreamFunc = std::function; using AcceptFunc = std::function; struct InferenceArgs : public ArgsBase { - InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + InferenceArgs(int argc, char *argv[]) { InitAndParse(argc, argv); } size_t max_tokens; size_t max_generated_tokens; @@ -164,7 +163,7 @@ struct InferenceArgs : public ArgsBase { bool multiturn; // Returns error string or nullptr if OK. - const char* Validate() const { + const char *Validate() const { if (max_tokens > gcpp::kSeqLen) { return "max_tokens is larger than the maximum sequence length (see " "configs.h)."; @@ -176,8 +175,7 @@ struct InferenceArgs : public ArgsBase { return nullptr; } - template - void ForEach(const Visitor& visitor) { + template void ForEach(const Visitor &visitor) { visitor(max_tokens, "max_tokens", size_t{3072}, "Maximum number of tokens in prompt + generation."); visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, @@ -186,22 +184,21 @@ struct InferenceArgs : public ArgsBase { visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2); visitor(deterministic, "deterministic", false, "Make top-k sampling deterministic", 2); - visitor(multiturn, "multiturn", true, + visitor(multiturn, "multiturn", false, "Multiturn mode (if 0, this clears the KV cache after every " - "interaction without quitting)", - 2); + "interaction without quitting)\n Default = 0 (conversation resets every turn)"); } }; -void GenerateGemma(Gemma& gemma, const InferenceArgs& args, - const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& g, +void GenerateGemma(Gemma &gemma, const InferenceArgs &args, + const std::vector &prompt, size_t start_pos, + hwy::ThreadPool &pool, hwy::ThreadPool &inner_pool, + const StreamFunc &stream_token, + const AcceptFunc &accept_token, std::mt19937 &g, int verbosity); constexpr int EOS_ID = 1; -} // namespace gcpp +} // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_ +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_