Clarify indices

PiperOrigin-RevId: 836235539
This commit is contained in:
Martin Stolle 2025-11-24 08:27:23 -08:00 committed by Copybara-Service
parent 37a25c9ffe
commit 8696f6dd17
1 changed files with 8 additions and 7 deletions

View File

@ -119,7 +119,7 @@ static HWY_INLINE void WeightedSumV(
// Calculates the attention outputs for a single q, which may be updated // Calculates the attention outputs for a single q, which may be updated
// in place for RMSNorm. // in place for RMSNorm.
void SingleDotSoftmaxWeightedSum( void SingleDotSoftmaxWeightedSum(
const size_t pos, const size_t start_pos, const size_t last_pos, const size_t q_pos, const size_t kv_start_pos, const size_t kv_last_pos,
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
const MatPtr& query_norm_scale, const size_t layer_idx, const MatPtr& query_norm_scale, const size_t layer_idx,
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
@ -128,7 +128,7 @@ void SingleDotSoftmaxWeightedSum(
const float att_cap = activations.config.att_cap; const float att_cap = activations.config.att_cap;
const float query_scale = activations.query_scale; const float query_scale = activations.query_scale;
// --seq_len must be large enough to avoid wraparound. // --seq_len must be large enough to avoid wraparound.
HWY_DASSERT(last_pos < activations.SeqLen()); HWY_DASSERT(kv_last_pos < activations.SeqLen());
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
// Apply rope and scaling to Q. // Apply rope and scaling to Q.
@ -139,18 +139,19 @@ void SingleDotSoftmaxWeightedSum(
}); });
} }
PositionalEncodingQK(q, layer_idx, activations, ctx, worker, pos, PositionalEncodingQK(q, layer_idx, activations, ctx, worker, q_pos,
query_scale); query_scale);
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker); QDotK(kv_start_pos, kv_last_pos, activations.div_seq_len, q, k, att, ctx,
worker);
// SoftMax with optional SoftCap yields "probabilities" in att. // SoftMax with optional SoftCap yields "probabilities" in att.
const Logits logits(att, last_pos + 1); const Logits logits(att, kv_last_pos + 1);
MaybeLogitsSoftCap(att_cap, logits, ctx, worker); MaybeLogitsSoftCap(att_cap, logits, ctx, worker);
Softmax(logits, ctx, worker, /*temperature=*/1.0f, sm_options); Softmax(logits, ctx, worker, /*temperature=*/1.0f, sm_options);
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, WeightedSumV(kv_start_pos, kv_last_pos, activations.div_seq_len, att, v,
ctx, worker); att_out, ctx, worker);
} }
// The attention window usually starts at 0 unless `pos` is larger than // The attention window usually starts at 0 unless `pos` is larger than