mirror of https://github.com/google/gemma.cpp.git
Use more parallelism in the QKV projections of the MHA block.
We compute all three projections with one MatVec and then copy
the kv part to the cache.
Benchmark results for 7b-it model that uses MHA blocks (summarization with
1600 tokens for prefill and essay writing with 500 tokens for generation):
```
Prefill speed Generation speed
Num threads BEFORE AFTER BEFORE AFTER
32 13.75 t/s 14.80 t/s 9.22 t/s 9.77 t/s
64 19.89 t/s 24.83 t/s 12.46 t/s 13.66 t/s
```
This commit is contained in:
parent
bafb8382f8
commit
9a2682d544
|
|
@ -402,10 +402,11 @@ struct Activations {
|
|||
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2;
|
||||
static constexpr size_t kCachePosSize =
|
||||
TConfig::kGemmaLayers * kCacheLayerSize;
|
||||
static constexpr size_t kQDim = kHeads == kKVHeads ? kQKVDim * 3 : kQKVDim;
|
||||
|
||||
std::array<float, kBatchSize * kModelDim> x; // input
|
||||
std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
|
||||
std::array<float, kBatchSize * kHeads * kQKVDim> q; // query vector
|
||||
std::array<float, kBatchSize * kHeads * kQDim> q; // query vector
|
||||
std::array<float, kBatchSize * kHeads * TConfig::kSeqLen>
|
||||
att; // attention vector
|
||||
std::array<float, kBatchSize * kHeads * kQKVDim> att_out; // attention output
|
||||
|
|
@ -710,10 +711,9 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
|
||||
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||
|
||||
auto Attn = [&](uint64_t head, size_t head_offset, size_t thread) HWY_ATTR {
|
||||
auto Attn = [&](float* q, uint64_t head, size_t head_offset,
|
||||
size_t thread) HWY_ATTR {
|
||||
// Calculate scores
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
|
||||
float* HWY_RESTRICT head_att = activations.att.data() +
|
||||
head * TConfig::kSeqLen +
|
||||
batch_idx * kHeads * kQKVDim;
|
||||
|
|
@ -745,34 +745,23 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
|
||||
if constexpr (kHeads == kKVHeads) {
|
||||
// Multi-Head Attention
|
||||
static_assert(TConfig::kInterleaveQKV);
|
||||
|
||||
float* HWY_RESTRICT qkv =
|
||||
activations.q.data() + batch_idx * kHeads * kQKVDim * 3;
|
||||
MatVec<kHeads * kQKVDim * 3, kModelDim>(
|
||||
layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv,
|
||||
pool);
|
||||
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
|
||||
// linear projections to QKV
|
||||
const size_t head_offset = TConfig::kInterleaveQKV
|
||||
? 3 * kQKVDim * kModelDim
|
||||
: kQKVDim * kModelDim;
|
||||
const size_t mat_offset =
|
||||
TConfig::kInterleaveQKV ? kQKVDim * kModelDim : kModelDim * kModelDim;
|
||||
const size_t q_offset = head * head_offset + 0 * mat_offset;
|
||||
const size_t k_offset = head * head_offset + 1 * mat_offset;
|
||||
const size_t v_offset = head * head_offset + 2 * mat_offset;
|
||||
|
||||
// ProjQ
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
|
||||
MatVecLoop<kQKVDim, kModelDim>(
|
||||
layer_weights->qkv_einsum_w, q_offset + 0 * kQKVDim * kModelDim, x,
|
||||
activations.even_odd.data() + thread * kModelDim, q);
|
||||
|
||||
// ProjKV
|
||||
float* HWY_RESTRICT q = qkv + head * kQKVDim * 3;
|
||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||
layer * kCacheLayerSize + head * kQKVDim * 2;
|
||||
float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset;
|
||||
float* HWY_RESTRICT v = k + kQKVDim;
|
||||
TwoOfsMatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w,
|
||||
k_offset, v_offset, x, k, v);
|
||||
Rope(k, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
|
||||
Attn(head, head * kQKVDim * 2, thread);
|
||||
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
|
||||
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
Attn(q, head, head * kQKVDim * 2, thread);
|
||||
});
|
||||
} else {
|
||||
// Multi-Query Attention
|
||||
|
|
@ -790,7 +779,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
|
||||
Attn(head, 0, thread);
|
||||
Attn(q + head * kQKVDim, head, 0, thread);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
51
gemma/ops.h
51
gemma/ops.h
|
|
@ -92,45 +92,6 @@ HWY_INLINE constexpr size_t RowsPerStrip() {
|
|||
return kRowsPerStrip;
|
||||
}
|
||||
|
||||
// Simple version without tiling nor threading.
|
||||
// even_odd is precomputed for the current thread.
|
||||
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
||||
typename VecT, typename AddT>
|
||||
HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
const AddT* HWY_RESTRICT add,
|
||||
float* HWY_RESTRICT even_odd,
|
||||
float* HWY_RESTRICT out) {
|
||||
PROFILER_ZONE("MatVecAddLoop");
|
||||
const hn::ScalableTag<float> df;
|
||||
|
||||
// Sanity check: we can write without race conditions.
|
||||
if (HWY_IS_TSAN) {
|
||||
even_odd[0] = hwy::ConvertScalarTo<float>(vec_aligned[0]);
|
||||
even_odd[kInner - 1] = -even_odd[0];
|
||||
}
|
||||
|
||||
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
|
||||
const size_t row_ofs = mat_ofs + idx_row * kInner;
|
||||
if constexpr (kAdd) {
|
||||
out[idx_row] = hwy::ConvertScalarTo<float>(add[idx_row]) +
|
||||
Dot(df, mat, row_ofs, vec_aligned, kInner);
|
||||
} else {
|
||||
out[idx_row] = Dot(df, mat, row_ofs, vec_aligned, kInner);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// even_odd is precomputed for the current thread.
|
||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
||||
HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
float* HWY_RESTRICT even_odd,
|
||||
float* HWY_RESTRICT out) {
|
||||
MatVecAddLoop</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
|
||||
mat, mat_ofs, vec_aligned, /*add=*/nullptr, even_odd, out);
|
||||
}
|
||||
|
||||
// Simple version without tiling nor threading, but two offsets/outputs.
|
||||
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
||||
typename VecT, typename AddT>
|
||||
|
|
@ -159,18 +120,6 @@ HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0,
|
|||
}
|
||||
}
|
||||
|
||||
// Simple version without tiling nor threading, but two offsets/outputs.
|
||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
||||
HWY_INLINE void TwoOfsMatVecLoop(const ArrayT& mat, const size_t mat_ofs0,
|
||||
const size_t mat_ofs1,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
float* HWY_RESTRICT out0,
|
||||
float* HWY_RESTRICT out1) {
|
||||
TwoOfsMatVecAddLoop</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
|
||||
mat, mat_ofs0, mat_ofs1, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr,
|
||||
out0, out1);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product
|
||||
|
|
|
|||
|
|
@ -436,24 +436,6 @@ void TestMatVecAdd() {
|
|||
AssertClose<kOuter>(actual_out, expected_out);
|
||||
}
|
||||
|
||||
void TestMatVecAddLoop() {
|
||||
constexpr size_t kOuter = 128 * 3;
|
||||
constexpr size_t kInner = 128 * 5;
|
||||
CompressedArray<float, kOuter * kInner> mat = GenerateMat<kOuter, kInner>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> add = GenerateVec<kOuter>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> even_odd =
|
||||
hwy::AllocateAligned<float>(kInner);
|
||||
hwy::AlignedFreeUniquePtr<float[]> expected_out =
|
||||
SimpleMatVecAdd<kOuter, kInner>(mat, vec, add);
|
||||
hwy::AlignedFreeUniquePtr<float[]> actual_out =
|
||||
hwy::AllocateAligned<float>(kOuter);
|
||||
HWY_ASSERT(vec && add && even_odd && expected_out && actual_out);
|
||||
MatVecAddLoop<true, kOuter, kInner>(mat, 0, vec.get(), add.get(),
|
||||
even_odd.get(), actual_out.get());
|
||||
AssertClose<kOuter>(actual_out, expected_out);
|
||||
}
|
||||
|
||||
void TestTwoMatVecAdd() {
|
||||
hwy::ThreadPool pool(0);
|
||||
constexpr size_t kOuter = 128 * 3;
|
||||
|
|
@ -537,7 +519,6 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
|||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAddLoop);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);
|
||||
|
|
|
|||
Loading…
Reference in New Issue