Use more parallelism in the QKV projections in MQA mode.

Instead of MatVecLoop, we use MatVec and we combine k and v
into one 2 * kQKVDim long vector so that K and V projections
can be combined into one MatVec operation.

Benchmark results (summarization with 1600 tokens for prefill
and essay writing with 500 tokens for generation):

```
                   Prefill speed                Generation speed
Num threads      BEFORE       AFTER            BEFORE       AFTER
4                 9.81 t/s     9.96 t/s       8.39 t/s     8.46 t/s
18               31.50 t/s    36.67 t/s      23.10 t/s    25.83 t/s
32               45.36 t/s    58.91 t/s      27.60 t/s    31.25 t/s
64               57.72 t/s    80.64 t/s      35.40 t/s    39.76 t/s
```
This commit is contained in:
Zoltan Szabadka 2024-04-30 13:10:14 +00:00
parent befe9fb07e
commit afaca4efa8
2 changed files with 19 additions and 19 deletions

View File

@ -399,9 +399,9 @@ struct Activations {
static constexpr size_t kQKVDim = TConfig::kQKVDim; static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads; static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2;
static constexpr size_t kCachePosSize = static constexpr size_t kCachePosSize =
TConfig::kGemmaLayers * kKVHeads * kQKVDim; TConfig::kGemmaLayers * kCacheLayerSize;
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim;
std::array<float, kBatchSize * kModelDim> x; // input std::array<float, kBatchSize * kModelDim> x; // input
std::array<float, kBatchSize * kModelDim> pre_att_rms_out; std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
@ -714,8 +714,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
auto ProjKV = [&](size_t k_offset, size_t v_offset, auto ProjKV = [&](size_t k_offset, size_t v_offset,
size_t kv_offset) HWY_ATTR { size_t kv_offset) HWY_ATTR {
float* HWY_RESTRICT k = kv_cache.key_cache.get() + kv_offset; float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset;
float* HWY_RESTRICT v = kv_cache.value_cache.get() + kv_offset; float* HWY_RESTRICT v = k + kQKVDim;
TwoOfsMatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, k_offset, TwoOfsMatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, k_offset,
v_offset, x, k, v); v_offset, x, k, v);
@ -738,7 +738,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) { for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
const size_t cache_offset = const size_t cache_offset =
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset; const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + cache_offset;
const float score = Dot(q, k2, kQKVDim); const float score = Dot(q, k2, kQKVDim);
head_att[pos2] = score; head_att[pos2] = score;
} }
@ -751,7 +751,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) { for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
const size_t cache_offset = const size_t cache_offset =
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset; float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + cache_offset + kQKVDim;
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
} }
// linear projection from kQKVDim back to kModelDim, sum projections // linear projection from kQKVDim back to kModelDim, sum projections
@ -795,16 +795,19 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
}); });
} else { } else {
// Multi-Query Attention // Multi-Query Attention
constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim; float* HWY_RESTRICT q = activations.q.data() + batch_idx * kHeads * kQKVDim;
constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim; MatVec<kHeads * kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, 0, x, q,
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim; pool);
const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize;
ProjKV(k_offset, v_offset, kv_offset); float* HWY_RESTRICT kv = kv_cache.kv_cache.get() +
cache_pos * kCachePosSize +
layer * kCacheLayerSize;
MatVec<kQKVDim * 2, kModelDim>(layer_weights->qkv_einsum_w,
kHeads * kQKVDim * kModelDim, x, kv, pool);
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
ProjQ(head, head * kQKVDim * kModelDim);
Attn(head, 0); Attn(head, 0);
}); });
} }
@ -1465,9 +1468,8 @@ KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
size_t conv1d_cache_size, size_t rglru_cache_size) { size_t conv1d_cache_size, size_t rglru_cache_size) {
KVCache kv_cache = {}; KVCache kv_cache = {};
if (size_cache_pos != 0) { if (size_cache_pos != 0) {
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos); kv_cache.kv_cache =
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos * 2);
hwy::AllocateAligned<float>(seq_len * size_cache_pos);
} }
if (conv1d_cache_size != 0) { if (conv1d_cache_size != 0) {
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size); kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);

View File

@ -44,9 +44,7 @@ constexpr bool kSystemPrompt = false;
struct KVCache { struct KVCache {
hwy::AlignedFreeUniquePtr<float[]> hwy::AlignedFreeUniquePtr<float[]>
key_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim kv_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim * 2
hwy::AlignedFreeUniquePtr<float[]>
value_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim
hwy::AlignedFreeUniquePtr<float[]> hwy::AlignedFreeUniquePtr<float[]>
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kGriffinLayers conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kGriffinLayers
hwy::AlignedFreeUniquePtr<float[]> hwy::AlignedFreeUniquePtr<float[]>