Exposes `GemmaAttention::DotSoftmaxWeightedSum` for experimentation.

Also in this change:
* The computation for a single `q` is factored out and exposed.
* Strided `ConstMat` views into the KV caches are introduced to enable experimentation with various KV cache layouts.

PiperOrigin-RevId: 756339313
This commit is contained in:
Biruk Mammo 2025-05-08 09:18:02 -07:00 committed by Copybara-Service
parent a0ff98ea60
commit d834c07042
1 changed files with 70 additions and 54 deletions

View File

@ -341,40 +341,36 @@ class GemmaAttention {
});
}
// Computes Q.K scores, which are "logits" (or scores) stored to head_att.
// Computes Q.K scores, which are "logits" (or scores) stored to att.
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const size_t head_offset, const float* HWY_RESTRICT q,
const KVCache& kv_cache, float* HWY_RESTRICT head_att) {
const float* HWY_RESTRICT q, const ConstMat<float>& k,
float* HWY_RESTRICT att) {
const size_t qkv_dim = layer_config_.qkv_dim;
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
// Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t kv_offset =
pos * cache_pos_size_ + layer_ * cache_layer_size_ + head_offset;
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
const float score = Dot(q, k, qkv_dim);
head_att[pos] = score;
const float* HWY_RESTRICT k_ptr = k.ptr + k.Row(pos);
const float score = Dot(q, k_ptr, qkv_dim);
att[pos] = score;
}
} else {
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t cache_pos = div_seq_len_.Remainder(pos);
const size_t kv_offset = cache_pos * cache_pos_size_ +
layer_ * cache_layer_size_ + head_offset;
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
const float score = Dot(q, k, qkv_dim);
head_att[pos % activations_.seq_len] = score;
const float* HWY_RESTRICT k_ptr = k.ptr + k.Row(cache_pos);
const float score = Dot(q, k_ptr, qkv_dim);
att[pos % activations_.seq_len] = score;
}
}
}
// Accumulates the sum of v (from `kv_cache`) * probability (`head_att`) into
// Accumulates the sum of v (from `kv_cache`) * probability (`att`) into
// `att_out`. Equivalent in gemma/modules.py:
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
HWY_INLINE void WeightedSumV(const size_t start_pos, const size_t last_pos,
const float* HWY_RESTRICT head_att,
const size_t layer, const size_t head_offset,
const hwy::Divisor& div_seq_len,
const KVCache& kv_cache,
const float* HWY_RESTRICT att,
const ConstMat<float>& v,
float* HWY_RESTRICT att_out) const {
const size_t qkv_dim = layer_config_.qkv_dim;
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
@ -382,25 +378,44 @@ class GemmaAttention {
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
// Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t kv_offset =
pos * cache_pos_size_ + layer * cache_layer_size_ + head_offset;
const float* HWY_RESTRICT v =
kv_cache.kv_cache.get() + kv_offset + qkv_dim;
MulByConstAndAdd(head_att[pos], v, att_out, qkv_dim);
const float* HWY_RESTRICT v_ptr = v.ptr + v.Row(pos);
MulByConstAndAdd(att[pos], v_ptr, att_out, qkv_dim);
}
} else {
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t cache_pos = div_seq_len.Remainder(pos);
const size_t kv_offset = cache_pos * cache_pos_size_ +
layer * cache_layer_size_ + head_offset;
const float* HWY_RESTRICT v =
kv_cache.kv_cache.get() + kv_offset + qkv_dim;
MulByConstAndAdd(head_att[pos % activations_.seq_len], v, att_out,
const size_t cache_pos = div_seq_len_.Remainder(pos);
const float* HWY_RESTRICT v_ptr = v.ptr + v.Row(cache_pos);
MulByConstAndAdd(att[pos % activations_.seq_len], v_ptr, att_out,
qkv_dim);
}
}
}
public:
// Calculates the attention outputs for a single q.
HWY_INLINE void SingleDotSoftmaxWeightedSum(
float* HWY_RESTRICT q, const ConstMat<float>& k, const ConstMat<float>& v,
float* HWY_RESTRICT att, float* HWY_RESTRICT att_out,
const float query_scale, const size_t pos, const size_t start_pos,
const size_t last_pos) {
const size_t qkv_dim = layer_config_.qkv_dim;
// Apply rope and scaling to Q.
if (layer_weights_.query_norm_scale.HasPtr()) {
RMSNormInplace(layer_weights_.query_norm_scale.Row(0), q, qkv_dim);
}
PositionalEncodingQK(q, pos, layer_, query_scale);
QDotK(start_pos, last_pos, q, k, att);
// SoftMax with optional SoftCap yields "probabilities" in att.
const size_t att_len = std::min(last_pos + 1, activations_.seq_len);
MaybeLogitsSoftCap(activations_.weights_config.att_cap, att, att_len);
Softmax(att, att_len);
WeightedSumV(start_pos, last_pos, att, v, att_out);
}
HWY_NOINLINE void DotSoftmaxWeightedSum(const size_t num_interleaved) {
PROFILER_ZONE("Gen.Attention.DotSoftmax");
const float query_scale = ChooseQueryScale(activations_.weights_config);
@ -418,18 +433,32 @@ class GemmaAttention {
const size_t batch_idx = interleaved_idx / num_queries_;
const size_t qkv_dim = layer_config_.qkv_dim;
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
KVCache& kv_cache = kv_caches_[query_idx];
float* HWY_RESTRICT q =
activations_.q.Batch(interleaved_idx) + head * q_stride_;
activations_.q.Batch(interleaved_idx) + head * q_stride_;
float* HWY_RESTRICT att =
activations_.att.Batch(interleaved_idx) +
head * activations_.seq_len;
float* HWY_RESTRICT att_out =
activations_.att_out.Batch(interleaved_idx) +
head * qkv_dim;
// Apply rope and scaling to Q.
// Make strided views into the kv cache entries for the current
// query and head.
KVCache& kv_cache = kv_caches_[query_idx];
const size_t kv_head_offset =
layer_ * cache_layer_size_ + head_offset;
ConstMat<float> k(kv_cache.kv_cache.get() + kv_head_offset,
Extents2D(kv_cache.seq_len, qkv_dim),
/*stride=*/cache_pos_size_);
ConstMat<float> v(
kv_cache.kv_cache.get() + kv_head_offset + qkv_dim,
Extents2D(kv_cache.seq_len, qkv_dim),
/*stride=*/cache_pos_size_);
// Find the token position in the query and calculate the range
// of cache positions to attend to.
const size_t pos = queries_pos_[query_idx] + batch_idx;
if (layer_weights_.query_norm_scale.HasPtr()) {
RMSNormInplace(layer_weights_.query_norm_scale.Row(0), q,
qkv_dim);
}
PositionalEncodingQK(q, pos, layer_, query_scale);
const size_t start_pos = StartPos(pos, layer_);
size_t last_pos = pos;
const size_t prefix_end = queries_prefix_end_[query_idx];
@ -438,26 +467,13 @@ class GemmaAttention {
last_pos = prefix_end - 1;
}
float* HWY_RESTRICT head_att =
activations_.att.Batch(interleaved_idx) +
head * activations_.seq_len;
QDotK(start_pos, last_pos, head_offset, q, kv_cache, head_att);
// SoftMax with optional SoftCap yields "probabilities" in
// head_att.
const size_t head_att_len =
std::min(last_pos + 1, activations_.seq_len);
MaybeLogitsSoftCap(activations_.weights_config.att_cap,
head_att, head_att_len);
Softmax(head_att, head_att_len);
SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale,
pos, start_pos, last_pos);
float* HWY_RESTRICT att_out =
activations_.att_out.Batch(interleaved_idx) +
head * qkv_dim;
WeightedSumV(start_pos, last_pos, head_att, layer_, head_offset,
div_seq_len_, kv_cache, att_out);
});
}
private:
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
// head_dim (`qkv_dim`) into output (`layer_out`).
HWY_NOINLINE void SumHeads(const size_t num_interleaved) {