mirror of https://github.com/google/gemma.cpp.git
Add more comments to attention computation (and some small restructuring).
PiperOrigin-RevId: 650929097
This commit is contained in:
parent
cf76f0a401
commit
063bbaa683
|
|
@ -218,38 +218,45 @@ HWY_NOINLINE void Attention(
|
|||
constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
GEMMA_CONSTEXPR_SQRT const float kQueryScale =
|
||||
1.0f / Sqrt(static_cast<float>(kQKVDim));
|
||||
constexpr bool kIsMHA = TActivations::kIsMHA; // Multi-Head Attention
|
||||
// Multi-Head Attention a.k.a. "use_qkv_einsum".
|
||||
constexpr bool kIsMHA = TActivations::kIsMHA;
|
||||
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
|
||||
const size_t batch_start = batch_and_query_start / num_queries;
|
||||
const size_t num_tokens_and_queries = num_tokens * num_queries;
|
||||
|
||||
// For the computation of Q, K, and V, it is useful to remember that
|
||||
// qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim]
|
||||
// and kQStride = kQKVDim * (kIsMHA ? 3 : 1);
|
||||
//
|
||||
// Compute Q only or QKV (if MHA).
|
||||
// If MHA, this also computes KV, which we copy to the KV cache below.
|
||||
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
|
||||
MatMul_4x4_Batch<kModelDim, kHeads * kQStride>(
|
||||
num_tokens_and_queries, activations.pre_att_rms_out.data(),
|
||||
layer_weights->qkv_einsum_w.data(), activations.q.data(), pool);
|
||||
|
||||
// Compute KV if not MHA.
|
||||
if constexpr (!kIsMHA) {
|
||||
for (size_t batch_and_query_idx = 0;
|
||||
batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) {
|
||||
const float* x = activations.pre_att_rms_out.data() + batch_and_query_idx
|
||||
* kModelDim;
|
||||
const float* x =
|
||||
activations.pre_att_rms_out.data() + batch_and_query_idx * kModelDim;
|
||||
const size_t query_idx = batch_and_query_idx % num_queries;
|
||||
const size_t batch_idx = batch_and_query_idx / num_queries;
|
||||
KVCache& kv_cache = *kv_caches[query_idx];
|
||||
// QKV projections:
|
||||
if constexpr (!kIsMHA) {
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
||||
const size_t kv_offset =
|
||||
cache_pos * kCachePosSize + layer * kCacheLayerSize;
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
|
||||
// TODO: requires MatMul support for offsets.
|
||||
MatVec<kKVHeads * kQKVDim * 2, kModelDim>(
|
||||
MatVec<kKVHeads * 2 * kQKVDim, kModelDim>(
|
||||
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
|
||||
activations.even_odd.data(), kv, pool);
|
||||
}
|
||||
}
|
||||
|
||||
// Positional encodings for kv:
|
||||
// Apply positional encodings for K (and copy KV to cache if MHA).
|
||||
pool.Run(
|
||||
0, kKVHeads * num_tokens_and_queries,
|
||||
[&](uint64_t task, size_t thread) HWY_ATTR {
|
||||
|
|
@ -264,19 +271,21 @@ HWY_NOINLINE void Attention(
|
|||
KVCache& kv_cache = *kv_caches[query_idx];
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
if constexpr (kIsMHA) {
|
||||
// For MHA, copy kv into the KV cache from scratch space (see above).
|
||||
// For MHA, copy KV into the KV cache from scratch space (see above).
|
||||
const float* HWY_RESTRICT q =
|
||||
activations.q.data() + (batch_and_query_idx * kHeads
|
||||
+ head) * kQStride;
|
||||
// Skip past the Q part of `q`, and copy KV to `kv`.
|
||||
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
|
||||
}
|
||||
// Apply rope to K.
|
||||
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
});
|
||||
|
||||
static_assert((kHeads % kKVHeads) == 0,
|
||||
"query heads must be a multiple of key-value heads");
|
||||
constexpr size_t kGroupHeads = kHeads / kKVHeads;
|
||||
// For each head (token, query), compute Q.K, softmax, and weighted V.
|
||||
pool.Run(0, kHeads * num_tokens_and_queries,
|
||||
[&](uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
|
|
@ -288,16 +297,15 @@ HWY_NOINLINE void Attention(
|
|||
float* HWY_RESTRICT q =
|
||||
activations.q.data() + (batch_and_query_idx * kHeads + head) * kQStride;
|
||||
|
||||
// Apply rope and scaling to Q.
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
// Calculate scores
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations.att.data() + head * kSeqLen
|
||||
+ batch_and_query_idx * kHeads * kSeqLen;
|
||||
|
||||
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
MulByConst(kQueryScale, q, kQKVDim);
|
||||
|
||||
// Compute Q dot K scores
|
||||
// Compute Q.K scores, yielding "logits" (or scores) in head_att.
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations.att.data() + head * kSeqLen
|
||||
+ batch_and_query_idx * kHeads * kSeqLen;
|
||||
const size_t start_pos =
|
||||
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
|
||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||
|
|
@ -308,13 +316,17 @@ HWY_NOINLINE void Attention(
|
|||
const float score = Dot(q, k2, kQKVDim);
|
||||
head_att[pos2 % kSeqLen] = score;
|
||||
}
|
||||
|
||||
// SoftMax. May be preceded by SoftCap. Yields "probabilities" in head_att.
|
||||
const size_t head_att_len = std::min(pos + 1, kSeqLen);
|
||||
if constexpr (TConfig::kAttCap > 0.0f) {
|
||||
LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len);
|
||||
}
|
||||
Softmax(head_att, head_att_len);
|
||||
|
||||
// Weighted summation
|
||||
// Summation of v (kv_cache) weighted by probs (head_att)
|
||||
// into "encoded" (att_out). Compare gemma/modules.py:
|
||||
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
||||
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
|
||||
batch_and_query_idx * kHeads * kQKVDim;
|
||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||
|
|
@ -327,6 +339,9 @@ HWY_NOINLINE void Attention(
|
|||
}
|
||||
});
|
||||
|
||||
// Sum encoded (att_out) over num_heads and head_dim (kQKVDim)
|
||||
// into output (layer_out). Compare gemma/modules.py:
|
||||
// attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded)
|
||||
for (size_t batch_and_query_idx = 0;
|
||||
batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) {
|
||||
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
|
||||
|
|
@ -335,10 +350,13 @@ HWY_NOINLINE void Attention(
|
|||
activations.att_out.data() + batch_and_query_idx * kHeads * kQKVDim;
|
||||
float* HWY_RESTRICT layer_out =
|
||||
activations.att_post2.data() + batch_and_query_idx * kModelDim;
|
||||
// Head 0 (and potentially biases) -> layer_out.
|
||||
// attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim].
|
||||
MatVecT</*kAdd=*/TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
|
||||
layer_weights->attn_vec_einsum_w, 0, att_out,
|
||||
layer_weights->attention_output_biases.data(),
|
||||
activations.even_odd.data(), layer_out, pool);
|
||||
// Head 1 and following are added to layer_out.
|
||||
for (size_t head = 1; head < kHeads; ++head) {
|
||||
// TODO(patrickms): Check this calculation
|
||||
float* HWY_RESTRICT head_out =
|
||||
|
|
|
|||
Loading…
Reference in New Issue