From 9a2682d54448763c81f664de985e8c31f8243b22 Mon Sep 17 00:00:00 2001 From: Zoltan Szabadka Date: Thu, 2 May 2024 13:46:45 +0000 Subject: [PATCH] Use more parallelism in the QKV projections of the MHA block. We compute all three projections with one MatVec and then copy the kv part to the cache. Benchmark results for 7b-it model that uses MHA blocks (summarization with 1600 tokens for prefill and essay writing with 500 tokens for generation): ``` Prefill speed Generation speed Num threads BEFORE AFTER BEFORE AFTER 32 13.75 t/s 14.80 t/s 9.22 t/s 9.77 t/s 64 19.89 t/s 24.83 t/s 12.46 t/s 13.66 t/s ``` --- gemma/gemma.cc | 47 +++++++++++++++++-------------------------- gemma/ops.h | 51 ----------------------------------------------- gemma/ops_test.cc | 19 ------------------ 3 files changed, 18 insertions(+), 99 deletions(-) diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 6d80741..28de0c8 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -402,10 +402,11 @@ struct Activations { static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2; static constexpr size_t kCachePosSize = TConfig::kGemmaLayers * kCacheLayerSize; + static constexpr size_t kQDim = kHeads == kKVHeads ? kQKVDim * 3 : kQKVDim; std::array x; // input std::array pre_att_rms_out; - std::array q; // query vector + std::array q; // query vector std::array att; // attention vector std::array att_out; // attention output @@ -710,10 +711,9 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim; - auto Attn = [&](uint64_t head, size_t head_offset, size_t thread) HWY_ATTR { + auto Attn = [&](float* q, uint64_t head, size_t head_offset, + size_t thread) HWY_ATTR { // Calculate scores - float* HWY_RESTRICT q = - activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; float* HWY_RESTRICT head_att = activations.att.data() + head * TConfig::kSeqLen + batch_idx * kHeads * kQKVDim; @@ -745,34 +745,23 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, if constexpr (kHeads == kKVHeads) { // Multi-Head Attention + static_assert(TConfig::kInterleaveQKV); + + float* HWY_RESTRICT qkv = + activations.q.data() + batch_idx * kHeads * kQKVDim * 3; + MatVec( + layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv, + pool); + pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR { - // linear projections to QKV - const size_t head_offset = TConfig::kInterleaveQKV - ? 3 * kQKVDim * kModelDim - : kQKVDim * kModelDim; - const size_t mat_offset = - TConfig::kInterleaveQKV ? kQKVDim * kModelDim : kModelDim * kModelDim; - const size_t q_offset = head * head_offset + 0 * mat_offset; - const size_t k_offset = head * head_offset + 1 * mat_offset; - const size_t v_offset = head * head_offset + 2 * mat_offset; - - // ProjQ - float* HWY_RESTRICT q = - activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; - MatVecLoop( - layer_weights->qkv_einsum_w, q_offset + 0 * kQKVDim * kModelDim, x, - activations.even_odd.data() + thread * kModelDim, q); - - // ProjKV + float* HWY_RESTRICT q = qkv + head * kQKVDim * 3; const size_t kv_offset = cache_pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim * 2; - float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset; - float* HWY_RESTRICT v = k + kQKVDim; - TwoOfsMatVecLoop(layer_weights->qkv_einsum_w, - k_offset, v_offset, x, k, v); - Rope(k, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); + float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; - Attn(head, head * kQKVDim * 2, thread); + memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float)); + Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); + Attn(q, head, head * kQKVDim * 2, thread); }); } else { // Multi-Query Attention @@ -790,7 +779,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR { - Attn(head, 0, thread); + Attn(q + head * kQKVDim, head, 0, thread); }); } diff --git a/gemma/ops.h b/gemma/ops.h index bac98fa..e40e72a 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -92,45 +92,6 @@ HWY_INLINE constexpr size_t RowsPerStrip() { return kRowsPerStrip; } -// Simple version without tiling nor threading. -// even_odd is precomputed for the current thread. -template -HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs, - const VecT* HWY_RESTRICT vec_aligned, - const AddT* HWY_RESTRICT add, - float* HWY_RESTRICT even_odd, - float* HWY_RESTRICT out) { - PROFILER_ZONE("MatVecAddLoop"); - const hn::ScalableTag df; - - // Sanity check: we can write without race conditions. - if (HWY_IS_TSAN) { - even_odd[0] = hwy::ConvertScalarTo(vec_aligned[0]); - even_odd[kInner - 1] = -even_odd[0]; - } - - for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) { - const size_t row_ofs = mat_ofs + idx_row * kInner; - if constexpr (kAdd) { - out[idx_row] = hwy::ConvertScalarTo(add[idx_row]) + - Dot(df, mat, row_ofs, vec_aligned, kInner); - } else { - out[idx_row] = Dot(df, mat, row_ofs, vec_aligned, kInner); - } - } -} - -// even_odd is precomputed for the current thread. -template -HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs, - const VecT* HWY_RESTRICT vec_aligned, - float* HWY_RESTRICT even_odd, - float* HWY_RESTRICT out) { - MatVecAddLoop( - mat, mat_ofs, vec_aligned, /*add=*/nullptr, even_odd, out); -} - // Simple version without tiling nor threading, but two offsets/outputs. template @@ -159,18 +120,6 @@ HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0, } } -// Simple version without tiling nor threading, but two offsets/outputs. -template -HWY_INLINE void TwoOfsMatVecLoop(const ArrayT& mat, const size_t mat_ofs0, - const size_t mat_ofs1, - const VecT* HWY_RESTRICT vec_aligned, - float* HWY_RESTRICT out0, - float* HWY_RESTRICT out1) { - TwoOfsMatVecAddLoop( - mat, mat_ofs0, mat_ofs1, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr, - out0, out1); -} - namespace detail { // For each i = [0, num_rows), compute partial (length `num_cols`) dot product diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index 6a26cfd..973d598 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -436,24 +436,6 @@ void TestMatVecAdd() { AssertClose(actual_out, expected_out); } -void TestMatVecAddLoop() { - constexpr size_t kOuter = 128 * 3; - constexpr size_t kInner = 128 * 5; - CompressedArray mat = GenerateMat(0); - hwy::AlignedFreeUniquePtr vec = GenerateVec(0); - hwy::AlignedFreeUniquePtr add = GenerateVec(0); - hwy::AlignedFreeUniquePtr even_odd = - hwy::AllocateAligned(kInner); - hwy::AlignedFreeUniquePtr expected_out = - SimpleMatVecAdd(mat, vec, add); - hwy::AlignedFreeUniquePtr actual_out = - hwy::AllocateAligned(kOuter); - HWY_ASSERT(vec && add && even_odd && expected_out && actual_out); - MatVecAddLoop(mat, 0, vec.get(), add.get(), - even_odd.get(), actual_out.get()); - AssertClose(actual_out, expected_out); -} - void TestTwoMatVecAdd() { hwy::ThreadPool pool(0); constexpr size_t kOuter = 128 * 3; @@ -537,7 +519,6 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd); -HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAddLoop); HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop); HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);