mirror of https://github.com/google/gemma.cpp.git
parent
37a25c9ffe
commit
8696f6dd17
|
|
@ -119,7 +119,7 @@ static HWY_INLINE void WeightedSumV(
|
|||
// Calculates the attention outputs for a single q, which may be updated
|
||||
// in place for RMSNorm.
|
||||
void SingleDotSoftmaxWeightedSum(
|
||||
const size_t pos, const size_t start_pos, const size_t last_pos,
|
||||
const size_t q_pos, const size_t kv_start_pos, const size_t kv_last_pos,
|
||||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
|
||||
const MatPtr& query_norm_scale, const size_t layer_idx,
|
||||
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
|
||||
|
|
@ -128,7 +128,7 @@ void SingleDotSoftmaxWeightedSum(
|
|||
const float att_cap = activations.config.att_cap;
|
||||
const float query_scale = activations.query_scale;
|
||||
// --seq_len must be large enough to avoid wraparound.
|
||||
HWY_DASSERT(last_pos < activations.SeqLen());
|
||||
HWY_DASSERT(kv_last_pos < activations.SeqLen());
|
||||
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
||||
|
||||
// Apply rope and scaling to Q.
|
||||
|
|
@ -139,18 +139,19 @@ void SingleDotSoftmaxWeightedSum(
|
|||
});
|
||||
}
|
||||
|
||||
PositionalEncodingQK(q, layer_idx, activations, ctx, worker, pos,
|
||||
PositionalEncodingQK(q, layer_idx, activations, ctx, worker, q_pos,
|
||||
query_scale);
|
||||
|
||||
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker);
|
||||
QDotK(kv_start_pos, kv_last_pos, activations.div_seq_len, q, k, att, ctx,
|
||||
worker);
|
||||
|
||||
// SoftMax with optional SoftCap yields "probabilities" in att.
|
||||
const Logits logits(att, last_pos + 1);
|
||||
const Logits logits(att, kv_last_pos + 1);
|
||||
MaybeLogitsSoftCap(att_cap, logits, ctx, worker);
|
||||
Softmax(logits, ctx, worker, /*temperature=*/1.0f, sm_options);
|
||||
|
||||
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
|
||||
ctx, worker);
|
||||
WeightedSumV(kv_start_pos, kv_last_pos, activations.div_seq_len, att, v,
|
||||
att_out, ctx, worker);
|
||||
}
|
||||
|
||||
// The attention window usually starts at 0 unless `pos` is larger than
|
||||
|
|
|
|||
Loading…
Reference in New Issue