mirror of https://github.com/google/gemma.cpp.git
Use more parallelism in the QKV projections in MQA mode.
Instead of MatVecLoop, we use MatVec and we combine k and v
into one 2 * kQKVDim long vector so that K and V projections
can be combined into one MatVec operation.
Benchmark results (summarization with 1600 tokens for prefill
and essay writing with 500 tokens for generation):
```
Prefill speed Generation speed
Num threads BEFORE AFTER BEFORE AFTER
4 9.81 t/s 9.96 t/s 8.39 t/s 8.46 t/s
18 31.50 t/s 36.67 t/s 23.10 t/s 25.83 t/s
32 45.36 t/s 58.91 t/s 27.60 t/s 31.25 t/s
64 57.72 t/s 80.64 t/s 35.40 t/s 39.76 t/s
```
This commit is contained in:
parent
befe9fb07e
commit
afaca4efa8
|
|
@ -399,9 +399,9 @@ struct Activations {
|
||||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||||
static constexpr size_t kHeads = TConfig::kHeads;
|
static constexpr size_t kHeads = TConfig::kHeads;
|
||||||
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||||
|
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2;
|
||||||
static constexpr size_t kCachePosSize =
|
static constexpr size_t kCachePosSize =
|
||||||
TConfig::kGemmaLayers * kKVHeads * kQKVDim;
|
TConfig::kGemmaLayers * kCacheLayerSize;
|
||||||
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim;
|
|
||||||
|
|
||||||
std::array<float, kBatchSize * kModelDim> x; // input
|
std::array<float, kBatchSize * kModelDim> x; // input
|
||||||
std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
|
std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
|
||||||
|
|
@ -714,8 +714,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
|
|
||||||
auto ProjKV = [&](size_t k_offset, size_t v_offset,
|
auto ProjKV = [&](size_t k_offset, size_t v_offset,
|
||||||
size_t kv_offset) HWY_ATTR {
|
size_t kv_offset) HWY_ATTR {
|
||||||
float* HWY_RESTRICT k = kv_cache.key_cache.get() + kv_offset;
|
float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset;
|
||||||
float* HWY_RESTRICT v = kv_cache.value_cache.get() + kv_offset;
|
float* HWY_RESTRICT v = k + kQKVDim;
|
||||||
|
|
||||||
TwoOfsMatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, k_offset,
|
TwoOfsMatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, k_offset,
|
||||||
v_offset, x, k, v);
|
v_offset, x, k, v);
|
||||||
|
|
@ -738,7 +738,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
|
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
|
||||||
const size_t cache_offset =
|
const size_t cache_offset =
|
||||||
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||||
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
|
const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + cache_offset;
|
||||||
const float score = Dot(q, k2, kQKVDim);
|
const float score = Dot(q, k2, kQKVDim);
|
||||||
head_att[pos2] = score;
|
head_att[pos2] = score;
|
||||||
}
|
}
|
||||||
|
|
@ -751,7 +751,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
|
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
|
||||||
const size_t cache_offset =
|
const size_t cache_offset =
|
||||||
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||||
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
|
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + cache_offset + kQKVDim;
|
||||||
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
|
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
|
||||||
}
|
}
|
||||||
// linear projection from kQKVDim back to kModelDim, sum projections
|
// linear projection from kQKVDim back to kModelDim, sum projections
|
||||||
|
|
@ -795,16 +795,19 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
// Multi-Query Attention
|
// Multi-Query Attention
|
||||||
constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim;
|
float* HWY_RESTRICT q = activations.q.data() + batch_idx * kHeads * kQKVDim;
|
||||||
constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim;
|
MatVec<kHeads * kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, 0, x, q,
|
||||||
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
|
pool);
|
||||||
const size_t kv_offset =
|
|
||||||
cache_pos * kCachePosSize + layer * kCacheLayerSize;
|
|
||||||
|
|
||||||
ProjKV(k_offset, v_offset, kv_offset);
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() +
|
||||||
|
cache_pos * kCachePosSize +
|
||||||
|
layer * kCacheLayerSize;
|
||||||
|
MatVec<kQKVDim * 2, kModelDim>(layer_weights->qkv_einsum_w,
|
||||||
|
kHeads * kQKVDim * kModelDim, x, kv, pool);
|
||||||
|
|
||||||
|
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||||
|
|
||||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||||
ProjQ(head, head * kQKVDim * kModelDim);
|
|
||||||
Attn(head, 0);
|
Attn(head, 0);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -1465,9 +1468,8 @@ KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
|
||||||
size_t conv1d_cache_size, size_t rglru_cache_size) {
|
size_t conv1d_cache_size, size_t rglru_cache_size) {
|
||||||
KVCache kv_cache = {};
|
KVCache kv_cache = {};
|
||||||
if (size_cache_pos != 0) {
|
if (size_cache_pos != 0) {
|
||||||
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
kv_cache.kv_cache =
|
||||||
kv_cache.value_cache =
|
hwy::AllocateAligned<float>(seq_len * size_cache_pos * 2);
|
||||||
hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
|
||||||
}
|
}
|
||||||
if (conv1d_cache_size != 0) {
|
if (conv1d_cache_size != 0) {
|
||||||
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);
|
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);
|
||||||
|
|
|
||||||
|
|
@ -44,9 +44,7 @@ constexpr bool kSystemPrompt = false;
|
||||||
|
|
||||||
struct KVCache {
|
struct KVCache {
|
||||||
hwy::AlignedFreeUniquePtr<float[]>
|
hwy::AlignedFreeUniquePtr<float[]>
|
||||||
key_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim
|
kv_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim * 2
|
||||||
hwy::AlignedFreeUniquePtr<float[]>
|
|
||||||
value_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]>
|
hwy::AlignedFreeUniquePtr<float[]>
|
||||||
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kGriffinLayers
|
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kGriffinLayers
|
||||||
hwy::AlignedFreeUniquePtr<float[]>
|
hwy::AlignedFreeUniquePtr<float[]>
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue