Also update attention.h to type-erased query_norm_scale

PiperOrigin-RevId: 825014334
This commit is contained in:
Jan Wassenberg 2025-10-28 06:47:56 -07:00 committed by Copybara-Service
parent 3cc0139ebb
commit 4bd465ffd3
2 changed files with 4 additions and 4 deletions

View File

@ -130,7 +130,7 @@ static HWY_INLINE void WeightedSumV(
void SingleDotSoftmaxWeightedSum( void SingleDotSoftmaxWeightedSum(
const size_t pos, const size_t start_pos, const size_t last_pos, const size_t pos, const size_t start_pos, const size_t last_pos,
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
const MatPtrT<float>& query_norm_scale, const size_t layer_idx, const MatPtr& query_norm_scale, const size_t layer_idx,
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) { float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
const float att_cap = activations.config.att_cap; const float att_cap = activations.config.att_cap;
@ -169,7 +169,7 @@ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) {
} }
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const MatPtrT<float>& query_norm_scale, const MatPtr& query_norm_scale,
AttentionActivationsPtrs& activations, AttentionActivationsPtrs& activations,
QBatch& qbatch, ThreadingContext& ctx) { QBatch& qbatch, ThreadingContext& ctx) {
GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive); GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive);

View File

@ -38,12 +38,12 @@ namespace gcpp {
void SingleDotSoftmaxWeightedSum( \ void SingleDotSoftmaxWeightedSum( \
const size_t pos, const size_t start_pos, const size_t last_pos, \ const size_t pos, const size_t start_pos, const size_t last_pos, \
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \ float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \ const MatPtr& query_norm_scale, size_t layer_idx, \
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \ const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \ float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \
\ \
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
const MatPtrT<float>& query_norm_scale, \ const MatPtr& query_norm_scale, \
AttentionActivationsPtrs& activations, \ AttentionActivationsPtrs& activations, \
QBatch& qbatch, ThreadingContext& ctx); \ QBatch& qbatch, ThreadingContext& ctx); \
\ \