1.15x 7b sfp prefill speedup: Matmul in attention

2b bf16:
prefill 114.456 -> 115.222
decode  16.8847 -> 16.9987

7b sfp:
prefill 18.8575 -> 21.7325
decode 5.68428 -> 5.79791

PiperOrigin-RevId: 644283676
This commit is contained in:
Jan Wassenberg 2024-06-18 01:00:12 -07:00 committed by Copybara-Service
parent 355f7b4f80
commit a07f60c9a1
2 changed files with 16 additions and 14 deletions

View File

@ -419,28 +419,28 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
}
};
if constexpr (kHeads == kKVHeads) {
// Multi-Head Attention calculates qkv using q as scratch space.
static_assert(TConfig::kInterleaveQKV);
MatMul_4x4_Batch<kModelDim, kHeads * kQKVDim * 3>(
num_tokens, activations.pre_att_rms_out.data(),
layer_weights->qkv_einsum_w.data(), activations.q.data(), pool);
} else {
MatMul_4x4_Batch<kModelDim, kHeads * kQKVDim>(
num_tokens, activations.pre_att_rms_out.data(),
layer_weights->qkv_einsum_w.data(), activations.q.data(), pool);
}
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
// QKV projections:
if constexpr (kHeads == kKVHeads) {
// Multi-Head Attention calculates qkv using q as scratch space.
static_assert(TConfig::kInterleaveQKV);
float* HWY_RESTRICT qkv =
activations.q.data() + batch_idx * kHeads * kQKVDim * 3;
MatVec<kHeads * kQKVDim * 3, kModelDim>(layer_weights->qkv_einsum_w, 0, x,
activations.even_odd.data(), qkv,
pool);
} else {
if constexpr (kHeads != kKVHeads) {
const size_t pos = batch_start + batch_idx;
float* HWY_RESTRICT q =
activations.q.data() + batch_idx * kHeads * kQKVDim;
MatVec<kHeads * kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, 0, x,
activations.even_odd.data(), q, pool);
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
// TODO: requires MatMul support for offsets.
MatVec<kKVHeads * kQKVDim * 2, kModelDim>(
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
activations.even_odd.data(), kv, pool);
@ -494,6 +494,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
for (size_t head = 1; head < kHeads; ++head) {
float* HWY_RESTRICT head_out =
activations.att_post1.data() + head * kBatchSize * kModelDim;
// TODO: requires MatMul support for offsets.
MatVec<kModelDim, kQKVDim>(
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim,
att_out + head * kQKVDim,

View File

@ -881,6 +881,7 @@ template <size_t kColsA_RowsB, size_t kColsBC, bool kAdd, typename MatTA,
HWY_NOINLINE void MatMul_4x4_Batch_Add(
size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B,
OutT* HWY_RESTRICT C, const AddT* add, hwy::ThreadPool& pool) {
PROFILER_ZONE("Matmul");
// Process reg-sized tiles of C in parallel. We currently write C directly,
// which touches more memory than fits in L3. TODO: add another level of loops
// so that we finish one L3-sized piece of C at a time.