From 7efeb4fe06d53352da3a98604746cb9fff0a5aea Mon Sep 17 00:00:00 2001 From: Krzysztof Rymski Date: Tue, 24 Mar 2026 10:02:02 -0700 Subject: [PATCH] Internal changes PiperOrigin-RevId: 888724073 --- gemma/tiled_attention.h | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/gemma/tiled_attention.h b/gemma/tiled_attention.h index f35c4f3..9b1a2ce 100644 --- a/gemma/tiled_attention.h +++ b/gemma/tiled_attention.h @@ -15,26 +15,27 @@ namespace gcpp { // Passed to HWY_VISIT_TARGETS; declares for one target. -#define GEMMA_DECL_TILED_ATTENTION(TARGET, NAMESPACE) \ - namespace NAMESPACE { \ - void TiledAttention(AttentionImpl attention_impl, size_t num_tokens, \ - size_t layer_idx, const LayerWeightsPtrs& layer, \ - AttentionActivationsPtrs& activations, QBatch& qbatch, \ - MatMulEnv& env, int flags); \ - void TransposeStridedQueries(hwy::Span queries, int qkv_dim, \ - hwy::Span transposed_queries); \ - void LocalAttentionForAllHeadsTokensAndBatch( \ - AttentionImpl attention_impl, const size_t num_tokens, \ - const size_t layer_idx, const LayerWeightsPtrs& layer, \ - AttentionActivationsPtrs& activations, QBatch& qbatch, \ - ThreadingContext& ctx); \ - \ - template \ - std::tuple>, \ - std::vector, AlignedFloatVector> \ - TransposeQueriesToGroupsOfNBF16orInt16(hwy::Span queries_ptrs, \ - int qkv_dim, size_t group_size); \ - /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ +#define GEMMA_DECL_TILED_ATTENTION(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void TiledAttention(AttentionImpl attention_impl, size_t num_tokens, \ + size_t layer_idx, const LayerWeightsPtrs& layer, \ + AttentionActivationsPtrs& activations, QBatch& qbatch, \ + MatMulEnv& env, int flags); \ + void TransposeStridedQueries(hwy::Span queries, int qkv_dim, \ + hwy::Span transposed_queries); \ + void LocalAttentionForAllHeadsTokensAndBatch( \ + AttentionImpl attention_impl, const size_t num_tokens, \ + const size_t layer_idx, const LayerWeightsPtrs& layer, \ + AttentionActivationsPtrs& activations, QBatch& qbatch, \ + ThreadingContext& ctx); \ + \ + template \ + std::tuple>, \ + std::vector, AlignedFloatVector> \ + TransposeQueriesToGroupsOfNBF16orInt16(hwy::Span queries_ptrs, \ + int qkv_dim, size_t group_size); \ + \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE // Function declarations for each SIMD target. Allows direct call from the