mirror of https://github.com/google/gemma.cpp.git
Simplify Attention.
Shared kMHA, reuse from Activations, inline Attn lambda, use QDim as the stride between successive Q. PiperOrigin-RevId: 644343854
This commit is contained in:
parent
2ac47e4a06
commit
15135f5b3d
173
gemma/gemma.cc
173
gemma/gemma.cc
|
|
@ -30,7 +30,6 @@
|
||||||
#ifndef GEMMA_ONCE
|
#ifndef GEMMA_ONCE
|
||||||
#define GEMMA_ONCE
|
#define GEMMA_ONCE
|
||||||
|
|
||||||
#include <math.h> // sqrtf
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
|
|
@ -38,7 +37,6 @@
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cmath>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
@ -73,11 +71,14 @@ struct Activations {
|
||||||
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2;
|
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2;
|
||||||
static constexpr size_t kCachePosSize =
|
static constexpr size_t kCachePosSize =
|
||||||
TConfig::kGemmaLayers * kCacheLayerSize;
|
TConfig::kGemmaLayers * kCacheLayerSize;
|
||||||
static constexpr size_t kQDim = kHeads == kKVHeads ? kQKVDim * 3 : kQKVDim;
|
static constexpr bool kIsMHA = kHeads == kKVHeads; // Multi-Head Attention
|
||||||
|
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
|
||||||
|
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
|
||||||
|
static constexpr size_t kQStride = kQKVDim * (kIsMHA ? 3 : 1);
|
||||||
|
|
||||||
std::array<float, kBatchSize * kModelDim> x; // input
|
std::array<float, kBatchSize * kModelDim> x; // input
|
||||||
std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
|
std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
|
||||||
std::array<float, kBatchSize * kHeads * kQDim> q; // query vector
|
std::array<float, kBatchSize * kHeads * kQStride> q; // query vector
|
||||||
std::array<float, kBatchSize * kHeads * TConfig::kSeqLen>
|
std::array<float, kBatchSize * kHeads * TConfig::kSeqLen>
|
||||||
att; // attention vector
|
att; // attention vector
|
||||||
std::array<float, kBatchSize * kHeads * kQKVDim> att_out; // attention output
|
std::array<float, kBatchSize * kHeads * kQKVDim> att_out; // attention output
|
||||||
|
|
@ -242,7 +243,7 @@ HWY_NOINLINE void GriffinRecurrent(
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
HWY_DASSERT(num_tokens <= kBatchSize);
|
HWY_DASSERT(num_tokens <= kBatchSize);
|
||||||
static constexpr size_t kModelDim =
|
static constexpr size_t kModelDim =
|
||||||
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
|
Activations<TConfig, kBatchSize>::kModelDim;
|
||||||
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||||
static constexpr size_t kHeads = TConfig::kHeads;
|
static constexpr size_t kHeads = TConfig::kHeads;
|
||||||
|
|
||||||
|
|
@ -370,71 +371,29 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("Gen.Attention");
|
PROFILER_ZONE("Gen.Attention");
|
||||||
HWY_DASSERT(num_tokens <= kBatchSize);
|
HWY_DASSERT(num_tokens <= kBatchSize);
|
||||||
static constexpr size_t kQKVDim = gcpp::Activations<TConfig, 1>::kQKVDim;
|
using TActivations = Activations<TConfig, kBatchSize>;
|
||||||
static constexpr size_t kCachePosSize =
|
constexpr size_t kQKVDim = TActivations::kQKVDim;
|
||||||
gcpp::Activations<TConfig, kBatchSize>::kCachePosSize;
|
constexpr size_t kQStride = TActivations::kQStride;
|
||||||
static constexpr size_t kCacheLayerSize =
|
constexpr size_t kCachePosSize = TActivations::kCachePosSize;
|
||||||
gcpp::Activations<TConfig, kBatchSize>::kCacheLayerSize;
|
constexpr size_t kCacheLayerSize = TActivations::kCacheLayerSize;
|
||||||
static constexpr size_t kModelDim =
|
constexpr size_t kModelDim = TActivations::kModelDim;
|
||||||
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
|
constexpr size_t kHeads = TConfig::kHeads;
|
||||||
static constexpr size_t kHeads = TConfig::kHeads;
|
constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||||
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||||
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
GEMMA_CONSTEXPR_SQRT const float kQueryScale =
|
||||||
static const float kQueryScale =
|
1.0f / Sqrt(static_cast<float>(kQKVDim));
|
||||||
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
constexpr bool kIsMHA = TActivations::kIsMHA; // Multi-Head Attention
|
||||||
|
|
||||||
auto Attn = [&](float* q, uint64_t head, size_t head_offset, size_t batch_idx,
|
// If MHA, this also computes KV, which we copy to the KV cache below.
|
||||||
size_t thread) HWY_ATTR {
|
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
|
||||||
const size_t pos = batch_start + batch_idx;
|
MatMul_4x4_Batch<kModelDim, kHeads * kQStride>(
|
||||||
// Calculate scores
|
num_tokens, activations.pre_att_rms_out.data(),
|
||||||
float* HWY_RESTRICT head_att = activations.att.data() +
|
layer_weights->qkv_einsum_w.data(), activations.q.data(), pool);
|
||||||
head * kSeqLen +
|
|
||||||
batch_idx * kHeads * kSeqLen;
|
|
||||||
|
|
||||||
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
|
||||||
MulByConst(kQueryScale, q, kQKVDim);
|
|
||||||
|
|
||||||
// Compute Q dot K scores
|
|
||||||
const size_t start_pos = pos - std::min(kSeqLen - 1, pos);
|
|
||||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
|
||||||
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
|
||||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
|
||||||
layer * kCacheLayerSize + head_offset;
|
|
||||||
const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset;
|
|
||||||
const float score = Dot(q, k2, kQKVDim);
|
|
||||||
head_att[pos2 % kSeqLen] = score;
|
|
||||||
}
|
|
||||||
Softmax(head_att, std::min(pos + 1, kSeqLen));
|
|
||||||
|
|
||||||
// Weighted summation
|
|
||||||
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
|
|
||||||
batch_idx * kHeads * kQKVDim;
|
|
||||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
|
||||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
|
||||||
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
|
||||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
|
||||||
layer * kCacheLayerSize + head_offset;
|
|
||||||
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + kv_offset + kQKVDim;
|
|
||||||
MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if constexpr (kHeads == kKVHeads) {
|
|
||||||
// Multi-Head Attention calculates qkv using q as scratch space.
|
|
||||||
static_assert(TConfig::kInterleaveQKV);
|
|
||||||
MatMul_4x4_Batch<kModelDim, kHeads * kQKVDim * 3>(
|
|
||||||
num_tokens, activations.pre_att_rms_out.data(),
|
|
||||||
layer_weights->qkv_einsum_w.data(), activations.q.data(), pool);
|
|
||||||
} else {
|
|
||||||
MatMul_4x4_Batch<kModelDim, kHeads * kQKVDim>(
|
|
||||||
num_tokens, activations.pre_att_rms_out.data(),
|
|
||||||
layer_weights->qkv_einsum_w.data(), activations.q.data(), pool);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||||
const float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
const float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||||
// QKV projections:
|
// QKV projections:
|
||||||
if constexpr (kHeads != kKVHeads) {
|
if constexpr (!kIsMHA) {
|
||||||
const size_t pos = batch_start + batch_idx;
|
const size_t pos = batch_start + batch_idx;
|
||||||
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
|
|
@ -447,37 +406,67 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Positional encodings for k:
|
// Positional encodings for kv:
|
||||||
const size_t num_kv_tasks = kKVHeads * num_tokens;
|
pool.Run(
|
||||||
pool.Run(0, num_kv_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
0, kKVHeads * num_tokens, [&](uint64_t task, size_t thread) HWY_ATTR {
|
||||||
const size_t head = task % kKVHeads;
|
const size_t head = task % kKVHeads;
|
||||||
const size_t batch_idx = task / kKVHeads;
|
const size_t batch_idx = task / kKVHeads;
|
||||||
const size_t pos = batch_start + batch_idx;
|
const size_t pos = batch_start + batch_idx;
|
||||||
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
||||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||||
layer * kCacheLayerSize + head * kQKVDim * 2;
|
layer * kCacheLayerSize + head * kQKVDim * 2;
|
||||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||||
if constexpr (kHeads == kKVHeads) {
|
if constexpr (kIsMHA) {
|
||||||
// For MHA, copy kv into the KV cache from scratch space (see above).
|
// For MHA, copy kv into the KV cache from scratch space (see above).
|
||||||
const float* HWY_RESTRICT q =
|
const float* HWY_RESTRICT q =
|
||||||
activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3;
|
activations.q.data() + (batch_idx * kHeads + head) * kQStride;
|
||||||
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
|
// Skip past the Q part of `q`, and copy KV to `kv`.
|
||||||
}
|
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
|
||||||
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
}
|
||||||
});
|
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||||
|
});
|
||||||
|
|
||||||
static_assert((TConfig::kHeads % TConfig::kKVHeads) == 0,
|
static_assert((kHeads % kKVHeads) == 0,
|
||||||
"query heads must be a multiple of key-value heads");
|
"query heads must be a multiple of key-value heads");
|
||||||
static constexpr size_t kGroupHeads = TConfig::kHeads / TConfig::kKVHeads;
|
static constexpr size_t kGroupHeads = kHeads / kKVHeads;
|
||||||
static constexpr size_t kQOffsetScale = (kHeads == kKVHeads) ? 3 : 1;
|
pool.Run(0, kHeads * num_tokens, [&](uint64_t task, size_t thread) HWY_ATTR {
|
||||||
const size_t num_q_tasks = kHeads * num_tokens;
|
|
||||||
pool.Run(0, num_q_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
|
||||||
const size_t head = task % kHeads;
|
const size_t head = task % kHeads;
|
||||||
const size_t batch_idx = task / kHeads;
|
const size_t batch_idx = task / kHeads;
|
||||||
const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2;
|
const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2;
|
||||||
float* HWY_RESTRICT q = activations.q.data() + (batch_idx * kHeads + head) *
|
float* HWY_RESTRICT q =
|
||||||
kQKVDim * kQOffsetScale;
|
activations.q.data() + (batch_idx * kHeads + head) * kQStride;
|
||||||
Attn(q, head, head_offset, batch_idx, thread);
|
|
||||||
|
const size_t pos = batch_start + batch_idx;
|
||||||
|
// Calculate scores
|
||||||
|
float* HWY_RESTRICT head_att =
|
||||||
|
activations.att.data() + head * kSeqLen + batch_idx * kHeads * kSeqLen;
|
||||||
|
|
||||||
|
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||||
|
MulByConst(kQueryScale, q, kQKVDim);
|
||||||
|
|
||||||
|
// Compute Q dot K scores
|
||||||
|
const size_t start_pos = pos - std::min(kSeqLen - 1, pos);
|
||||||
|
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||||
|
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
||||||
|
const size_t kv_offset =
|
||||||
|
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||||
|
const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset;
|
||||||
|
const float score = Dot(q, k2, kQKVDim);
|
||||||
|
head_att[pos2 % kSeqLen] = score;
|
||||||
|
}
|
||||||
|
Softmax(head_att, std::min(pos + 1, kSeqLen));
|
||||||
|
|
||||||
|
// Weighted summation
|
||||||
|
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
|
||||||
|
batch_idx * kHeads * kQKVDim;
|
||||||
|
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||||
|
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||||
|
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
||||||
|
const size_t kv_offset =
|
||||||
|
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||||
|
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + kv_offset + kQKVDim;
|
||||||
|
MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||||
|
|
@ -1012,7 +1001,7 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||||
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
|
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
|
||||||
// Both pre-trained and instruction-tuned require BOS as first token.
|
// Both pre-trained and instruction-tuned require BOS as first token.
|
||||||
if (pos == 0) {
|
if (pos == 0) {
|
||||||
tokens.insert(tokens.begin(), gcpp::BOS_ID);
|
tokens.insert(tokens.begin(), BOS_ID);
|
||||||
}
|
}
|
||||||
return tokens;
|
return tokens;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue