mirror of https://github.com/google/gemma.cpp.git
Exposes `GemmaAttention::DotSoftmaxWeightedSum` for experimentation.
Also in this change: * The computation for a single `q` is factored out and exposed. * Strided `ConstMat` views into the KV caches are introduced to enable experimentation with various KV cache layouts. PiperOrigin-RevId: 756339313
This commit is contained in:
parent
a0ff98ea60
commit
d834c07042
|
|
@ -341,40 +341,36 @@ class GemmaAttention {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Computes Q.K scores, which are "logits" (or scores) stored to head_att.
|
// Computes Q.K scores, which are "logits" (or scores) stored to att.
|
||||||
|
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
||||||
HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||||
const size_t head_offset, const float* HWY_RESTRICT q,
|
const float* HWY_RESTRICT q, const ConstMat<float>& k,
|
||||||
const KVCache& kv_cache, float* HWY_RESTRICT head_att) {
|
float* HWY_RESTRICT att) {
|
||||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||||
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
|
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
|
||||||
// Slightly faster: no wraparound.
|
// Slightly faster: no wraparound.
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||||
const size_t kv_offset =
|
const float* HWY_RESTRICT k_ptr = k.ptr + k.Row(pos);
|
||||||
pos * cache_pos_size_ + layer_ * cache_layer_size_ + head_offset;
|
const float score = Dot(q, k_ptr, qkv_dim);
|
||||||
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
|
att[pos] = score;
|
||||||
const float score = Dot(q, k, qkv_dim);
|
|
||||||
head_att[pos] = score;
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||||
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
||||||
const size_t kv_offset = cache_pos * cache_pos_size_ +
|
const float* HWY_RESTRICT k_ptr = k.ptr + k.Row(cache_pos);
|
||||||
layer_ * cache_layer_size_ + head_offset;
|
const float score = Dot(q, k_ptr, qkv_dim);
|
||||||
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
|
att[pos % activations_.seq_len] = score;
|
||||||
const float score = Dot(q, k, qkv_dim);
|
|
||||||
head_att[pos % activations_.seq_len] = score;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accumulates the sum of v (from `kv_cache`) * probability (`head_att`) into
|
// Accumulates the sum of v (from `kv_cache`) * probability (`att`) into
|
||||||
// `att_out`. Equivalent in gemma/modules.py:
|
// `att_out`. Equivalent in gemma/modules.py:
|
||||||
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
||||||
|
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
||||||
HWY_INLINE void WeightedSumV(const size_t start_pos, const size_t last_pos,
|
HWY_INLINE void WeightedSumV(const size_t start_pos, const size_t last_pos,
|
||||||
const float* HWY_RESTRICT head_att,
|
const float* HWY_RESTRICT att,
|
||||||
const size_t layer, const size_t head_offset,
|
const ConstMat<float>& v,
|
||||||
const hwy::Divisor& div_seq_len,
|
|
||||||
const KVCache& kv_cache,
|
|
||||||
float* HWY_RESTRICT att_out) const {
|
float* HWY_RESTRICT att_out) const {
|
||||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||||
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
|
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
|
||||||
|
|
@ -382,25 +378,44 @@ class GemmaAttention {
|
||||||
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
|
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
|
||||||
// Slightly faster: no wraparound.
|
// Slightly faster: no wraparound.
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||||
const size_t kv_offset =
|
const float* HWY_RESTRICT v_ptr = v.ptr + v.Row(pos);
|
||||||
pos * cache_pos_size_ + layer * cache_layer_size_ + head_offset;
|
MulByConstAndAdd(att[pos], v_ptr, att_out, qkv_dim);
|
||||||
const float* HWY_RESTRICT v =
|
|
||||||
kv_cache.kv_cache.get() + kv_offset + qkv_dim;
|
|
||||||
MulByConstAndAdd(head_att[pos], v, att_out, qkv_dim);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||||
const size_t cache_pos = div_seq_len.Remainder(pos);
|
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
||||||
const size_t kv_offset = cache_pos * cache_pos_size_ +
|
const float* HWY_RESTRICT v_ptr = v.ptr + v.Row(cache_pos);
|
||||||
layer * cache_layer_size_ + head_offset;
|
MulByConstAndAdd(att[pos % activations_.seq_len], v_ptr, att_out,
|
||||||
const float* HWY_RESTRICT v =
|
|
||||||
kv_cache.kv_cache.get() + kv_offset + qkv_dim;
|
|
||||||
MulByConstAndAdd(head_att[pos % activations_.seq_len], v, att_out,
|
|
||||||
qkv_dim);
|
qkv_dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
// Calculates the attention outputs for a single q.
|
||||||
|
HWY_INLINE void SingleDotSoftmaxWeightedSum(
|
||||||
|
float* HWY_RESTRICT q, const ConstMat<float>& k, const ConstMat<float>& v,
|
||||||
|
float* HWY_RESTRICT att, float* HWY_RESTRICT att_out,
|
||||||
|
const float query_scale, const size_t pos, const size_t start_pos,
|
||||||
|
const size_t last_pos) {
|
||||||
|
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||||
|
|
||||||
|
// Apply rope and scaling to Q.
|
||||||
|
if (layer_weights_.query_norm_scale.HasPtr()) {
|
||||||
|
RMSNormInplace(layer_weights_.query_norm_scale.Row(0), q, qkv_dim);
|
||||||
|
}
|
||||||
|
PositionalEncodingQK(q, pos, layer_, query_scale);
|
||||||
|
|
||||||
|
QDotK(start_pos, last_pos, q, k, att);
|
||||||
|
|
||||||
|
// SoftMax with optional SoftCap yields "probabilities" in att.
|
||||||
|
const size_t att_len = std::min(last_pos + 1, activations_.seq_len);
|
||||||
|
MaybeLogitsSoftCap(activations_.weights_config.att_cap, att, att_len);
|
||||||
|
Softmax(att, att_len);
|
||||||
|
|
||||||
|
WeightedSumV(start_pos, last_pos, att, v, att_out);
|
||||||
|
}
|
||||||
|
|
||||||
HWY_NOINLINE void DotSoftmaxWeightedSum(const size_t num_interleaved) {
|
HWY_NOINLINE void DotSoftmaxWeightedSum(const size_t num_interleaved) {
|
||||||
PROFILER_ZONE("Gen.Attention.DotSoftmax");
|
PROFILER_ZONE("Gen.Attention.DotSoftmax");
|
||||||
const float query_scale = ChooseQueryScale(activations_.weights_config);
|
const float query_scale = ChooseQueryScale(activations_.weights_config);
|
||||||
|
|
@ -418,18 +433,32 @@ class GemmaAttention {
|
||||||
const size_t batch_idx = interleaved_idx / num_queries_;
|
const size_t batch_idx = interleaved_idx / num_queries_;
|
||||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||||
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
|
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
|
||||||
KVCache& kv_cache = kv_caches_[query_idx];
|
|
||||||
float* HWY_RESTRICT q =
|
float* HWY_RESTRICT q =
|
||||||
activations_.q.Batch(interleaved_idx) + head * q_stride_;
|
activations_.q.Batch(interleaved_idx) + head * q_stride_;
|
||||||
|
float* HWY_RESTRICT att =
|
||||||
|
activations_.att.Batch(interleaved_idx) +
|
||||||
|
head * activations_.seq_len;
|
||||||
|
float* HWY_RESTRICT att_out =
|
||||||
|
activations_.att_out.Batch(interleaved_idx) +
|
||||||
|
head * qkv_dim;
|
||||||
|
|
||||||
// Apply rope and scaling to Q.
|
// Make strided views into the kv cache entries for the current
|
||||||
|
// query and head.
|
||||||
|
KVCache& kv_cache = kv_caches_[query_idx];
|
||||||
|
const size_t kv_head_offset =
|
||||||
|
layer_ * cache_layer_size_ + head_offset;
|
||||||
|
ConstMat<float> k(kv_cache.kv_cache.get() + kv_head_offset,
|
||||||
|
Extents2D(kv_cache.seq_len, qkv_dim),
|
||||||
|
/*stride=*/cache_pos_size_);
|
||||||
|
ConstMat<float> v(
|
||||||
|
kv_cache.kv_cache.get() + kv_head_offset + qkv_dim,
|
||||||
|
Extents2D(kv_cache.seq_len, qkv_dim),
|
||||||
|
/*stride=*/cache_pos_size_);
|
||||||
|
|
||||||
|
// Find the token position in the query and calculate the range
|
||||||
|
// of cache positions to attend to.
|
||||||
const size_t pos = queries_pos_[query_idx] + batch_idx;
|
const size_t pos = queries_pos_[query_idx] + batch_idx;
|
||||||
if (layer_weights_.query_norm_scale.HasPtr()) {
|
|
||||||
RMSNormInplace(layer_weights_.query_norm_scale.Row(0), q,
|
|
||||||
qkv_dim);
|
|
||||||
}
|
|
||||||
PositionalEncodingQK(q, pos, layer_, query_scale);
|
|
||||||
|
|
||||||
const size_t start_pos = StartPos(pos, layer_);
|
const size_t start_pos = StartPos(pos, layer_);
|
||||||
size_t last_pos = pos;
|
size_t last_pos = pos;
|
||||||
const size_t prefix_end = queries_prefix_end_[query_idx];
|
const size_t prefix_end = queries_prefix_end_[query_idx];
|
||||||
|
|
@ -438,26 +467,13 @@ class GemmaAttention {
|
||||||
last_pos = prefix_end - 1;
|
last_pos = prefix_end - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
float* HWY_RESTRICT head_att =
|
SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale,
|
||||||
activations_.att.Batch(interleaved_idx) +
|
pos, start_pos, last_pos);
|
||||||
head * activations_.seq_len;
|
|
||||||
QDotK(start_pos, last_pos, head_offset, q, kv_cache, head_att);
|
|
||||||
// SoftMax with optional SoftCap yields "probabilities" in
|
|
||||||
// head_att.
|
|
||||||
const size_t head_att_len =
|
|
||||||
std::min(last_pos + 1, activations_.seq_len);
|
|
||||||
MaybeLogitsSoftCap(activations_.weights_config.att_cap,
|
|
||||||
head_att, head_att_len);
|
|
||||||
Softmax(head_att, head_att_len);
|
|
||||||
|
|
||||||
float* HWY_RESTRICT att_out =
|
|
||||||
activations_.att_out.Batch(interleaved_idx) +
|
|
||||||
head * qkv_dim;
|
|
||||||
WeightedSumV(start_pos, last_pos, head_att, layer_, head_offset,
|
|
||||||
div_seq_len_, kv_cache, att_out);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
|
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
|
||||||
// head_dim (`qkv_dim`) into output (`layer_out`).
|
// head_dim (`qkv_dim`) into output (`layer_out`).
|
||||||
HWY_NOINLINE void SumHeads(const size_t num_interleaved) {
|
HWY_NOINLINE void SumHeads(const size_t num_interleaved) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue