Clarify indices

PiperOrigin-RevId: 836235539
This commit is contained in:
Martin Stolle 2025-11-24 08:27:23 -08:00 committed by Copybara-Service
parent 37a25c9ffe
commit 8696f6dd17
1 changed files with 8 additions and 7 deletions

View File

@ -119,7 +119,7 @@ static HWY_INLINE void WeightedSumV(
// Calculates the attention outputs for a single q, which may be updated
// in place for RMSNorm.
void SingleDotSoftmaxWeightedSum(
const size_t pos, const size_t start_pos, const size_t last_pos,
const size_t q_pos, const size_t kv_start_pos, const size_t kv_last_pos,
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
const MatPtr& query_norm_scale, const size_t layer_idx,
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
@ -128,7 +128,7 @@ void SingleDotSoftmaxWeightedSum(
const float att_cap = activations.config.att_cap;
const float query_scale = activations.query_scale;
// --seq_len must be large enough to avoid wraparound.
HWY_DASSERT(last_pos < activations.SeqLen());
HWY_DASSERT(kv_last_pos < activations.SeqLen());
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
// Apply rope and scaling to Q.
@ -139,18 +139,19 @@ void SingleDotSoftmaxWeightedSum(
});
}
PositionalEncodingQK(q, layer_idx, activations, ctx, worker, pos,
PositionalEncodingQK(q, layer_idx, activations, ctx, worker, q_pos,
query_scale);
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker);
QDotK(kv_start_pos, kv_last_pos, activations.div_seq_len, q, k, att, ctx,
worker);
// SoftMax with optional SoftCap yields "probabilities" in att.
const Logits logits(att, last_pos + 1);
const Logits logits(att, kv_last_pos + 1);
MaybeLogitsSoftCap(att_cap, logits, ctx, worker);
Softmax(logits, ctx, worker, /*temperature=*/1.0f, sm_options);
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
ctx, worker);
WeightedSumV(kv_start_pos, kv_last_pos, activations.div_seq_len, att, v,
att_out, ctx, worker);
}
// The attention window usually starts at 0 unless `pos` is larger than