mirror of https://github.com/google/gemma.cpp.git
parent
129e66ada2
commit
8db89304bd
|
|
@ -70,21 +70,3 @@ The implementation code is roughly split into 4 layers, from high to low level:
|
||||||
|
|
||||||
4. Backend (`highway`) - Low-level hardware interface (SIMD in the case of
|
4. Backend (`highway`) - Low-level hardware interface (SIMD in the case of
|
||||||
highway) supporting the implementations in (3).
|
highway) supporting the implementations in (3).
|
||||||
|
|
||||||
## Compile-Time Flags (Advanced)
|
|
||||||
|
|
||||||
There are several compile-time flags to be aware of (note these may or may not
|
|
||||||
be exposed to the build system):
|
|
||||||
|
|
||||||
- `GEMMA_WEIGHT_T` : Sets the level of compression for weights (surfaced as
|
|
||||||
WEIGHT_TYPE in CMakeLists.txt). Currently this should be set to `SfpStream`
|
|
||||||
(default, if no flag is specified) for 8-bit SFP, or `hwy::bfloat16_t` to
|
|
||||||
enable for higher-fidelity (but slower) bfloat16 support. This is defined in
|
|
||||||
`gemma.h`.
|
|
||||||
- `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV
|
|
||||||
Cache. The default is 4096 tokens but can be overridden. This is not exposed
|
|
||||||
through `CMakeLists.txt` yet.
|
|
||||||
|
|
||||||
In the medium term both of these will likely be deprecated in favor of handling
|
|
||||||
options at runtime - allowing for multiple weight compression schemes in a single
|
|
||||||
build and dynamically resizes the KV cache as needed.
|
|
||||||
|
|
|
||||||
|
|
@ -114,12 +114,8 @@ convenient directory location (e.g. the `build/` directory in this repo).
|
||||||
|
|
||||||
The build system uses [CMake](https://cmake.org/). To build the gemma inference
|
The build system uses [CMake](https://cmake.org/). To build the gemma inference
|
||||||
runtime, create a build directory and generate the build files using `cmake`
|
runtime, create a build directory and generate the build files using `cmake`
|
||||||
from the top-level project directory. Note if you previous ran `cmake` and are
|
from the top-level project directory. For the 8-bit switched floating point
|
||||||
re-running with a different setting, be sure to clean out the `build/` directory
|
weights (sfp), run cmake with no options:
|
||||||
with `rm -rf build/*` (warning this will delete any other files in the `build/`
|
|
||||||
directory.
|
|
||||||
|
|
||||||
For the 8-bit switched floating point weights (sfp), run cmake with no options:
|
|
||||||
|
|
||||||
#### Unix-like Platforms
|
#### Unix-like Platforms
|
||||||
```sh
|
```sh
|
||||||
|
|
|
||||||
21
configs.h
21
configs.h
|
|
@ -18,26 +18,21 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
#define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
||||||
|
|
||||||
// Allow changing pre-allocated kv cache size as a compiler flag
|
|
||||||
#ifndef GEMMA_MAX_SEQLEN
|
|
||||||
#define GEMMA_MAX_SEQLEN 4096
|
|
||||||
#endif // !GEMMA_MAX_SEQLEN
|
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
|
static constexpr size_t kSeqLen = 7168;
|
||||||
|
|
||||||
struct ConfigGemma7B {
|
struct ConfigGemma7B {
|
||||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||||
static constexpr int kVocabSize = 256128;
|
static constexpr int kVocabSize = 256128;
|
||||||
static constexpr int kLayers = 28;
|
static constexpr int kLayers = 28;
|
||||||
static constexpr int kModelDim = 3072;
|
static constexpr int kModelDim = 3072;
|
||||||
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
|
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
|
||||||
static constexpr int kHeads = 16;
|
static constexpr int kHeads = 16;
|
||||||
static constexpr int kKVHeads = 16; // standard MHA
|
static constexpr int kKVHeads = 16; // standard MHA, no GQA or MQA
|
||||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||||
static constexpr int kTopK = 1;
|
static constexpr int kTopK = 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -46,13 +41,13 @@ struct ConfigGemma2B {
|
||||||
static constexpr int kVocabSize = 256128;
|
static constexpr int kVocabSize = 256128;
|
||||||
static constexpr int kLayers = 18;
|
static constexpr int kLayers = 18;
|
||||||
static constexpr int kModelDim = 2048;
|
static constexpr int kModelDim = 2048;
|
||||||
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
|
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
|
||||||
static constexpr int kHeads = 8;
|
static constexpr int kHeads = 8;
|
||||||
static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support
|
static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support
|
||||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||||
static constexpr int kTopK = 1;
|
static constexpr int kTopK = 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
||||||
|
|
|
||||||
59
gemma.h
59
gemma.h
|
|
@ -25,14 +25,14 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "compression/compress.h" // SfpStream/NuqStream
|
#include "compression/compress.h" // SfpStream/NuqStream
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "configs.h" // kSeqLen
|
#include "configs.h" // kSeqLen
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "util/args.h" // ArgsBase
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
#include "hwy/base.h" // hwy::bfloat16_t
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "util/args.h" // ArgsBase
|
|
||||||
// copybara:import_next_line:sentencepiece
|
// copybara:import_next_line:sentencepiece
|
||||||
#include "src/sentencepiece_processor.h"
|
#include "src/sentencepiece_processor.h"
|
||||||
|
|
||||||
|
|
@ -42,7 +42,7 @@ namespace gcpp {
|
||||||
// float, hwy::bfloat16_t, SfpStream, NuqStream
|
// float, hwy::bfloat16_t, SfpStream, NuqStream
|
||||||
#ifndef GEMMA_WEIGHT_T
|
#ifndef GEMMA_WEIGHT_T
|
||||||
#define GEMMA_WEIGHT_T SfpStream
|
#define GEMMA_WEIGHT_T SfpStream
|
||||||
#endif // !GEMMA_WEIGHT_T
|
#endif // !GEMMA_WEIGHT_T
|
||||||
using WeightT = GEMMA_WEIGHT_T;
|
using WeightT = GEMMA_WEIGHT_T;
|
||||||
|
|
||||||
using EmbedderInputT = hwy::bfloat16_t;
|
using EmbedderInputT = hwy::bfloat16_t;
|
||||||
|
|
@ -51,9 +51,9 @@ constexpr bool kSystemPrompt = false;
|
||||||
|
|
||||||
struct KVCache {
|
struct KVCache {
|
||||||
hwy::AlignedFreeUniquePtr<float[]>
|
hwy::AlignedFreeUniquePtr<float[]>
|
||||||
key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
|
key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
|
||||||
hwy::AlignedFreeUniquePtr<float[]>
|
hwy::AlignedFreeUniquePtr<float[]>
|
||||||
value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
|
value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
|
||||||
};
|
};
|
||||||
|
|
||||||
// Model variants: see configs.h for details.
|
// Model variants: see configs.h for details.
|
||||||
|
|
@ -61,9 +61,9 @@ enum class Model { GEMMA_2B, GEMMA_7B };
|
||||||
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
||||||
|
|
||||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
LoaderArgs(int argc, char *argv[]) { InitAndParse(argc, argv); }
|
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
|
|
||||||
static std::string ToLower(const std::string &text) {
|
static std::string ToLower(const std::string& text) {
|
||||||
std::string result = text;
|
std::string result = text;
|
||||||
std::transform(begin(result), end(result), begin(result),
|
std::transform(begin(result), end(result), begin(result),
|
||||||
[](unsigned char c) { return std::tolower(c); });
|
[](unsigned char c) { return std::tolower(c); });
|
||||||
|
|
@ -89,7 +89,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
// Returns error string or nullptr if OK.
|
||||||
const char *Validate() const {
|
const char* Validate() const {
|
||||||
const std::string model_type_lc = ToLower(model_type);
|
const std::string model_type_lc = ToLower(model_type);
|
||||||
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
|
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
|
||||||
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
|
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
|
||||||
|
|
@ -111,11 +111,12 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
}
|
}
|
||||||
|
|
||||||
Path tokenizer;
|
Path tokenizer;
|
||||||
Path model; // uncompressed weights OR
|
Path model; // uncompressed weights OR
|
||||||
Path cache; // compressed weights
|
Path cache; // compressed weights
|
||||||
std::string model_type;
|
std::string model_type;
|
||||||
|
|
||||||
template <class Visitor> void ForEach(const Visitor &visitor) {
|
template <class Visitor>
|
||||||
|
void ForEach(const Visitor& visitor) {
|
||||||
visitor(tokenizer, "tokenizer", Path(),
|
visitor(tokenizer, "tokenizer", Path(),
|
||||||
"Path name of tokenizer model file. (required)");
|
"Path name of tokenizer model file. (required)");
|
||||||
visitor(
|
visitor(
|
||||||
|
|
@ -138,10 +139,10 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
struct GemmaInterface;
|
struct GemmaInterface;
|
||||||
|
|
||||||
struct Gemma {
|
struct Gemma {
|
||||||
Gemma(const LoaderArgs &args, hwy::ThreadPool &pool);
|
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
|
||||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
||||||
|
|
||||||
const sentencepiece::SentencePieceProcessor &Tokenizer() const;
|
const sentencepiece::SentencePieceProcessor& Tokenizer() const;
|
||||||
|
|
||||||
std::unique_ptr<GemmaInterface> impl_;
|
std::unique_ptr<GemmaInterface> impl_;
|
||||||
gcpp::ModelTraining model_training;
|
gcpp::ModelTraining model_training;
|
||||||
|
|
@ -153,7 +154,7 @@ using StreamFunc = std::function<bool(int, float)>;
|
||||||
using AcceptFunc = std::function<bool(int)>;
|
using AcceptFunc = std::function<bool(int)>;
|
||||||
|
|
||||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
InferenceArgs(int argc, char *argv[]) { InitAndParse(argc, argv); }
|
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
|
|
||||||
size_t max_tokens;
|
size_t max_tokens;
|
||||||
size_t max_generated_tokens;
|
size_t max_generated_tokens;
|
||||||
|
|
@ -163,7 +164,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
bool multiturn;
|
bool multiturn;
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
// Returns error string or nullptr if OK.
|
||||||
const char *Validate() const {
|
const char* Validate() const {
|
||||||
if (max_tokens > gcpp::kSeqLen) {
|
if (max_tokens > gcpp::kSeqLen) {
|
||||||
return "max_tokens is larger than the maximum sequence length (see "
|
return "max_tokens is larger than the maximum sequence length (see "
|
||||||
"configs.h).";
|
"configs.h).";
|
||||||
|
|
@ -175,7 +176,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Visitor> void ForEach(const Visitor &visitor) {
|
template <class Visitor>
|
||||||
|
void ForEach(const Visitor& visitor) {
|
||||||
visitor(max_tokens, "max_tokens", size_t{3072},
|
visitor(max_tokens, "max_tokens", size_t{3072},
|
||||||
"Maximum number of tokens in prompt + generation.");
|
"Maximum number of tokens in prompt + generation.");
|
||||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||||
|
|
@ -184,21 +186,22 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
|
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
|
||||||
visitor(deterministic, "deterministic", false,
|
visitor(deterministic, "deterministic", false,
|
||||||
"Make top-k sampling deterministic", 2);
|
"Make top-k sampling deterministic", 2);
|
||||||
visitor(multiturn, "multiturn", false,
|
visitor(multiturn, "multiturn", true,
|
||||||
"Multiturn mode (if 0, this clears the KV cache after every "
|
"Multiturn mode (if 0, this clears the KV cache after every "
|
||||||
"interaction without quitting)\n Default = 0 (conversation resets every turn)");
|
"interaction without quitting)",
|
||||||
|
2);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void GenerateGemma(Gemma &gemma, const InferenceArgs &args,
|
void GenerateGemma(Gemma& gemma, const InferenceArgs& args,
|
||||||
const std::vector<int> &prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
hwy::ThreadPool &pool, hwy::ThreadPool &inner_pool,
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||||
const StreamFunc &stream_token,
|
const StreamFunc& stream_token,
|
||||||
const AcceptFunc &accept_token, std::mt19937 &g,
|
const AcceptFunc& accept_token, std::mt19937& g,
|
||||||
int verbosity);
|
int verbosity);
|
||||||
|
|
||||||
constexpr int EOS_ID = 1;
|
constexpr int EOS_ID = 1;
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue