From d834c07042e2df108c91b49f3a7681a7ba05add9 Mon Sep 17 00:00:00 2001 From: Biruk Mammo Date: Thu, 8 May 2025 09:18:02 -0700 Subject: [PATCH] Exposes `GemmaAttention::DotSoftmaxWeightedSum` for experimentation. Also in this change: * The computation for a single `q` is factored out and exposed. * Strided `ConstMat` views into the KV caches are introduced to enable experimentation with various KV cache layouts. PiperOrigin-RevId: 756339313 --- gemma/gemma-inl.h | 124 ++++++++++++++++++++++++++-------------------- 1 file changed, 70 insertions(+), 54 deletions(-) diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 0a36035..3345ecb 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -341,40 +341,36 @@ class GemmaAttention { }); } - // Computes Q.K scores, which are "logits" (or scores) stored to head_att. + // Computes Q.K scores, which are "logits" (or scores) stored to att. + // `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, - const size_t head_offset, const float* HWY_RESTRICT q, - const KVCache& kv_cache, float* HWY_RESTRICT head_att) { + const float* HWY_RESTRICT q, const ConstMat& k, + float* HWY_RESTRICT att) { const size_t qkv_dim = layer_config_.qkv_dim; if (HWY_LIKELY(last_pos < activations_.seq_len)) { // Slightly faster: no wraparound. for (size_t pos = start_pos; pos <= last_pos; ++pos) { - const size_t kv_offset = - pos * cache_pos_size_ + layer_ * cache_layer_size_ + head_offset; - const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset]; - const float score = Dot(q, k, qkv_dim); - head_att[pos] = score; + const float* HWY_RESTRICT k_ptr = k.ptr + k.Row(pos); + const float score = Dot(q, k_ptr, qkv_dim); + att[pos] = score; } } else { for (size_t pos = start_pos; pos <= last_pos; ++pos) { const size_t cache_pos = div_seq_len_.Remainder(pos); - const size_t kv_offset = cache_pos * cache_pos_size_ + - layer_ * cache_layer_size_ + head_offset; - const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset]; - const float score = Dot(q, k, qkv_dim); - head_att[pos % activations_.seq_len] = score; + const float* HWY_RESTRICT k_ptr = k.ptr + k.Row(cache_pos); + const float score = Dot(q, k_ptr, qkv_dim); + att[pos % activations_.seq_len] = score; } } } - // Accumulates the sum of v (from `kv_cache`) * probability (`head_att`) into + // Accumulates the sum of v (from `kv_cache`) * probability (`att`) into // `att_out`. Equivalent in gemma/modules.py: // encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) + // `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. HWY_INLINE void WeightedSumV(const size_t start_pos, const size_t last_pos, - const float* HWY_RESTRICT head_att, - const size_t layer, const size_t head_offset, - const hwy::Divisor& div_seq_len, - const KVCache& kv_cache, + const float* HWY_RESTRICT att, + const ConstMat& v, float* HWY_RESTRICT att_out) const { const size_t qkv_dim = layer_config_.qkv_dim; hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); @@ -382,25 +378,44 @@ class GemmaAttention { if (HWY_LIKELY(last_pos < activations_.seq_len)) { // Slightly faster: no wraparound. for (size_t pos = start_pos; pos <= last_pos; ++pos) { - const size_t kv_offset = - pos * cache_pos_size_ + layer * cache_layer_size_ + head_offset; - const float* HWY_RESTRICT v = - kv_cache.kv_cache.get() + kv_offset + qkv_dim; - MulByConstAndAdd(head_att[pos], v, att_out, qkv_dim); + const float* HWY_RESTRICT v_ptr = v.ptr + v.Row(pos); + MulByConstAndAdd(att[pos], v_ptr, att_out, qkv_dim); } } else { for (size_t pos = start_pos; pos <= last_pos; ++pos) { - const size_t cache_pos = div_seq_len.Remainder(pos); - const size_t kv_offset = cache_pos * cache_pos_size_ + - layer * cache_layer_size_ + head_offset; - const float* HWY_RESTRICT v = - kv_cache.kv_cache.get() + kv_offset + qkv_dim; - MulByConstAndAdd(head_att[pos % activations_.seq_len], v, att_out, + const size_t cache_pos = div_seq_len_.Remainder(pos); + const float* HWY_RESTRICT v_ptr = v.ptr + v.Row(cache_pos); + MulByConstAndAdd(att[pos % activations_.seq_len], v_ptr, att_out, qkv_dim); } } } + public: + // Calculates the attention outputs for a single q. + HWY_INLINE void SingleDotSoftmaxWeightedSum( + float* HWY_RESTRICT q, const ConstMat& k, const ConstMat& v, + float* HWY_RESTRICT att, float* HWY_RESTRICT att_out, + const float query_scale, const size_t pos, const size_t start_pos, + const size_t last_pos) { + const size_t qkv_dim = layer_config_.qkv_dim; + + // Apply rope and scaling to Q. + if (layer_weights_.query_norm_scale.HasPtr()) { + RMSNormInplace(layer_weights_.query_norm_scale.Row(0), q, qkv_dim); + } + PositionalEncodingQK(q, pos, layer_, query_scale); + + QDotK(start_pos, last_pos, q, k, att); + + // SoftMax with optional SoftCap yields "probabilities" in att. + const size_t att_len = std::min(last_pos + 1, activations_.seq_len); + MaybeLogitsSoftCap(activations_.weights_config.att_cap, att, att_len); + Softmax(att, att_len); + + WeightedSumV(start_pos, last_pos, att, v, att_out); + } + HWY_NOINLINE void DotSoftmaxWeightedSum(const size_t num_interleaved) { PROFILER_ZONE("Gen.Attention.DotSoftmax"); const float query_scale = ChooseQueryScale(activations_.weights_config); @@ -418,18 +433,32 @@ class GemmaAttention { const size_t batch_idx = interleaved_idx / num_queries_; const size_t qkv_dim = layer_config_.qkv_dim; const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2; - KVCache& kv_cache = kv_caches_[query_idx]; + float* HWY_RESTRICT q = - activations_.q.Batch(interleaved_idx) + head * q_stride_; + activations_.q.Batch(interleaved_idx) + head * q_stride_; + float* HWY_RESTRICT att = + activations_.att.Batch(interleaved_idx) + + head * activations_.seq_len; + float* HWY_RESTRICT att_out = + activations_.att_out.Batch(interleaved_idx) + + head * qkv_dim; - // Apply rope and scaling to Q. + // Make strided views into the kv cache entries for the current + // query and head. + KVCache& kv_cache = kv_caches_[query_idx]; + const size_t kv_head_offset = + layer_ * cache_layer_size_ + head_offset; + ConstMat k(kv_cache.kv_cache.get() + kv_head_offset, + Extents2D(kv_cache.seq_len, qkv_dim), + /*stride=*/cache_pos_size_); + ConstMat v( + kv_cache.kv_cache.get() + kv_head_offset + qkv_dim, + Extents2D(kv_cache.seq_len, qkv_dim), + /*stride=*/cache_pos_size_); + + // Find the token position in the query and calculate the range + // of cache positions to attend to. const size_t pos = queries_pos_[query_idx] + batch_idx; - if (layer_weights_.query_norm_scale.HasPtr()) { - RMSNormInplace(layer_weights_.query_norm_scale.Row(0), q, - qkv_dim); - } - PositionalEncodingQK(q, pos, layer_, query_scale); - const size_t start_pos = StartPos(pos, layer_); size_t last_pos = pos; const size_t prefix_end = queries_prefix_end_[query_idx]; @@ -438,26 +467,13 @@ class GemmaAttention { last_pos = prefix_end - 1; } - float* HWY_RESTRICT head_att = - activations_.att.Batch(interleaved_idx) + - head * activations_.seq_len; - QDotK(start_pos, last_pos, head_offset, q, kv_cache, head_att); - // SoftMax with optional SoftCap yields "probabilities" in - // head_att. - const size_t head_att_len = - std::min(last_pos + 1, activations_.seq_len); - MaybeLogitsSoftCap(activations_.weights_config.att_cap, - head_att, head_att_len); - Softmax(head_att, head_att_len); + SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale, + pos, start_pos, last_pos); - float* HWY_RESTRICT att_out = - activations_.att_out.Batch(interleaved_idx) + - head * qkv_dim; - WeightedSumV(start_pos, last_pos, head_att, layer_, head_offset, - div_seq_len_, kv_cache, att_out); }); } + private: // Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and // head_dim (`qkv_dim`) into output (`layer_out`). HWY_NOINLINE void SumHeads(const size_t num_interleaved) {