diff --git a/gemma/attention.cc b/gemma/attention.cc index b7099d1..6d81c74 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -130,7 +130,7 @@ static HWY_INLINE void WeightedSumV( void SingleDotSoftmaxWeightedSum( const size_t pos, const size_t start_pos, const size_t last_pos, float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, - const MatPtrT& query_norm_scale, const size_t layer_idx, + const MatPtr& query_norm_scale, const size_t layer_idx, const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) { const float att_cap = activations.config.att_cap; @@ -169,7 +169,7 @@ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) { } void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, - const MatPtrT& query_norm_scale, + const MatPtr& query_norm_scale, AttentionActivationsPtrs& activations, QBatch& qbatch, ThreadingContext& ctx) { GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive); diff --git a/gemma/attention.h b/gemma/attention.h index 491a0b0..60e6823 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -38,12 +38,12 @@ namespace gcpp { void SingleDotSoftmaxWeightedSum( \ const size_t pos, const size_t start_pos, const size_t last_pos, \ float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, \ - const MatPtrT& query_norm_scale, size_t layer_idx, \ + const MatPtr& query_norm_scale, size_t layer_idx, \ const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \ float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \ \ void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ - const MatPtrT& query_norm_scale, \ + const MatPtr& query_norm_scale, \ AttentionActivationsPtrs& activations, \ QBatch& qbatch, ThreadingContext& ctx); \ \