mirror of https://github.com/google/gemma.cpp.git
Reduce KV cache preallocation to 4096 and make it comptime configurable, add rm build note in readme, add note on comptime options in DEVELOPERS, make multiturn=0 the default
This commit is contained in:
parent
7aeade5c9d
commit
129e66ada2
|
|
@ -70,3 +70,21 @@ The implementation code is roughly split into 4 layers, from high to low level:
|
|||
|
||||
4. Backend (`highway`) - Low-level hardware interface (SIMD in the case of
|
||||
highway) supporting the implementations in (3).
|
||||
|
||||
## Compile-Time Flags (Advanced)
|
||||
|
||||
There are several compile-time flags to be aware of (note these may or may not
|
||||
be exposed to the build system):
|
||||
|
||||
- `GEMMA_WEIGHT_T` : Sets the level of compression for weights (surfaced as
|
||||
WEIGHT_TYPE in CMakeLists.txt). Currently this should be set to `SfpStream`
|
||||
(default, if no flag is specified) for 8-bit SFP, or `hwy::bfloat16_t` to
|
||||
enable for higher-fidelity (but slower) bfloat16 support. This is defined in
|
||||
`gemma.h`.
|
||||
- `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV
|
||||
Cache. The default is 4096 tokens but can be overridden. This is not exposed
|
||||
through `CMakeLists.txt` yet.
|
||||
|
||||
In the medium term both of these will likely be deprecated in favor of handling
|
||||
options at runtime - allowing for multiple weight compression schemes in a single
|
||||
build and dynamically resizes the KV cache as needed.
|
||||
|
|
|
|||
|
|
@ -114,8 +114,12 @@ convenient directory location (e.g. the `build/` directory in this repo).
|
|||
|
||||
The build system uses [CMake](https://cmake.org/). To build the gemma inference
|
||||
runtime, create a build directory and generate the build files using `cmake`
|
||||
from the top-level project directory. For the 8-bit switched floating point
|
||||
weights (sfp), run cmake with no options:
|
||||
from the top-level project directory. Note if you previous ran `cmake` and are
|
||||
re-running with a different setting, be sure to clean out the `build/` directory
|
||||
with `rm -rf build/*` (warning this will delete any other files in the `build/`
|
||||
directory.
|
||||
|
||||
For the 8-bit switched floating point weights (sfp), run cmake with no options:
|
||||
|
||||
#### Unix-like Platforms
|
||||
```sh
|
||||
|
|
|
|||
|
|
@ -18,11 +18,16 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
||||
|
||||
// Allow changing pre-allocated kv cache size as a compiler flag
|
||||
#ifndef GEMMA_MAX_SEQLEN
|
||||
#define GEMMA_MAX_SEQLEN 4096
|
||||
#endif // !GEMMA_MAX_SEQLEN
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
static constexpr size_t kSeqLen = 7168;
|
||||
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
|
||||
|
||||
struct ConfigGemma7B {
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
|
|
@ -31,7 +36,7 @@ struct ConfigGemma7B {
|
|||
static constexpr int kModelDim = 3072;
|
||||
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
|
||||
static constexpr int kHeads = 16;
|
||||
static constexpr int kKVHeads = 16; // standard MHA, no GQA or MQA
|
||||
static constexpr int kKVHeads = 16; // standard MHA
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = 1;
|
||||
};
|
||||
|
|
|
|||
13
gemma.h
13
gemma.h
|
|
@ -29,10 +29,10 @@
|
|||
// copybara:import_next_line:gemma_cpp
|
||||
#include "configs.h" // kSeqLen
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h" // ArgsBase
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // hwy::bfloat16_t
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "util/args.h" // ArgsBase
|
||||
// copybara:import_next_line:sentencepiece
|
||||
#include "src/sentencepiece_processor.h"
|
||||
|
||||
|
|
@ -115,8 +115,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
Path cache; // compressed weights
|
||||
std::string model_type;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
template <class Visitor> void ForEach(const Visitor &visitor) {
|
||||
visitor(tokenizer, "tokenizer", Path(),
|
||||
"Path name of tokenizer model file. (required)");
|
||||
visitor(
|
||||
|
|
@ -176,8 +175,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
template <class Visitor> void ForEach(const Visitor &visitor) {
|
||||
visitor(max_tokens, "max_tokens", size_t{3072},
|
||||
"Maximum number of tokens in prompt + generation.");
|
||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||
|
|
@ -186,10 +184,9 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
|
||||
visitor(deterministic, "deterministic", false,
|
||||
"Make top-k sampling deterministic", 2);
|
||||
visitor(multiturn, "multiturn", true,
|
||||
visitor(multiturn, "multiturn", false,
|
||||
"Multiturn mode (if 0, this clears the KV cache after every "
|
||||
"interaction without quitting)",
|
||||
2);
|
||||
"interaction without quitting)\n Default = 0 (conversation resets every turn)");
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue