Introduce attention implementation configurability.

PiperOrigin-RevId: 828971705
This commit is contained in:
Martin Stolle 2025-11-06 08:43:03 -08:00 committed by Copybara-Service
parent 091b4567c9
commit 35e9f9f05f
6 changed files with 52 additions and 11 deletions

View File

@ -524,6 +524,7 @@ cc_library(
deps = [ deps = [
":args", ":args",
":basics", ":basics",
":configs",
":mat", ":mat",
"//io", "//io",
"@highway//:hwy", "@highway//:hwy",

View File

@ -24,6 +24,7 @@
#include <vector> #include <vector>
#include "gemma/configs.h" // ModelConfig #include "gemma/configs.h" // ModelConfig
#include "gemma/gemma_args.h" // AttentionImpl
#include "ops/ops.h" // CreateInvTimescale #include "ops/ops.h" // CreateInvTimescale
#include "util/basics.h" // BF16 #include "util/basics.h" // BF16
#include "util/mat.h" // MatStorageT #include "util/mat.h" // MatStorageT
@ -179,8 +180,8 @@ struct AttentionActivationsPtrs {
}; };
struct Activations { struct Activations {
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len, Activations(const RuntimeConfig& runtime_config, const ModelConfig& config,
ThreadingContext& ctx, size_t batch_size, size_t seq_len, ThreadingContext& ctx,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs) std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: layer_config(config.layer_configs[0]), : layer_config(config.layer_configs[0]),
@ -199,6 +200,7 @@ struct Activations {
ffw_out( ffw_out(
MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)), MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)),
attention_impl(runtime_config.attention_impl),
attention_storage(config, layer_config, batch_size, seq_len, attention_storage(config, layer_config, batch_size, seq_len,
ctx.allocator, row_ptrs), ctx.allocator, row_ptrs),
attention(config, seq_len, attention_storage) { attention(config, seq_len, attention_storage) {
@ -248,6 +250,8 @@ struct Activations {
MatStorageT<BF16> C2; MatStorageT<BF16> C2;
MatStorageT<float> ffw_out; MatStorageT<float> ffw_out;
AttentionImpl attention_impl;
AttentionActivations attention_storage; AttentionActivations attention_storage;
AttentionActivationsPtrs attention; AttentionActivationsPtrs attention;
}; };

View File

@ -80,6 +80,34 @@ static inline bool EnumValid(LayerAttentionType type) {
return type == LayerAttentionType::kGemma || type == LayerAttentionType::kVit; return type == LayerAttentionType::kGemma || type == LayerAttentionType::kVit;
} }
enum class AttentionImpl {
kOld,
kFlash,
};
/*
* Returns a bitmask of flags to pass to attention functions based on the
* attention implementation selected.
*
* If `hwy_native_dot_bf16` is true, the function will use the old attention
* implementation, ignoring `impl`.
*
* `hwy_native_dot_bf16` needs to be passed in, because the HWY_NATIVE_DOT_BF16
* macro is not available outside of highway instrumented translation units and
* cannot be made accessible from .h files.
*/
static inline int AttentionImplToFlags(AttentionImpl impl,
int hwy_native_dot_bf16) {
if (hwy_native_dot_bf16) return kAttentionUseOld;
switch (impl) {
case AttentionImpl::kOld:
return kAttentionUseOld;
case AttentionImpl::kFlash:
return 0;
}
}
// Post attention and ffw normalization type. // Post attention and ffw normalization type.
enum class PostNormType { enum class PostNormType {
None, None,

View File

@ -112,8 +112,9 @@ void TestFlashAttention(size_t target_parallelism) {
RuntimeConfig runtime_config; RuntimeConfig runtime_config;
KVCache kv_cache(config, inference_args, ctx.allocator); KVCache kv_cache(config, inference_args, ctx.allocator);
MatMulEnv env(ctx); MatMulEnv env(ctx);
Activations activations(config, runtime_config.prefill_tbatch_size, Activations activations(runtime_config, config,
kv_cache.SeqLen(), env.ctx, env.row_ptrs); runtime_config.prefill_tbatch_size, kv_cache.SeqLen(),
env.ctx, env.row_ptrs);
std::vector<int> tokens(kOuter); std::vector<int> tokens(kOuter);
std::iota(tokens.begin(), tokens.end(), 1); std::iota(tokens.begin(), tokens.end(), 1);
PromptTokens prompt(tokens); PromptTokens prompt(tokens);

View File

@ -73,10 +73,12 @@ namespace HWY_NAMESPACE {
void Attention(LayerAttentionType type, const size_t num_tokens, void Attention(LayerAttentionType type, const size_t num_tokens,
const size_t layer_idx, const LayerWeightsPtrs& layer, const size_t layer_idx, const LayerWeightsPtrs& layer,
Activations& activations, QBatch& qbatch, MatMulEnv& env) { Activations& activations, QBatch& qbatch, MatMulEnv& env) {
if (type == LayerAttentionType::kGemma) { if (type == LayerAttentionType::kGemma) {
// TODO: remove flag to enable FlashAttention. // TODO: remove flag to enable FlashAttention.
GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch, GemmaAttention(
env, HWY_NATIVE_DOT_BF16 ? kAttentionUseOld : 0); num_tokens, layer_idx, layer, activations.attention, qbatch, env,
AttentionImplToFlags(activations.attention_impl, HWY_NATIVE_DOT_BF16));
} }
} }
@ -573,8 +575,9 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const AesCtrEngine& engine, const WeightsPtrs& weights, const AesCtrEngine& engine, const WeightsPtrs& weights,
KVCache& kv_cache, MatMulEnv& env, KVCache& kv_cache, MatMulEnv& env,
TimingInfo& timing_info) { TimingInfo& timing_info) {
Activations activations(config, runtime_config.prefill_tbatch_size, Activations activations(runtime_config, config,
kv_cache.SeqLen(), env.ctx, env.row_ptrs); runtime_config.prefill_tbatch_size, kv_cache.SeqLen(),
env.ctx, env.row_ptrs);
AllQueries all_queries(prompt, pos, prefix_end, AllQueries all_queries(prompt, pos, prefix_end,
hwy::Span<KVCache>(&kv_cache, 1)); hwy::Span<KVCache>(&kv_cache, 1));
@ -592,7 +595,7 @@ void GenerateBatchT(const ModelConfig& config,
TimingInfo& timing_info) { TimingInfo& timing_info) {
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size, const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
runtime_config.prefill_tbatch_size); runtime_config.prefill_tbatch_size);
Activations activations(config, max_batch_size, Activations activations(runtime_config, config, max_batch_size,
all_queries[0].kv_cache.SeqLen(), env.ctx, all_queries[0].kv_cache.SeqLen(), env.ctx,
env.row_ptrs); env.row_ptrs);
@ -617,8 +620,8 @@ void GenerateImageTokensT(const ModelConfig& config,
const size_t num_tokens = vit_config.max_seq_len; const size_t num_tokens = vit_config.max_seq_len;
prefill_runtime_config.prefill_tbatch_size = prefill_runtime_config.prefill_tbatch_size =
num_tokens / (vit_config.pool_dim * vit_config.pool_dim); num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
Activations prefill_activations(vit_config, num_tokens, num_tokens, env.ctx, Activations prefill_activations(runtime_config, vit_config, num_tokens,
env.row_ptrs); num_tokens, env.ctx, env.row_ptrs);
// Weights are for the full PaliGemma model, not just the ViT part. // Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
prefill_activations, env); prefill_activations, env);

View File

@ -24,6 +24,7 @@
#include <functional> #include <functional>
#include <string> #include <string>
#include "gemma/configs.h"
#include "io/io.h" // Path #include "io/io.h" // Path
#include "util/args.h" #include "util/args.h"
#include "util/basics.h" // Tristate #include "util/basics.h" // Tristate
@ -139,6 +140,9 @@ struct RuntimeConfig {
int verbosity; // Controls verbosity of printed messages. int verbosity; // Controls verbosity of printed messages.
// Which attention implementation to use.
AttentionImpl attention_impl = AttentionImpl::kFlash;
// Functions operating on the generated tokens. // Functions operating on the generated tokens.
StreamFunc stream_token; StreamFunc stream_token;
BatchStreamFunc batch_stream_token; BatchStreamFunc batch_stream_token;