#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TILED_ATTENTION_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_TILED_ATTENTION_H_ #include #include #include #include #include "gemma/gemma.h" #include "util/allocator.h" #include "hwy/aligned_allocator.h" #include "hwy/highway.h" namespace gcpp { // Passed to HWY_VISIT_TARGETS; declares for one target. #define GEMMA_DECL_TILED_ATTENTION(TARGET, NAMESPACE) \ namespace NAMESPACE { \ void TiledAttention(AttentionImpl attention_impl, size_t num_tokens, \ size_t layer_idx, const LayerWeightsPtrs& layer, \ AttentionActivationsPtrs& activations, QBatch& qbatch, \ MatMulEnv& env, int flags); \ void TransposeStridedQueries(hwy::Span queries, int qkv_dim, \ hwy::Span transposed_queries); \ void LocalAttentionForAllHeadsTokensAndBatch( \ AttentionImpl attention_impl, const size_t num_tokens, \ const size_t layer_idx, const LayerWeightsPtrs& layer, \ AttentionActivationsPtrs& activations, QBatch& qbatch, \ ThreadingContext& ctx); \ \ template \ std::tuple>, \ std::vector, AlignedFloatVector> \ TransposeQueriesToGroupsOfNBF16orInt16(hwy::Span queries_ptrs, \ int qkv_dim, size_t group_size); \ \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE // Function declarations for each SIMD target. Allows direct call from the // per-target namespace. We may later replace this with dynamic dispatch if // the overhead is acceptable. HWY_VISIT_TARGETS(GEMMA_DECL_TILED_ATTENTION) #undef GEMMA_DECL_TILED_ATTENTION } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TILED_ATTENTION_H_