mirror of https://github.com/google/gemma.cpp.git
Add more comments to attention computation (and some small restructuring).
PiperOrigin-RevId: 650929097
This commit is contained in:
parent
cf76f0a401
commit
063bbaa683
|
|
@ -218,38 +218,45 @@ HWY_NOINLINE void Attention(
|
||||||
constexpr size_t kSeqLen = TConfig::kSeqLen;
|
constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||||
GEMMA_CONSTEXPR_SQRT const float kQueryScale =
|
GEMMA_CONSTEXPR_SQRT const float kQueryScale =
|
||||||
1.0f / Sqrt(static_cast<float>(kQKVDim));
|
1.0f / Sqrt(static_cast<float>(kQKVDim));
|
||||||
constexpr bool kIsMHA = TActivations::kIsMHA; // Multi-Head Attention
|
// Multi-Head Attention a.k.a. "use_qkv_einsum".
|
||||||
|
constexpr bool kIsMHA = TActivations::kIsMHA;
|
||||||
|
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
|
||||||
const size_t batch_start = batch_and_query_start / num_queries;
|
const size_t batch_start = batch_and_query_start / num_queries;
|
||||||
const size_t num_tokens_and_queries = num_tokens * num_queries;
|
const size_t num_tokens_and_queries = num_tokens * num_queries;
|
||||||
|
|
||||||
|
// For the computation of Q, K, and V, it is useful to remember that
|
||||||
|
// qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim]
|
||||||
|
// and kQStride = kQKVDim * (kIsMHA ? 3 : 1);
|
||||||
|
//
|
||||||
|
// Compute Q only or QKV (if MHA).
|
||||||
// If MHA, this also computes KV, which we copy to the KV cache below.
|
// If MHA, this also computes KV, which we copy to the KV cache below.
|
||||||
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
|
|
||||||
MatMul_4x4_Batch<kModelDim, kHeads * kQStride>(
|
MatMul_4x4_Batch<kModelDim, kHeads * kQStride>(
|
||||||
num_tokens_and_queries, activations.pre_att_rms_out.data(),
|
num_tokens_and_queries, activations.pre_att_rms_out.data(),
|
||||||
layer_weights->qkv_einsum_w.data(), activations.q.data(), pool);
|
layer_weights->qkv_einsum_w.data(), activations.q.data(), pool);
|
||||||
|
|
||||||
for (size_t batch_and_query_idx = 0;
|
// Compute KV if not MHA.
|
||||||
batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) {
|
if constexpr (!kIsMHA) {
|
||||||
const float* x = activations.pre_att_rms_out.data() + batch_and_query_idx
|
for (size_t batch_and_query_idx = 0;
|
||||||
* kModelDim;
|
batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) {
|
||||||
const size_t query_idx = batch_and_query_idx % num_queries;
|
const float* x =
|
||||||
const size_t batch_idx = batch_and_query_idx / num_queries;
|
activations.pre_att_rms_out.data() + batch_and_query_idx * kModelDim;
|
||||||
KVCache& kv_cache = *kv_caches[query_idx];
|
const size_t query_idx = batch_and_query_idx % num_queries;
|
||||||
// QKV projections:
|
const size_t batch_idx = batch_and_query_idx / num_queries;
|
||||||
if constexpr (!kIsMHA) {
|
KVCache& kv_cache = *kv_caches[query_idx];
|
||||||
const size_t pos = batch_start + batch_idx;
|
const size_t pos = batch_start + batch_idx;
|
||||||
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
cache_pos * kCachePosSize + layer * kCacheLayerSize;
|
cache_pos * kCachePosSize + layer * kCacheLayerSize;
|
||||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||||
|
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
|
||||||
// TODO: requires MatMul support for offsets.
|
// TODO: requires MatMul support for offsets.
|
||||||
MatVec<kKVHeads * kQKVDim * 2, kModelDim>(
|
MatVec<kKVHeads * 2 * kQKVDim, kModelDim>(
|
||||||
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
|
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
|
||||||
activations.even_odd.data(), kv, pool);
|
activations.even_odd.data(), kv, pool);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Positional encodings for kv:
|
// Apply positional encodings for K (and copy KV to cache if MHA).
|
||||||
pool.Run(
|
pool.Run(
|
||||||
0, kKVHeads * num_tokens_and_queries,
|
0, kKVHeads * num_tokens_and_queries,
|
||||||
[&](uint64_t task, size_t thread) HWY_ATTR {
|
[&](uint64_t task, size_t thread) HWY_ATTR {
|
||||||
|
|
@ -264,19 +271,21 @@ HWY_NOINLINE void Attention(
|
||||||
KVCache& kv_cache = *kv_caches[query_idx];
|
KVCache& kv_cache = *kv_caches[query_idx];
|
||||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||||
if constexpr (kIsMHA) {
|
if constexpr (kIsMHA) {
|
||||||
// For MHA, copy kv into the KV cache from scratch space (see above).
|
// For MHA, copy KV into the KV cache from scratch space (see above).
|
||||||
const float* HWY_RESTRICT q =
|
const float* HWY_RESTRICT q =
|
||||||
activations.q.data() + (batch_and_query_idx * kHeads
|
activations.q.data() + (batch_and_query_idx * kHeads
|
||||||
+ head) * kQStride;
|
+ head) * kQStride;
|
||||||
// Skip past the Q part of `q`, and copy KV to `kv`.
|
// Skip past the Q part of `q`, and copy KV to `kv`.
|
||||||
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
|
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
|
||||||
}
|
}
|
||||||
|
// Apply rope to K.
|
||||||
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||||
});
|
});
|
||||||
|
|
||||||
static_assert((kHeads % kKVHeads) == 0,
|
static_assert((kHeads % kKVHeads) == 0,
|
||||||
"query heads must be a multiple of key-value heads");
|
"query heads must be a multiple of key-value heads");
|
||||||
constexpr size_t kGroupHeads = kHeads / kKVHeads;
|
constexpr size_t kGroupHeads = kHeads / kKVHeads;
|
||||||
|
// For each head (token, query), compute Q.K, softmax, and weighted V.
|
||||||
pool.Run(0, kHeads * num_tokens_and_queries,
|
pool.Run(0, kHeads * num_tokens_and_queries,
|
||||||
[&](uint64_t task, size_t thread) HWY_ATTR {
|
[&](uint64_t task, size_t thread) HWY_ATTR {
|
||||||
const size_t head = task % kHeads;
|
const size_t head = task % kHeads;
|
||||||
|
|
@ -288,16 +297,15 @@ HWY_NOINLINE void Attention(
|
||||||
float* HWY_RESTRICT q =
|
float* HWY_RESTRICT q =
|
||||||
activations.q.data() + (batch_and_query_idx * kHeads + head) * kQStride;
|
activations.q.data() + (batch_and_query_idx * kHeads + head) * kQStride;
|
||||||
|
|
||||||
|
// Apply rope and scaling to Q.
|
||||||
const size_t pos = batch_start + batch_idx;
|
const size_t pos = batch_start + batch_idx;
|
||||||
// Calculate scores
|
|
||||||
float* HWY_RESTRICT head_att =
|
|
||||||
activations.att.data() + head * kSeqLen
|
|
||||||
+ batch_and_query_idx * kHeads * kSeqLen;
|
|
||||||
|
|
||||||
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||||
MulByConst(kQueryScale, q, kQKVDim);
|
MulByConst(kQueryScale, q, kQKVDim);
|
||||||
|
|
||||||
// Compute Q dot K scores
|
// Compute Q.K scores, yielding "logits" (or scores) in head_att.
|
||||||
|
float* HWY_RESTRICT head_att =
|
||||||
|
activations.att.data() + head * kSeqLen
|
||||||
|
+ batch_and_query_idx * kHeads * kSeqLen;
|
||||||
const size_t start_pos =
|
const size_t start_pos =
|
||||||
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
|
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
|
||||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||||
|
|
@ -308,13 +316,17 @@ HWY_NOINLINE void Attention(
|
||||||
const float score = Dot(q, k2, kQKVDim);
|
const float score = Dot(q, k2, kQKVDim);
|
||||||
head_att[pos2 % kSeqLen] = score;
|
head_att[pos2 % kSeqLen] = score;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SoftMax. May be preceded by SoftCap. Yields "probabilities" in head_att.
|
||||||
const size_t head_att_len = std::min(pos + 1, kSeqLen);
|
const size_t head_att_len = std::min(pos + 1, kSeqLen);
|
||||||
if constexpr (TConfig::kAttCap > 0.0f) {
|
if constexpr (TConfig::kAttCap > 0.0f) {
|
||||||
LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len);
|
LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len);
|
||||||
}
|
}
|
||||||
Softmax(head_att, head_att_len);
|
Softmax(head_att, head_att_len);
|
||||||
|
|
||||||
// Weighted summation
|
// Summation of v (kv_cache) weighted by probs (head_att)
|
||||||
|
// into "encoded" (att_out). Compare gemma/modules.py:
|
||||||
|
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
||||||
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
|
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
|
||||||
batch_and_query_idx * kHeads * kQKVDim;
|
batch_and_query_idx * kHeads * kQKVDim;
|
||||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||||
|
|
@ -327,6 +339,9 @@ HWY_NOINLINE void Attention(
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Sum encoded (att_out) over num_heads and head_dim (kQKVDim)
|
||||||
|
// into output (layer_out). Compare gemma/modules.py:
|
||||||
|
// attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded)
|
||||||
for (size_t batch_and_query_idx = 0;
|
for (size_t batch_and_query_idx = 0;
|
||||||
batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) {
|
batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) {
|
||||||
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
|
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
|
||||||
|
|
@ -335,10 +350,13 @@ HWY_NOINLINE void Attention(
|
||||||
activations.att_out.data() + batch_and_query_idx * kHeads * kQKVDim;
|
activations.att_out.data() + batch_and_query_idx * kHeads * kQKVDim;
|
||||||
float* HWY_RESTRICT layer_out =
|
float* HWY_RESTRICT layer_out =
|
||||||
activations.att_post2.data() + batch_and_query_idx * kModelDim;
|
activations.att_post2.data() + batch_and_query_idx * kModelDim;
|
||||||
|
// Head 0 (and potentially biases) -> layer_out.
|
||||||
|
// attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim].
|
||||||
MatVecT</*kAdd=*/TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
|
MatVecT</*kAdd=*/TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
|
||||||
layer_weights->attn_vec_einsum_w, 0, att_out,
|
layer_weights->attn_vec_einsum_w, 0, att_out,
|
||||||
layer_weights->attention_output_biases.data(),
|
layer_weights->attention_output_biases.data(),
|
||||||
activations.even_odd.data(), layer_out, pool);
|
activations.even_odd.data(), layer_out, pool);
|
||||||
|
// Head 1 and following are added to layer_out.
|
||||||
for (size_t head = 1; head < kHeads; ++head) {
|
for (size_t head = 1; head < kHeads; ++head) {
|
||||||
// TODO(patrickms): Check this calculation
|
// TODO(patrickms): Check this calculation
|
||||||
float* HWY_RESTRICT head_out =
|
float* HWY_RESTRICT head_out =
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue