From 6e5e4123f1777f02d00399a0b801560c47237df7 Mon Sep 17 00:00:00 2001 From: Krzysztof Rymski Date: Fri, 28 Nov 2025 02:36:36 -0800 Subject: [PATCH] Internal changes PiperOrigin-RevId: 837775282 --- gemma/flash_attention.cc | 3 ++ gemma/flash_attention.h | 65 ++++++++++++++++++++-------------------- ops/ops-inl.h | 1 + 3 files changed, 37 insertions(+), 32 deletions(-) diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index a463dfe..8a9757b 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -17,13 +17,16 @@ #include #include +#include #include +#include #include #include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "gemma/flash_structs.h" #include "util/threading_context.h" #include "util/zones.h" +#include "hwy/base.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h index 236c7dc..b8a70ea 100644 --- a/gemma/flash_attention.h +++ b/gemma/flash_attention.h @@ -29,38 +29,39 @@ namespace gcpp { // Passed to HWY_VISIT_TARGETS; declares for one target. -#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \ - namespace NAMESPACE { \ - void RMSNormAndPositionalEncoding( \ - size_t num_tokens, const QBatch& qbatch, MatPtrT& q, \ - const MatPtr& query_norm_scale, size_t layer_idx, \ - const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \ - \ - void SingleFlashAttention(size_t start_pos, size_t last_pos, \ - const float* HWY_RESTRICT q, \ - const MatPtrT& k, const MatPtrT& v, \ - size_t layer_idx, \ - const AttentionActivationsPtrs& activations, \ - float* HWY_RESTRICT att_out, \ - ThreadingContext& ctx, size_t worker); \ - \ - Tile4FlashState TileFlashAttention4( \ - const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, \ - const MatPtrT& k, size_t start_pos, \ - const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \ - size_t max_last_pos, const MatPtrT& v, size_t layer_idx, \ - const LayerWeightsPtrs& layer, const AttentionActivations& activations, \ - MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, \ - ThreadingContext& ctx, const size_t worker); \ - \ - size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ - size_t total_tasks, size_t target_parallelism); \ - \ - void FlashAttention(size_t num_tokens, size_t target_parallelism, \ - size_t layer_idx, const MatPtr& query_norm_scale, \ - AttentionActivationsPtrs& activations, QBatch& qbatch, \ - ThreadingContext& ctx); \ - /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ +#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void RMSNormAndPositionalEncoding( \ + size_t num_tokens, const QBatch& qbatch, MatPtrT& q, \ + const MatPtr& query_norm_scale, size_t layer_idx, \ + const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \ + \ + void SingleFlashAttention(size_t start_pos, size_t last_pos, \ + const BF16* HWY_RESTRICT q, \ + const MatPtrT& k, const MatPtrT& v, \ + size_t layer_idx, \ + const AttentionActivationsPtrs& activations, \ + float* HWY_RESTRICT att_out, \ + ThreadingContext& ctx, size_t worker); \ + \ + Tile4FlashState TileFlashAttention4( \ + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, \ + const MatPtrT& k, size_t start_pos, \ + const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \ + size_t max_last_pos, const MatPtrT& v, size_t layer_idx, \ + const LayerWeightsPtrs& layer, const AttentionActivations& activations, \ + MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, \ + ThreadingContext& ctx, const size_t worker); \ + \ + size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ + size_t total_tasks, size_t target_parallelism); \ + \ + void FlashAttention(size_t num_tokens, size_t target_parallelism, \ + size_t layer_idx, const MatPtr& query_norm_scale, \ + AttentionActivationsPtrs& activations, QBatch& qbatch, \ + ThreadingContext& ctx); \ + \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE // Function declarations for each SIMD target. Allows direct call from the diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 3b41ff3..0eeec31 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -25,6 +25,7 @@ #include #include #include // std::enable_if_t +#include #include #include "ops/matmul.h"