Import from GitHub.

PiperOrigin-RevId: 610595796
This commit is contained in:
Dan Zheng 2024-02-26 19:04:33 -08:00 committed by Copybara-Service
parent 8db89304bd
commit afc354dcb1
4 changed files with 61 additions and 37 deletions

View File

@ -70,3 +70,21 @@ The implementation code is roughly split into 4 layers, from high to low level:
4. Backend (`highway`) - Low-level hardware interface (SIMD in the case of 4. Backend (`highway`) - Low-level hardware interface (SIMD in the case of
highway) supporting the implementations in (3). highway) supporting the implementations in (3).
## Compile-Time Flags (Advanced)
There are several compile-time flags to be aware of (note these may or may not
be exposed to the build system):
- `GEMMA_WEIGHT_T` : Sets the level of compression for weights (surfaced as
WEIGHT_TYPE in CMakeLists.txt). Currently this should be set to `SfpStream`
(default, if no flag is specified) for 8-bit SFP, or `hwy::bfloat16_t` to
enable for higher-fidelity (but slower) bfloat16 support. This is defined in
`gemma.h`.
- `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV
Cache. The default is 4096 tokens but can be overridden. This is not exposed
through `CMakeLists.txt` yet.
In the medium term both of these will likely be deprecated in favor of handling
options at runtime - allowing for multiple weight compression schemes in a single
build and dynamically resizes the KV cache as needed.

View File

@ -114,8 +114,12 @@ convenient directory location (e.g. the `build/` directory in this repo).
The build system uses [CMake](https://cmake.org/). To build the gemma inference The build system uses [CMake](https://cmake.org/). To build the gemma inference
runtime, create a build directory and generate the build files using `cmake` runtime, create a build directory and generate the build files using `cmake`
from the top-level project directory. For the 8-bit switched floating point from the top-level project directory. Note if you previous ran `cmake` and are
weights (sfp), run cmake with no options: re-running with a different setting, be sure to clean out the `build/` directory
with `rm -rf build/*` (warning this will delete any other files in the `build/`
directory.
For the 8-bit switched floating point weights (sfp), run cmake with no options:
#### Unix-like Platforms #### Unix-like Platforms
```sh ```sh

View File

@ -18,21 +18,26 @@
#ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ #ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
#define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ #define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
// Allow changing pre-allocated kv cache size as a compiler flag
#ifndef GEMMA_MAX_SEQLEN
#define GEMMA_MAX_SEQLEN 4096
#endif // !GEMMA_MAX_SEQLEN
#include <stddef.h> #include <stddef.h>
namespace gcpp { namespace gcpp {
static constexpr size_t kSeqLen = 7168; static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
struct ConfigGemma7B { struct ConfigGemma7B {
static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256128; static constexpr int kVocabSize = 256128;
static constexpr int kLayers = 28; static constexpr int kLayers = 28;
static constexpr int kModelDim = 3072; static constexpr int kModelDim = 3072;
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
static constexpr int kHeads = 16; static constexpr int kHeads = 16;
static constexpr int kKVHeads = 16; // standard MHA, no GQA or MQA static constexpr int kKVHeads = 16; // standard MHA
static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = 1; static constexpr int kTopK = 1;
}; };
@ -41,13 +46,13 @@ struct ConfigGemma2B {
static constexpr int kVocabSize = 256128; static constexpr int kVocabSize = 256128;
static constexpr int kLayers = 18; static constexpr int kLayers = 18;
static constexpr int kModelDim = 2048; static constexpr int kModelDim = 2048;
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
static constexpr int kHeads = 8; static constexpr int kHeads = 8;
static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support
static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = 1; static constexpr int kTopK = 1;
}; };
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ #endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_

51
gemma.h
View File

@ -42,7 +42,7 @@ namespace gcpp {
// float, hwy::bfloat16_t, SfpStream, NuqStream // float, hwy::bfloat16_t, SfpStream, NuqStream
#ifndef GEMMA_WEIGHT_T #ifndef GEMMA_WEIGHT_T
#define GEMMA_WEIGHT_T SfpStream #define GEMMA_WEIGHT_T SfpStream
#endif // !GEMMA_WEIGHT_T #endif // !GEMMA_WEIGHT_T
using WeightT = GEMMA_WEIGHT_T; using WeightT = GEMMA_WEIGHT_T;
using EmbedderInputT = hwy::bfloat16_t; using EmbedderInputT = hwy::bfloat16_t;
@ -51,9 +51,9 @@ constexpr bool kSystemPrompt = false;
struct KVCache { struct KVCache {
hwy::AlignedFreeUniquePtr<float[]> hwy::AlignedFreeUniquePtr<float[]>
key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
hwy::AlignedFreeUniquePtr<float[]> hwy::AlignedFreeUniquePtr<float[]>
value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
}; };
// Model variants: see configs.h for details. // Model variants: see configs.h for details.
@ -61,9 +61,9 @@ enum class Model { GEMMA_2B, GEMMA_7B };
enum class ModelTraining { GEMMA_IT, GEMMA_PT }; enum class ModelTraining { GEMMA_IT, GEMMA_PT };
struct LoaderArgs : public ArgsBase<LoaderArgs> { struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } LoaderArgs(int argc, char *argv[]) { InitAndParse(argc, argv); }
static std::string ToLower(const std::string& text) { static std::string ToLower(const std::string &text) {
std::string result = text; std::string result = text;
std::transform(begin(result), end(result), begin(result), std::transform(begin(result), end(result), begin(result),
[](unsigned char c) { return std::tolower(c); }); [](unsigned char c) { return std::tolower(c); });
@ -89,7 +89,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
} }
// Returns error string or nullptr if OK. // Returns error string or nullptr if OK.
const char* Validate() const { const char *Validate() const {
const std::string model_type_lc = ToLower(model_type); const std::string model_type_lc = ToLower(model_type);
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" && if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
model_type_lc != "2b-it" && model_type_lc != "7b-it") { model_type_lc != "2b-it" && model_type_lc != "7b-it") {
@ -111,12 +111,11 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
} }
Path tokenizer; Path tokenizer;
Path model; // uncompressed weights OR Path model; // uncompressed weights OR
Path cache; // compressed weights Path cache; // compressed weights
std::string model_type; std::string model_type;
template <class Visitor> template <class Visitor> void ForEach(const Visitor &visitor) {
void ForEach(const Visitor& visitor) {
visitor(tokenizer, "tokenizer", Path(), visitor(tokenizer, "tokenizer", Path(),
"Path name of tokenizer model file. (required)"); "Path name of tokenizer model file. (required)");
visitor( visitor(
@ -139,10 +138,10 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
struct GemmaInterface; struct GemmaInterface;
struct Gemma { struct Gemma {
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); Gemma(const LoaderArgs &args, hwy::ThreadPool &pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined. ~Gemma(); // must be defined after GemmaInterface's dtor is defined.
const sentencepiece::SentencePieceProcessor& Tokenizer() const; const sentencepiece::SentencePieceProcessor &Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_; std::unique_ptr<GemmaInterface> impl_;
gcpp::ModelTraining model_training; gcpp::ModelTraining model_training;
@ -154,7 +153,7 @@ using StreamFunc = std::function<bool(int, float)>;
using AcceptFunc = std::function<bool(int)>; using AcceptFunc = std::function<bool(int)>;
struct InferenceArgs : public ArgsBase<InferenceArgs> { struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } InferenceArgs(int argc, char *argv[]) { InitAndParse(argc, argv); }
size_t max_tokens; size_t max_tokens;
size_t max_generated_tokens; size_t max_generated_tokens;
@ -164,7 +163,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
bool multiturn; bool multiturn;
// Returns error string or nullptr if OK. // Returns error string or nullptr if OK.
const char* Validate() const { const char *Validate() const {
if (max_tokens > gcpp::kSeqLen) { if (max_tokens > gcpp::kSeqLen) {
return "max_tokens is larger than the maximum sequence length (see " return "max_tokens is larger than the maximum sequence length (see "
"configs.h)."; "configs.h).";
@ -176,8 +175,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
return nullptr; return nullptr;
} }
template <class Visitor> template <class Visitor> void ForEach(const Visitor &visitor) {
void ForEach(const Visitor& visitor) {
visitor(max_tokens, "max_tokens", size_t{3072}, visitor(max_tokens, "max_tokens", size_t{3072},
"Maximum number of tokens in prompt + generation."); "Maximum number of tokens in prompt + generation.");
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
@ -186,22 +184,21 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2); visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
visitor(deterministic, "deterministic", false, visitor(deterministic, "deterministic", false,
"Make top-k sampling deterministic", 2); "Make top-k sampling deterministic", 2);
visitor(multiturn, "multiturn", true, visitor(multiturn, "multiturn", false,
"Multiturn mode (if 0, this clears the KV cache after every " "Multiturn mode (if 0, this clears the KV cache after every "
"interaction without quitting)", "interaction without quitting)\n Default = 0 (conversation resets every turn)");
2);
} }
}; };
void GenerateGemma(Gemma& gemma, const InferenceArgs& args, void GenerateGemma(Gemma &gemma, const InferenceArgs &args,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int> &prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool &pool, hwy::ThreadPool &inner_pool,
const StreamFunc& stream_token, const StreamFunc &stream_token,
const AcceptFunc& accept_token, std::mt19937& g, const AcceptFunc &accept_token, std::mt19937 &g,
int verbosity); int verbosity);
constexpr int EOS_ID = 1; constexpr int EOS_ID = 1;
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_