diff --git a/BUILD.bazel b/BUILD.bazel index aae230e..2a33707 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -524,6 +524,7 @@ cc_library( deps = [ ":args", ":basics", + ":configs", ":mat", "//io", "@highway//:hwy", diff --git a/gemma/activations.h b/gemma/activations.h index a96c305..20d938c 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -24,6 +24,7 @@ #include #include "gemma/configs.h" // ModelConfig +#include "gemma/gemma_args.h" // AttentionImpl #include "ops/ops.h" // CreateInvTimescale #include "util/basics.h" // BF16 #include "util/mat.h" // MatStorageT @@ -179,8 +180,8 @@ struct AttentionActivationsPtrs { }; struct Activations { - Activations(const ModelConfig& config, size_t batch_size, size_t seq_len, - ThreadingContext& ctx, + Activations(const RuntimeConfig& runtime_config, const ModelConfig& config, + size_t batch_size, size_t seq_len, ThreadingContext& ctx, std::vector>& row_ptrs) : layer_config(config.layer_configs[0]), @@ -199,6 +200,7 @@ struct Activations { ffw_out( MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)), + attention_impl(runtime_config.attention_impl), attention_storage(config, layer_config, batch_size, seq_len, ctx.allocator, row_ptrs), attention(config, seq_len, attention_storage) { @@ -248,6 +250,8 @@ struct Activations { MatStorageT C2; MatStorageT ffw_out; + AttentionImpl attention_impl; + AttentionActivations attention_storage; AttentionActivationsPtrs attention; }; diff --git a/gemma/configs.h b/gemma/configs.h index d774481..e2cb5e2 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -80,6 +80,34 @@ static inline bool EnumValid(LayerAttentionType type) { return type == LayerAttentionType::kGemma || type == LayerAttentionType::kVit; } +enum class AttentionImpl { + kOld, + kFlash, +}; + +/* + * Returns a bitmask of flags to pass to attention functions based on the + * attention implementation selected. + * + * If `hwy_native_dot_bf16` is true, the function will use the old attention + * implementation, ignoring `impl`. + * + * `hwy_native_dot_bf16` needs to be passed in, because the HWY_NATIVE_DOT_BF16 + * macro is not available outside of highway instrumented translation units and + * cannot be made accessible from .h files. + */ +static inline int AttentionImplToFlags(AttentionImpl impl, + int hwy_native_dot_bf16) { + if (hwy_native_dot_bf16) return kAttentionUseOld; + + switch (impl) { + case AttentionImpl::kOld: + return kAttentionUseOld; + case AttentionImpl::kFlash: + return 0; + } +} + // Post attention and ffw normalization type. enum class PostNormType { None, diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index b33f52f..4a7d319 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -112,8 +112,9 @@ void TestFlashAttention(size_t target_parallelism) { RuntimeConfig runtime_config; KVCache kv_cache(config, inference_args, ctx.allocator); MatMulEnv env(ctx); - Activations activations(config, runtime_config.prefill_tbatch_size, - kv_cache.SeqLen(), env.ctx, env.row_ptrs); + Activations activations(runtime_config, config, + runtime_config.prefill_tbatch_size, kv_cache.SeqLen(), + env.ctx, env.row_ptrs); std::vector tokens(kOuter); std::iota(tokens.begin(), tokens.end(), 1); PromptTokens prompt(tokens); diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 2f342bf..ffe7c47 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -73,10 +73,12 @@ namespace HWY_NAMESPACE { void Attention(LayerAttentionType type, const size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, Activations& activations, QBatch& qbatch, MatMulEnv& env) { + if (type == LayerAttentionType::kGemma) { // TODO: remove flag to enable FlashAttention. - GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch, - env, HWY_NATIVE_DOT_BF16 ? kAttentionUseOld : 0); + GemmaAttention( + num_tokens, layer_idx, layer, activations.attention, qbatch, env, + AttentionImplToFlags(activations.attention_impl, HWY_NATIVE_DOT_BF16)); } } @@ -573,8 +575,9 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, const AesCtrEngine& engine, const WeightsPtrs& weights, KVCache& kv_cache, MatMulEnv& env, TimingInfo& timing_info) { - Activations activations(config, runtime_config.prefill_tbatch_size, - kv_cache.SeqLen(), env.ctx, env.row_ptrs); + Activations activations(runtime_config, config, + runtime_config.prefill_tbatch_size, kv_cache.SeqLen(), + env.ctx, env.row_ptrs); AllQueries all_queries(prompt, pos, prefix_end, hwy::Span(&kv_cache, 1)); @@ -592,7 +595,7 @@ void GenerateBatchT(const ModelConfig& config, TimingInfo& timing_info) { const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size, runtime_config.prefill_tbatch_size); - Activations activations(config, max_batch_size, + Activations activations(runtime_config, config, max_batch_size, all_queries[0].kv_cache.SeqLen(), env.ctx, env.row_ptrs); @@ -617,8 +620,8 @@ void GenerateImageTokensT(const ModelConfig& config, const size_t num_tokens = vit_config.max_seq_len; prefill_runtime_config.prefill_tbatch_size = num_tokens / (vit_config.pool_dim * vit_config.pool_dim); - Activations prefill_activations(vit_config, num_tokens, num_tokens, env.ctx, - env.row_ptrs); + Activations prefill_activations(runtime_config, vit_config, num_tokens, + num_tokens, env.ctx, env.row_ptrs); // Weights are for the full PaliGemma model, not just the ViT part. PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, prefill_activations, env); diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 3135f50..8536a78 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -24,6 +24,7 @@ #include #include +#include "gemma/configs.h" #include "io/io.h" // Path #include "util/args.h" #include "util/basics.h" // Tristate @@ -139,6 +140,9 @@ struct RuntimeConfig { int verbosity; // Controls verbosity of printed messages. + // Which attention implementation to use. + AttentionImpl attention_impl = AttentionImpl::kFlash; + // Functions operating on the generated tokens. StreamFunc stream_token; BatchStreamFunc batch_stream_token;