diff --git a/gemma/attention.cc b/gemma/attention.cc index 67542ae..78eb496 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -119,7 +119,7 @@ static HWY_INLINE void WeightedSumV( // Calculates the attention outputs for a single q, which may be updated // in place for RMSNorm. void SingleDotSoftmaxWeightedSum( - const size_t pos, const size_t start_pos, const size_t last_pos, + const size_t q_pos, const size_t kv_start_pos, const size_t kv_last_pos, float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, const MatPtr& query_norm_scale, const size_t layer_idx, const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, @@ -128,7 +128,7 @@ void SingleDotSoftmaxWeightedSum( const float att_cap = activations.config.att_cap; const float query_scale = activations.query_scale; // --seq_len must be large enough to avoid wraparound. - HWY_DASSERT(last_pos < activations.SeqLen()); + HWY_DASSERT(kv_last_pos < activations.SeqLen()); const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; // Apply rope and scaling to Q. @@ -139,18 +139,19 @@ void SingleDotSoftmaxWeightedSum( }); } - PositionalEncodingQK(q, layer_idx, activations, ctx, worker, pos, + PositionalEncodingQK(q, layer_idx, activations, ctx, worker, q_pos, query_scale); - QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker); + QDotK(kv_start_pos, kv_last_pos, activations.div_seq_len, q, k, att, ctx, + worker); // SoftMax with optional SoftCap yields "probabilities" in att. - const Logits logits(att, last_pos + 1); + const Logits logits(att, kv_last_pos + 1); MaybeLogitsSoftCap(att_cap, logits, ctx, worker); Softmax(logits, ctx, worker, /*temperature=*/1.0f, sm_options); - WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, - ctx, worker); + WeightedSumV(kv_start_pos, kv_last_pos, activations.div_seq_len, att, v, + att_out, ctx, worker); } // The attention window usually starts at 0 unless `pos` is larger than