From 5a500872b8fec896ab4e0079da23a6ad4c0f9810 Mon Sep 17 00:00:00 2001 From: Martin Stolle Date: Fri, 21 Nov 2025 01:17:06 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 835115693 --- gemma/activations.h | 6 ++++-- gemma/flash_attention_test.cc | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/gemma/activations.h b/gemma/activations.h index 28e48ca..60c26c2 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -46,7 +46,8 @@ static inline float ChooseQueryScale(const ModelConfig& config) { struct AttentionActivations { AttentionActivations( const ModelConfig& config, const LayerConfig& layer_config, - size_t batch_size, size_t seq_len, const Allocator& allocator, + size_t batch_size, size_t seq_len, AttentionImpl attention_impl, + const Allocator& allocator, std::vector>& row_ptrs) : // `vocab_size == 0` means it is for Vit part, VitAttention is still // MHA and does not use an external KV cache. @@ -217,7 +218,8 @@ struct Activations { attention_impl(runtime_config.attention_impl), attention_storage(config, layer_config, batch_size, seq_len, - ctx.allocator, row_ptrs), + runtime_config.attention_impl, ctx.allocator, + row_ptrs), attention(config, seq_len, attention_storage) { HWY_ASSERT(batch_size != 0); diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index 4a7d319..944277f 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -124,7 +124,8 @@ void TestFlashAttention(size_t target_parallelism) { const size_t batch_size = kOuter; std::vector> row_ptrs; AttentionActivations attention_storage(config, layer_config, batch_size, - kOuter, ctx.allocator, row_ptrs); + kOuter, AttentionImpl::kFlash, + ctx.allocator, row_ptrs); AttentionActivationsPtrs attention(config, kOuter, attention_storage); const size_t qkv_dim = layer_config.qkv_dim; ASSERT_EQ(qkv_dim, kInner);