diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 247d953..18b4ac5 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -218,38 +218,45 @@ HWY_NOINLINE void Attention( constexpr size_t kSeqLen = TConfig::kSeqLen; GEMMA_CONSTEXPR_SQRT const float kQueryScale = 1.0f / Sqrt(static_cast(kQKVDim)); - constexpr bool kIsMHA = TActivations::kIsMHA; // Multi-Head Attention + // Multi-Head Attention a.k.a. "use_qkv_einsum". + constexpr bool kIsMHA = TActivations::kIsMHA; + static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved const size_t batch_start = batch_and_query_start / num_queries; const size_t num_tokens_and_queries = num_tokens * num_queries; + // For the computation of Q, K, and V, it is useful to remember that + // qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim] + // and kQStride = kQKVDim * (kIsMHA ? 3 : 1); + // + // Compute Q only or QKV (if MHA). // If MHA, this also computes KV, which we copy to the KV cache below. - static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved MatMul_4x4_Batch( num_tokens_and_queries, activations.pre_att_rms_out.data(), layer_weights->qkv_einsum_w.data(), activations.q.data(), pool); - for (size_t batch_and_query_idx = 0; - batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) { - const float* x = activations.pre_att_rms_out.data() + batch_and_query_idx - * kModelDim; - const size_t query_idx = batch_and_query_idx % num_queries; - const size_t batch_idx = batch_and_query_idx / num_queries; - KVCache& kv_cache = *kv_caches[query_idx]; - // QKV projections: - if constexpr (!kIsMHA) { + // Compute KV if not MHA. + if constexpr (!kIsMHA) { + for (size_t batch_and_query_idx = 0; + batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) { + const float* x = + activations.pre_att_rms_out.data() + batch_and_query_idx * kModelDim; + const size_t query_idx = batch_and_query_idx % num_queries; + const size_t batch_idx = batch_and_query_idx / num_queries; + KVCache& kv_cache = *kv_caches[query_idx]; const size_t pos = batch_start + batch_idx; const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); const size_t kv_offset = cache_pos * kCachePosSize + layer * kCacheLayerSize; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; + // KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v). // TODO: requires MatMul support for offsets. - MatVec( + MatVec( layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x, activations.even_odd.data(), kv, pool); } } - // Positional encodings for kv: + // Apply positional encodings for K (and copy KV to cache if MHA). pool.Run( 0, kKVHeads * num_tokens_and_queries, [&](uint64_t task, size_t thread) HWY_ATTR { @@ -264,19 +271,21 @@ HWY_NOINLINE void Attention( KVCache& kv_cache = *kv_caches[query_idx]; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; if constexpr (kIsMHA) { - // For MHA, copy kv into the KV cache from scratch space (see above). + // For MHA, copy KV into the KV cache from scratch space (see above). const float* HWY_RESTRICT q = activations.q.data() + (batch_and_query_idx * kHeads + head) * kQStride; // Skip past the Q part of `q`, and copy KV to `kv`. memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float)); } + // Apply rope to K. Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); }); static_assert((kHeads % kKVHeads) == 0, "query heads must be a multiple of key-value heads"); constexpr size_t kGroupHeads = kHeads / kKVHeads; + // For each head (token, query), compute Q.K, softmax, and weighted V. pool.Run(0, kHeads * num_tokens_and_queries, [&](uint64_t task, size_t thread) HWY_ATTR { const size_t head = task % kHeads; @@ -288,16 +297,15 @@ HWY_NOINLINE void Attention( float* HWY_RESTRICT q = activations.q.data() + (batch_and_query_idx * kHeads + head) * kQStride; + // Apply rope and scaling to Q. const size_t pos = batch_start + batch_idx; - // Calculate scores - float* HWY_RESTRICT head_att = - activations.att.data() + head * kSeqLen - + batch_and_query_idx * kHeads * kSeqLen; - Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); MulByConst(kQueryScale, q, kQKVDim); - // Compute Q dot K scores + // Compute Q.K scores, yielding "logits" (or scores) in head_att. + float* HWY_RESTRICT head_att = + activations.att.data() + head * kSeqLen + + batch_and_query_idx * kHeads * kSeqLen; const size_t start_pos = pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos); for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { @@ -308,13 +316,17 @@ HWY_NOINLINE void Attention( const float score = Dot(q, k2, kQKVDim); head_att[pos2 % kSeqLen] = score; } + + // SoftMax. May be preceded by SoftCap. Yields "probabilities" in head_att. const size_t head_att_len = std::min(pos + 1, kSeqLen); if constexpr (TConfig::kAttCap > 0.0f) { LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len); } Softmax(head_att, head_att_len); - // Weighted summation + // Summation of v (kv_cache) weighted by probs (head_att) + // into "encoded" (att_out). Compare gemma/modules.py: + // encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim + batch_and_query_idx * kHeads * kQKVDim; hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); @@ -327,6 +339,9 @@ HWY_NOINLINE void Attention( } }); + // Sum encoded (att_out) over num_heads and head_dim (kQKVDim) + // into output (layer_out). Compare gemma/modules.py: + // attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded) for (size_t batch_and_query_idx = 0; batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) { // TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after @@ -335,10 +350,13 @@ HWY_NOINLINE void Attention( activations.att_out.data() + batch_and_query_idx * kHeads * kQKVDim; float* HWY_RESTRICT layer_out = activations.att_post2.data() + batch_and_query_idx * kModelDim; + // Head 0 (and potentially biases) -> layer_out. + // attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim]. MatVecT( layer_weights->attn_vec_einsum_w, 0, att_out, layer_weights->attention_output_biases.data(), activations.even_odd.data(), layer_out, pool); + // Head 1 and following are added to layer_out. for (size_t head = 1; head < kHeads; ++head) { // TODO(patrickms): Check this calculation float* HWY_RESTRICT head_out =