Add more comments to attention computation (and some small restructuring).

PiperOrigin-RevId: 650929097
This commit is contained in:
Daniel Keysers 2024-07-10 02:38:18 -07:00 committed by Copybara-Service
parent cf76f0a401
commit 063bbaa683
1 changed files with 39 additions and 21 deletions

View File

@ -218,38 +218,45 @@ HWY_NOINLINE void Attention(
constexpr size_t kSeqLen = TConfig::kSeqLen; constexpr size_t kSeqLen = TConfig::kSeqLen;
GEMMA_CONSTEXPR_SQRT const float kQueryScale = GEMMA_CONSTEXPR_SQRT const float kQueryScale =
1.0f / Sqrt(static_cast<float>(kQKVDim)); 1.0f / Sqrt(static_cast<float>(kQKVDim));
constexpr bool kIsMHA = TActivations::kIsMHA; // Multi-Head Attention // Multi-Head Attention a.k.a. "use_qkv_einsum".
constexpr bool kIsMHA = TActivations::kIsMHA;
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
const size_t batch_start = batch_and_query_start / num_queries; const size_t batch_start = batch_and_query_start / num_queries;
const size_t num_tokens_and_queries = num_tokens * num_queries; const size_t num_tokens_and_queries = num_tokens * num_queries;
// For the computation of Q, K, and V, it is useful to remember that
// qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim]
// and kQStride = kQKVDim * (kIsMHA ? 3 : 1);
//
// Compute Q only or QKV (if MHA).
// If MHA, this also computes KV, which we copy to the KV cache below. // If MHA, this also computes KV, which we copy to the KV cache below.
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
MatMul_4x4_Batch<kModelDim, kHeads * kQStride>( MatMul_4x4_Batch<kModelDim, kHeads * kQStride>(
num_tokens_and_queries, activations.pre_att_rms_out.data(), num_tokens_and_queries, activations.pre_att_rms_out.data(),
layer_weights->qkv_einsum_w.data(), activations.q.data(), pool); layer_weights->qkv_einsum_w.data(), activations.q.data(), pool);
for (size_t batch_and_query_idx = 0; // Compute KV if not MHA.
batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) { if constexpr (!kIsMHA) {
const float* x = activations.pre_att_rms_out.data() + batch_and_query_idx for (size_t batch_and_query_idx = 0;
* kModelDim; batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) {
const size_t query_idx = batch_and_query_idx % num_queries; const float* x =
const size_t batch_idx = batch_and_query_idx / num_queries; activations.pre_att_rms_out.data() + batch_and_query_idx * kModelDim;
KVCache& kv_cache = *kv_caches[query_idx]; const size_t query_idx = batch_and_query_idx % num_queries;
// QKV projections: const size_t batch_idx = batch_and_query_idx / num_queries;
if constexpr (!kIsMHA) { KVCache& kv_cache = *kv_caches[query_idx];
const size_t pos = batch_start + batch_idx; const size_t pos = batch_start + batch_idx;
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset = const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize; cache_pos * kCachePosSize + layer * kCacheLayerSize;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
// TODO: requires MatMul support for offsets. // TODO: requires MatMul support for offsets.
MatVec<kKVHeads * kQKVDim * 2, kModelDim>( MatVec<kKVHeads * 2 * kQKVDim, kModelDim>(
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x, layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
activations.even_odd.data(), kv, pool); activations.even_odd.data(), kv, pool);
} }
} }
// Positional encodings for kv: // Apply positional encodings for K (and copy KV to cache if MHA).
pool.Run( pool.Run(
0, kKVHeads * num_tokens_and_queries, 0, kKVHeads * num_tokens_and_queries,
[&](uint64_t task, size_t thread) HWY_ATTR { [&](uint64_t task, size_t thread) HWY_ATTR {
@ -264,19 +271,21 @@ HWY_NOINLINE void Attention(
KVCache& kv_cache = *kv_caches[query_idx]; KVCache& kv_cache = *kv_caches[query_idx];
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
if constexpr (kIsMHA) { if constexpr (kIsMHA) {
// For MHA, copy kv into the KV cache from scratch space (see above). // For MHA, copy KV into the KV cache from scratch space (see above).
const float* HWY_RESTRICT q = const float* HWY_RESTRICT q =
activations.q.data() + (batch_and_query_idx * kHeads activations.q.data() + (batch_and_query_idx * kHeads
+ head) * kQStride; + head) * kQStride;
// Skip past the Q part of `q`, and copy KV to `kv`. // Skip past the Q part of `q`, and copy KV to `kv`.
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float)); memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
} }
// Apply rope to K.
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
}); });
static_assert((kHeads % kKVHeads) == 0, static_assert((kHeads % kKVHeads) == 0,
"query heads must be a multiple of key-value heads"); "query heads must be a multiple of key-value heads");
constexpr size_t kGroupHeads = kHeads / kKVHeads; constexpr size_t kGroupHeads = kHeads / kKVHeads;
// For each head (token, query), compute Q.K, softmax, and weighted V.
pool.Run(0, kHeads * num_tokens_and_queries, pool.Run(0, kHeads * num_tokens_and_queries,
[&](uint64_t task, size_t thread) HWY_ATTR { [&](uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads; const size_t head = task % kHeads;
@ -288,16 +297,15 @@ HWY_NOINLINE void Attention(
float* HWY_RESTRICT q = float* HWY_RESTRICT q =
activations.q.data() + (batch_and_query_idx * kHeads + head) * kQStride; activations.q.data() + (batch_and_query_idx * kHeads + head) * kQStride;
// Apply rope and scaling to Q.
const size_t pos = batch_start + batch_idx; const size_t pos = batch_start + batch_idx;
// Calculate scores
float* HWY_RESTRICT head_att =
activations.att.data() + head * kSeqLen
+ batch_and_query_idx * kHeads * kSeqLen;
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
MulByConst(kQueryScale, q, kQKVDim); MulByConst(kQueryScale, q, kQKVDim);
// Compute Q dot K scores // Compute Q.K scores, yielding "logits" (or scores) in head_att.
float* HWY_RESTRICT head_att =
activations.att.data() + head * kSeqLen
+ batch_and_query_idx * kHeads * kSeqLen;
const size_t start_pos = const size_t start_pos =
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos); pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
@ -308,13 +316,17 @@ HWY_NOINLINE void Attention(
const float score = Dot(q, k2, kQKVDim); const float score = Dot(q, k2, kQKVDim);
head_att[pos2 % kSeqLen] = score; head_att[pos2 % kSeqLen] = score;
} }
// SoftMax. May be preceded by SoftCap. Yields "probabilities" in head_att.
const size_t head_att_len = std::min(pos + 1, kSeqLen); const size_t head_att_len = std::min(pos + 1, kSeqLen);
if constexpr (TConfig::kAttCap > 0.0f) { if constexpr (TConfig::kAttCap > 0.0f) {
LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len); LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len);
} }
Softmax(head_att, head_att_len); Softmax(head_att, head_att_len);
// Weighted summation // Summation of v (kv_cache) weighted by probs (head_att)
// into "encoded" (att_out). Compare gemma/modules.py:
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim + float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
batch_and_query_idx * kHeads * kQKVDim; batch_and_query_idx * kHeads * kQKVDim;
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
@ -327,6 +339,9 @@ HWY_NOINLINE void Attention(
} }
}); });
// Sum encoded (att_out) over num_heads and head_dim (kQKVDim)
// into output (layer_out). Compare gemma/modules.py:
// attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded)
for (size_t batch_and_query_idx = 0; for (size_t batch_and_query_idx = 0;
batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) { batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) {
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after // TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
@ -335,10 +350,13 @@ HWY_NOINLINE void Attention(
activations.att_out.data() + batch_and_query_idx * kHeads * kQKVDim; activations.att_out.data() + batch_and_query_idx * kHeads * kQKVDim;
float* HWY_RESTRICT layer_out = float* HWY_RESTRICT layer_out =
activations.att_post2.data() + batch_and_query_idx * kModelDim; activations.att_post2.data() + batch_and_query_idx * kModelDim;
// Head 0 (and potentially biases) -> layer_out.
// attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim].
MatVecT</*kAdd=*/TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>( MatVecT</*kAdd=*/TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
layer_weights->attn_vec_einsum_w, 0, att_out, layer_weights->attn_vec_einsum_w, 0, att_out,
layer_weights->attention_output_biases.data(), layer_weights->attention_output_biases.data(),
activations.even_odd.data(), layer_out, pool); activations.even_odd.data(), layer_out, pool);
// Head 1 and following are added to layer_out.
for (size_t head = 1; head < kHeads; ++head) { for (size_t head = 1; head < kHeads; ++head) {
// TODO(patrickms): Check this calculation // TODO(patrickms): Check this calculation
float* HWY_RESTRICT head_out = float* HWY_RESTRICT head_out =