Use more parallelism in the QKV projections of the MHA block.

We compute all three projections with one MatVec and then copy
the kv part to the cache.

Benchmark results for 7b-it model that uses MHA blocks (summarization with
1600 tokens for prefill and essay writing with 500 tokens for generation):

```
                   Prefill speed                Generation speed
Num threads      BEFORE       AFTER            BEFORE       AFTER
32               13.75 t/s    14.80 t/s       9.22 t/s     9.77 t/s
64               19.89 t/s    24.83 t/s      12.46 t/s    13.66 t/s
```
This commit is contained in:
Zoltan Szabadka 2024-05-02 13:46:45 +00:00
parent bafb8382f8
commit 9a2682d544
3 changed files with 18 additions and 99 deletions

View File

@ -402,10 +402,11 @@ struct Activations {
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2;
static constexpr size_t kCachePosSize =
TConfig::kGemmaLayers * kCacheLayerSize;
static constexpr size_t kQDim = kHeads == kKVHeads ? kQKVDim * 3 : kQKVDim;
std::array<float, kBatchSize * kModelDim> x; // input
std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
std::array<float, kBatchSize * kHeads * kQKVDim> q; // query vector
std::array<float, kBatchSize * kHeads * kQDim> q; // query vector
std::array<float, kBatchSize * kHeads * TConfig::kSeqLen>
att; // attention vector
std::array<float, kBatchSize * kHeads * kQKVDim> att_out; // attention output
@ -710,10 +711,9 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
auto Attn = [&](uint64_t head, size_t head_offset, size_t thread) HWY_ATTR {
auto Attn = [&](float* q, uint64_t head, size_t head_offset,
size_t thread) HWY_ATTR {
// Calculate scores
float* HWY_RESTRICT q =
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
float* HWY_RESTRICT head_att = activations.att.data() +
head * TConfig::kSeqLen +
batch_idx * kHeads * kQKVDim;
@ -745,34 +745,23 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
if constexpr (kHeads == kKVHeads) {
// Multi-Head Attention
static_assert(TConfig::kInterleaveQKV);
float* HWY_RESTRICT qkv =
activations.q.data() + batch_idx * kHeads * kQKVDim * 3;
MatVec<kHeads * kQKVDim * 3, kModelDim>(
layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv,
pool);
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
// linear projections to QKV
const size_t head_offset = TConfig::kInterleaveQKV
? 3 * kQKVDim * kModelDim
: kQKVDim * kModelDim;
const size_t mat_offset =
TConfig::kInterleaveQKV ? kQKVDim * kModelDim : kModelDim * kModelDim;
const size_t q_offset = head * head_offset + 0 * mat_offset;
const size_t k_offset = head * head_offset + 1 * mat_offset;
const size_t v_offset = head * head_offset + 2 * mat_offset;
// ProjQ
float* HWY_RESTRICT q =
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
MatVecLoop<kQKVDim, kModelDim>(
layer_weights->qkv_einsum_w, q_offset + 0 * kQKVDim * kModelDim, x,
activations.even_odd.data() + thread * kModelDim, q);
// ProjKV
float* HWY_RESTRICT q = qkv + head * kQKVDim * 3;
const size_t kv_offset = cache_pos * kCachePosSize +
layer * kCacheLayerSize + head * kQKVDim * 2;
float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset;
float* HWY_RESTRICT v = k + kQKVDim;
TwoOfsMatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w,
k_offset, v_offset, x, k, v);
Rope(k, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
Attn(head, head * kQKVDim * 2, thread);
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
Attn(q, head, head * kQKVDim * 2, thread);
});
} else {
// Multi-Query Attention
@ -790,7 +779,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
Attn(head, 0, thread);
Attn(q + head * kQKVDim, head, 0, thread);
});
}

View File

@ -92,45 +92,6 @@ HWY_INLINE constexpr size_t RowsPerStrip() {
return kRowsPerStrip;
}
// Simple version without tiling nor threading.
// even_odd is precomputed for the current thread.
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
typename VecT, typename AddT>
HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT vec_aligned,
const AddT* HWY_RESTRICT add,
float* HWY_RESTRICT even_odd,
float* HWY_RESTRICT out) {
PROFILER_ZONE("MatVecAddLoop");
const hn::ScalableTag<float> df;
// Sanity check: we can write without race conditions.
if (HWY_IS_TSAN) {
even_odd[0] = hwy::ConvertScalarTo<float>(vec_aligned[0]);
even_odd[kInner - 1] = -even_odd[0];
}
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
const size_t row_ofs = mat_ofs + idx_row * kInner;
if constexpr (kAdd) {
out[idx_row] = hwy::ConvertScalarTo<float>(add[idx_row]) +
Dot(df, mat, row_ofs, vec_aligned, kInner);
} else {
out[idx_row] = Dot(df, mat, row_ofs, vec_aligned, kInner);
}
}
}
// even_odd is precomputed for the current thread.
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT vec_aligned,
float* HWY_RESTRICT even_odd,
float* HWY_RESTRICT out) {
MatVecAddLoop</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
mat, mat_ofs, vec_aligned, /*add=*/nullptr, even_odd, out);
}
// Simple version without tiling nor threading, but two offsets/outputs.
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
typename VecT, typename AddT>
@ -159,18 +120,6 @@ HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0,
}
}
// Simple version without tiling nor threading, but two offsets/outputs.
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
HWY_INLINE void TwoOfsMatVecLoop(const ArrayT& mat, const size_t mat_ofs0,
const size_t mat_ofs1,
const VecT* HWY_RESTRICT vec_aligned,
float* HWY_RESTRICT out0,
float* HWY_RESTRICT out1) {
TwoOfsMatVecAddLoop</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
mat, mat_ofs0, mat_ofs1, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr,
out0, out1);
}
namespace detail {
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product

View File

@ -436,24 +436,6 @@ void TestMatVecAdd() {
AssertClose<kOuter>(actual_out, expected_out);
}
void TestMatVecAddLoop() {
constexpr size_t kOuter = 128 * 3;
constexpr size_t kInner = 128 * 5;
CompressedArray<float, kOuter * kInner> mat = GenerateMat<kOuter, kInner>(0);
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
hwy::AlignedFreeUniquePtr<float[]> add = GenerateVec<kOuter>(0);
hwy::AlignedFreeUniquePtr<float[]> even_odd =
hwy::AllocateAligned<float>(kInner);
hwy::AlignedFreeUniquePtr<float[]> expected_out =
SimpleMatVecAdd<kOuter, kInner>(mat, vec, add);
hwy::AlignedFreeUniquePtr<float[]> actual_out =
hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(vec && add && even_odd && expected_out && actual_out);
MatVecAddLoop<true, kOuter, kInner>(mat, 0, vec.get(), add.get(),
even_odd.get(), actual_out.get());
AssertClose<kOuter>(actual_out, expected_out);
}
void TestTwoMatVecAdd() {
hwy::ThreadPool pool(0);
constexpr size_t kOuter = 128 * 3;
@ -537,7 +519,6 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAddLoop);
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop);
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);