mirror of https://github.com/google/gemma.cpp.git
Exposes `GemmaAttention::DotSoftmaxWeightedSum` for experimentation.
Also in this change: * The computation for a single `q` is factored out and exposed. * Strided `ConstMat` views into the KV caches are introduced to enable experimentation with various KV cache layouts. PiperOrigin-RevId: 756339313
This commit is contained in:
parent
a0ff98ea60
commit
d834c07042
|
|
@ -341,40 +341,36 @@ class GemmaAttention {
|
|||
});
|
||||
}
|
||||
|
||||
// Computes Q.K scores, which are "logits" (or scores) stored to head_att.
|
||||
// Computes Q.K scores, which are "logits" (or scores) stored to att.
|
||||
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
||||
HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||
const size_t head_offset, const float* HWY_RESTRICT q,
|
||||
const KVCache& kv_cache, float* HWY_RESTRICT head_att) {
|
||||
const float* HWY_RESTRICT q, const ConstMat<float>& k,
|
||||
float* HWY_RESTRICT att) {
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
|
||||
// Slightly faster: no wraparound.
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
const size_t kv_offset =
|
||||
pos * cache_pos_size_ + layer_ * cache_layer_size_ + head_offset;
|
||||
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
|
||||
const float score = Dot(q, k, qkv_dim);
|
||||
head_att[pos] = score;
|
||||
const float* HWY_RESTRICT k_ptr = k.ptr + k.Row(pos);
|
||||
const float score = Dot(q, k_ptr, qkv_dim);
|
||||
att[pos] = score;
|
||||
}
|
||||
} else {
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
||||
const size_t kv_offset = cache_pos * cache_pos_size_ +
|
||||
layer_ * cache_layer_size_ + head_offset;
|
||||
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
|
||||
const float score = Dot(q, k, qkv_dim);
|
||||
head_att[pos % activations_.seq_len] = score;
|
||||
const float* HWY_RESTRICT k_ptr = k.ptr + k.Row(cache_pos);
|
||||
const float score = Dot(q, k_ptr, qkv_dim);
|
||||
att[pos % activations_.seq_len] = score;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulates the sum of v (from `kv_cache`) * probability (`head_att`) into
|
||||
// Accumulates the sum of v (from `kv_cache`) * probability (`att`) into
|
||||
// `att_out`. Equivalent in gemma/modules.py:
|
||||
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
||||
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
||||
HWY_INLINE void WeightedSumV(const size_t start_pos, const size_t last_pos,
|
||||
const float* HWY_RESTRICT head_att,
|
||||
const size_t layer, const size_t head_offset,
|
||||
const hwy::Divisor& div_seq_len,
|
||||
const KVCache& kv_cache,
|
||||
const float* HWY_RESTRICT att,
|
||||
const ConstMat<float>& v,
|
||||
float* HWY_RESTRICT att_out) const {
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
|
||||
|
|
@ -382,25 +378,44 @@ class GemmaAttention {
|
|||
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
|
||||
// Slightly faster: no wraparound.
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
const size_t kv_offset =
|
||||
pos * cache_pos_size_ + layer * cache_layer_size_ + head_offset;
|
||||
const float* HWY_RESTRICT v =
|
||||
kv_cache.kv_cache.get() + kv_offset + qkv_dim;
|
||||
MulByConstAndAdd(head_att[pos], v, att_out, qkv_dim);
|
||||
const float* HWY_RESTRICT v_ptr = v.ptr + v.Row(pos);
|
||||
MulByConstAndAdd(att[pos], v_ptr, att_out, qkv_dim);
|
||||
}
|
||||
} else {
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
const size_t cache_pos = div_seq_len.Remainder(pos);
|
||||
const size_t kv_offset = cache_pos * cache_pos_size_ +
|
||||
layer * cache_layer_size_ + head_offset;
|
||||
const float* HWY_RESTRICT v =
|
||||
kv_cache.kv_cache.get() + kv_offset + qkv_dim;
|
||||
MulByConstAndAdd(head_att[pos % activations_.seq_len], v, att_out,
|
||||
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
||||
const float* HWY_RESTRICT v_ptr = v.ptr + v.Row(cache_pos);
|
||||
MulByConstAndAdd(att[pos % activations_.seq_len], v_ptr, att_out,
|
||||
qkv_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
// Calculates the attention outputs for a single q.
|
||||
HWY_INLINE void SingleDotSoftmaxWeightedSum(
|
||||
float* HWY_RESTRICT q, const ConstMat<float>& k, const ConstMat<float>& v,
|
||||
float* HWY_RESTRICT att, float* HWY_RESTRICT att_out,
|
||||
const float query_scale, const size_t pos, const size_t start_pos,
|
||||
const size_t last_pos) {
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
|
||||
// Apply rope and scaling to Q.
|
||||
if (layer_weights_.query_norm_scale.HasPtr()) {
|
||||
RMSNormInplace(layer_weights_.query_norm_scale.Row(0), q, qkv_dim);
|
||||
}
|
||||
PositionalEncodingQK(q, pos, layer_, query_scale);
|
||||
|
||||
QDotK(start_pos, last_pos, q, k, att);
|
||||
|
||||
// SoftMax with optional SoftCap yields "probabilities" in att.
|
||||
const size_t att_len = std::min(last_pos + 1, activations_.seq_len);
|
||||
MaybeLogitsSoftCap(activations_.weights_config.att_cap, att, att_len);
|
||||
Softmax(att, att_len);
|
||||
|
||||
WeightedSumV(start_pos, last_pos, att, v, att_out);
|
||||
}
|
||||
|
||||
HWY_NOINLINE void DotSoftmaxWeightedSum(const size_t num_interleaved) {
|
||||
PROFILER_ZONE("Gen.Attention.DotSoftmax");
|
||||
const float query_scale = ChooseQueryScale(activations_.weights_config);
|
||||
|
|
@ -418,18 +433,32 @@ class GemmaAttention {
|
|||
const size_t batch_idx = interleaved_idx / num_queries_;
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
|
||||
KVCache& kv_cache = kv_caches_[query_idx];
|
||||
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.q.Batch(interleaved_idx) + head * q_stride_;
|
||||
float* HWY_RESTRICT att =
|
||||
activations_.att.Batch(interleaved_idx) +
|
||||
head * activations_.seq_len;
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.att_out.Batch(interleaved_idx) +
|
||||
head * qkv_dim;
|
||||
|
||||
// Apply rope and scaling to Q.
|
||||
// Make strided views into the kv cache entries for the current
|
||||
// query and head.
|
||||
KVCache& kv_cache = kv_caches_[query_idx];
|
||||
const size_t kv_head_offset =
|
||||
layer_ * cache_layer_size_ + head_offset;
|
||||
ConstMat<float> k(kv_cache.kv_cache.get() + kv_head_offset,
|
||||
Extents2D(kv_cache.seq_len, qkv_dim),
|
||||
/*stride=*/cache_pos_size_);
|
||||
ConstMat<float> v(
|
||||
kv_cache.kv_cache.get() + kv_head_offset + qkv_dim,
|
||||
Extents2D(kv_cache.seq_len, qkv_dim),
|
||||
/*stride=*/cache_pos_size_);
|
||||
|
||||
// Find the token position in the query and calculate the range
|
||||
// of cache positions to attend to.
|
||||
const size_t pos = queries_pos_[query_idx] + batch_idx;
|
||||
if (layer_weights_.query_norm_scale.HasPtr()) {
|
||||
RMSNormInplace(layer_weights_.query_norm_scale.Row(0), q,
|
||||
qkv_dim);
|
||||
}
|
||||
PositionalEncodingQK(q, pos, layer_, query_scale);
|
||||
|
||||
const size_t start_pos = StartPos(pos, layer_);
|
||||
size_t last_pos = pos;
|
||||
const size_t prefix_end = queries_prefix_end_[query_idx];
|
||||
|
|
@ -438,26 +467,13 @@ class GemmaAttention {
|
|||
last_pos = prefix_end - 1;
|
||||
}
|
||||
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations_.att.Batch(interleaved_idx) +
|
||||
head * activations_.seq_len;
|
||||
QDotK(start_pos, last_pos, head_offset, q, kv_cache, head_att);
|
||||
// SoftMax with optional SoftCap yields "probabilities" in
|
||||
// head_att.
|
||||
const size_t head_att_len =
|
||||
std::min(last_pos + 1, activations_.seq_len);
|
||||
MaybeLogitsSoftCap(activations_.weights_config.att_cap,
|
||||
head_att, head_att_len);
|
||||
Softmax(head_att, head_att_len);
|
||||
SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale,
|
||||
pos, start_pos, last_pos);
|
||||
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.att_out.Batch(interleaved_idx) +
|
||||
head * qkv_dim;
|
||||
WeightedSumV(start_pos, last_pos, head_att, layer_, head_offset,
|
||||
div_seq_len_, kv_cache, att_out);
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
|
||||
// head_dim (`qkv_dim`) into output (`layer_out`).
|
||||
HWY_NOINLINE void SumHeads(const size_t num_interleaved) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue