From 4bd465ffd31fa75bd5106244524d44f1b503ffc3 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 28 Oct 2025 06:47:56 -0700 Subject: [PATCH] Also update attention.h to type-erased query_norm_scale PiperOrigin-RevId: 825014334 --- gemma/attention.cc | 4 ++-- gemma/attention.h | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gemma/attention.cc b/gemma/attention.cc index b7099d1..6d81c74 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -130,7 +130,7 @@ static HWY_INLINE void WeightedSumV( void SingleDotSoftmaxWeightedSum( const size_t pos, const size_t start_pos, const size_t last_pos, float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, - const MatPtrT& query_norm_scale, const size_t layer_idx, + const MatPtr& query_norm_scale, const size_t layer_idx, const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) { const float att_cap = activations.config.att_cap; @@ -169,7 +169,7 @@ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) { } void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, - const MatPtrT& query_norm_scale, + const MatPtr& query_norm_scale, AttentionActivationsPtrs& activations, QBatch& qbatch, ThreadingContext& ctx) { GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive); diff --git a/gemma/attention.h b/gemma/attention.h index 491a0b0..60e6823 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -38,12 +38,12 @@ namespace gcpp { void SingleDotSoftmaxWeightedSum( \ const size_t pos, const size_t start_pos, const size_t last_pos, \ float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, \ - const MatPtrT& query_norm_scale, size_t layer_idx, \ + const MatPtr& query_norm_scale, size_t layer_idx, \ const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \ float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \ \ void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ - const MatPtrT& query_norm_scale, \ + const MatPtr& query_norm_scale, \ AttentionActivationsPtrs& activations, \ QBatch& qbatch, ThreadingContext& ctx); \ \