mirror of https://github.com/google/gemma.cpp.git
parent
37a25c9ffe
commit
8696f6dd17
|
|
@ -119,7 +119,7 @@ static HWY_INLINE void WeightedSumV(
|
||||||
// Calculates the attention outputs for a single q, which may be updated
|
// Calculates the attention outputs for a single q, which may be updated
|
||||||
// in place for RMSNorm.
|
// in place for RMSNorm.
|
||||||
void SingleDotSoftmaxWeightedSum(
|
void SingleDotSoftmaxWeightedSum(
|
||||||
const size_t pos, const size_t start_pos, const size_t last_pos,
|
const size_t q_pos, const size_t kv_start_pos, const size_t kv_last_pos,
|
||||||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
|
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
|
||||||
const MatPtr& query_norm_scale, const size_t layer_idx,
|
const MatPtr& query_norm_scale, const size_t layer_idx,
|
||||||
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
|
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
|
||||||
|
|
@ -128,7 +128,7 @@ void SingleDotSoftmaxWeightedSum(
|
||||||
const float att_cap = activations.config.att_cap;
|
const float att_cap = activations.config.att_cap;
|
||||||
const float query_scale = activations.query_scale;
|
const float query_scale = activations.query_scale;
|
||||||
// --seq_len must be large enough to avoid wraparound.
|
// --seq_len must be large enough to avoid wraparound.
|
||||||
HWY_DASSERT(last_pos < activations.SeqLen());
|
HWY_DASSERT(kv_last_pos < activations.SeqLen());
|
||||||
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
||||||
|
|
||||||
// Apply rope and scaling to Q.
|
// Apply rope and scaling to Q.
|
||||||
|
|
@ -139,18 +139,19 @@ void SingleDotSoftmaxWeightedSum(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
PositionalEncodingQK(q, layer_idx, activations, ctx, worker, pos,
|
PositionalEncodingQK(q, layer_idx, activations, ctx, worker, q_pos,
|
||||||
query_scale);
|
query_scale);
|
||||||
|
|
||||||
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker);
|
QDotK(kv_start_pos, kv_last_pos, activations.div_seq_len, q, k, att, ctx,
|
||||||
|
worker);
|
||||||
|
|
||||||
// SoftMax with optional SoftCap yields "probabilities" in att.
|
// SoftMax with optional SoftCap yields "probabilities" in att.
|
||||||
const Logits logits(att, last_pos + 1);
|
const Logits logits(att, kv_last_pos + 1);
|
||||||
MaybeLogitsSoftCap(att_cap, logits, ctx, worker);
|
MaybeLogitsSoftCap(att_cap, logits, ctx, worker);
|
||||||
Softmax(logits, ctx, worker, /*temperature=*/1.0f, sm_options);
|
Softmax(logits, ctx, worker, /*temperature=*/1.0f, sm_options);
|
||||||
|
|
||||||
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
|
WeightedSumV(kv_start_pos, kv_last_pos, activations.div_seq_len, att, v,
|
||||||
ctx, worker);
|
att_out, ctx, worker);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The attention window usually starts at 0 unless `pos` is larger than
|
// The attention window usually starts at 0 unless `pos` is larger than
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue