mirror of https://github.com/google/gemma.cpp.git
Introduce attention implementation configurability.
PiperOrigin-RevId: 828971705
This commit is contained in:
parent
091b4567c9
commit
35e9f9f05f
|
|
@ -524,6 +524,7 @@ cc_library(
|
|||
deps = [
|
||||
":args",
|
||||
":basics",
|
||||
":configs",
|
||||
":mat",
|
||||
"//io",
|
||||
"@highway//:hwy",
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/gemma_args.h" // AttentionImpl
|
||||
#include "ops/ops.h" // CreateInvTimescale
|
||||
#include "util/basics.h" // BF16
|
||||
#include "util/mat.h" // MatStorageT
|
||||
|
|
@ -179,8 +180,8 @@ struct AttentionActivationsPtrs {
|
|||
};
|
||||
|
||||
struct Activations {
|
||||
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
|
||||
ThreadingContext& ctx,
|
||||
Activations(const RuntimeConfig& runtime_config, const ModelConfig& config,
|
||||
size_t batch_size, size_t seq_len, ThreadingContext& ctx,
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||
: layer_config(config.layer_configs[0]),
|
||||
|
||||
|
|
@ -199,6 +200,7 @@ struct Activations {
|
|||
ffw_out(
|
||||
MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)),
|
||||
|
||||
attention_impl(runtime_config.attention_impl),
|
||||
attention_storage(config, layer_config, batch_size, seq_len,
|
||||
ctx.allocator, row_ptrs),
|
||||
attention(config, seq_len, attention_storage) {
|
||||
|
|
@ -248,6 +250,8 @@ struct Activations {
|
|||
MatStorageT<BF16> C2;
|
||||
MatStorageT<float> ffw_out;
|
||||
|
||||
AttentionImpl attention_impl;
|
||||
|
||||
AttentionActivations attention_storage;
|
||||
AttentionActivationsPtrs attention;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -80,6 +80,34 @@ static inline bool EnumValid(LayerAttentionType type) {
|
|||
return type == LayerAttentionType::kGemma || type == LayerAttentionType::kVit;
|
||||
}
|
||||
|
||||
enum class AttentionImpl {
|
||||
kOld,
|
||||
kFlash,
|
||||
};
|
||||
|
||||
/*
|
||||
* Returns a bitmask of flags to pass to attention functions based on the
|
||||
* attention implementation selected.
|
||||
*
|
||||
* If `hwy_native_dot_bf16` is true, the function will use the old attention
|
||||
* implementation, ignoring `impl`.
|
||||
*
|
||||
* `hwy_native_dot_bf16` needs to be passed in, because the HWY_NATIVE_DOT_BF16
|
||||
* macro is not available outside of highway instrumented translation units and
|
||||
* cannot be made accessible from .h files.
|
||||
*/
|
||||
static inline int AttentionImplToFlags(AttentionImpl impl,
|
||||
int hwy_native_dot_bf16) {
|
||||
if (hwy_native_dot_bf16) return kAttentionUseOld;
|
||||
|
||||
switch (impl) {
|
||||
case AttentionImpl::kOld:
|
||||
return kAttentionUseOld;
|
||||
case AttentionImpl::kFlash:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Post attention and ffw normalization type.
|
||||
enum class PostNormType {
|
||||
None,
|
||||
|
|
|
|||
|
|
@ -112,8 +112,9 @@ void TestFlashAttention(size_t target_parallelism) {
|
|||
RuntimeConfig runtime_config;
|
||||
KVCache kv_cache(config, inference_args, ctx.allocator);
|
||||
MatMulEnv env(ctx);
|
||||
Activations activations(config, runtime_config.prefill_tbatch_size,
|
||||
kv_cache.SeqLen(), env.ctx, env.row_ptrs);
|
||||
Activations activations(runtime_config, config,
|
||||
runtime_config.prefill_tbatch_size, kv_cache.SeqLen(),
|
||||
env.ctx, env.row_ptrs);
|
||||
std::vector<int> tokens(kOuter);
|
||||
std::iota(tokens.begin(), tokens.end(), 1);
|
||||
PromptTokens prompt(tokens);
|
||||
|
|
|
|||
|
|
@ -73,10 +73,12 @@ namespace HWY_NAMESPACE {
|
|||
void Attention(LayerAttentionType type, const size_t num_tokens,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
Activations& activations, QBatch& qbatch, MatMulEnv& env) {
|
||||
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
// TODO: remove flag to enable FlashAttention.
|
||||
GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch,
|
||||
env, HWY_NATIVE_DOT_BF16 ? kAttentionUseOld : 0);
|
||||
GemmaAttention(
|
||||
num_tokens, layer_idx, layer, activations.attention, qbatch, env,
|
||||
AttentionImplToFlags(activations.attention_impl, HWY_NATIVE_DOT_BF16));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -573,8 +575,9 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
|||
const AesCtrEngine& engine, const WeightsPtrs& weights,
|
||||
KVCache& kv_cache, MatMulEnv& env,
|
||||
TimingInfo& timing_info) {
|
||||
Activations activations(config, runtime_config.prefill_tbatch_size,
|
||||
kv_cache.SeqLen(), env.ctx, env.row_ptrs);
|
||||
Activations activations(runtime_config, config,
|
||||
runtime_config.prefill_tbatch_size, kv_cache.SeqLen(),
|
||||
env.ctx, env.row_ptrs);
|
||||
|
||||
AllQueries all_queries(prompt, pos, prefix_end,
|
||||
hwy::Span<KVCache>(&kv_cache, 1));
|
||||
|
|
@ -592,7 +595,7 @@ void GenerateBatchT(const ModelConfig& config,
|
|||
TimingInfo& timing_info) {
|
||||
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
|
||||
runtime_config.prefill_tbatch_size);
|
||||
Activations activations(config, max_batch_size,
|
||||
Activations activations(runtime_config, config, max_batch_size,
|
||||
all_queries[0].kv_cache.SeqLen(), env.ctx,
|
||||
env.row_ptrs);
|
||||
|
||||
|
|
@ -617,8 +620,8 @@ void GenerateImageTokensT(const ModelConfig& config,
|
|||
const size_t num_tokens = vit_config.max_seq_len;
|
||||
prefill_runtime_config.prefill_tbatch_size =
|
||||
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
|
||||
Activations prefill_activations(vit_config, num_tokens, num_tokens, env.ctx,
|
||||
env.row_ptrs);
|
||||
Activations prefill_activations(runtime_config, vit_config, num_tokens,
|
||||
num_tokens, env.ctx, env.row_ptrs);
|
||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
|
||||
prefill_activations, env);
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@
|
|||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
#include "gemma/configs.h"
|
||||
#include "io/io.h" // Path
|
||||
#include "util/args.h"
|
||||
#include "util/basics.h" // Tristate
|
||||
|
|
@ -139,6 +140,9 @@ struct RuntimeConfig {
|
|||
|
||||
int verbosity; // Controls verbosity of printed messages.
|
||||
|
||||
// Which attention implementation to use.
|
||||
AttentionImpl attention_impl = AttentionImpl::kFlash;
|
||||
|
||||
// Functions operating on the generated tokens.
|
||||
StreamFunc stream_token;
|
||||
BatchStreamFunc batch_stream_token;
|
||||
|
|
|
|||
Loading…
Reference in New Issue