mirror of https://github.com/google/gemma.cpp.git
1.15x 7b sfp prefill speedup: Matmul in attention
2b bf16: prefill 114.456 -> 115.222 decode 16.8847 -> 16.9987 7b sfp: prefill 18.8575 -> 21.7325 decode 5.68428 -> 5.79791 PiperOrigin-RevId: 644283676
This commit is contained in:
parent
355f7b4f80
commit
a07f60c9a1
|
|
@ -419,28 +419,28 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
|||
}
|
||||
};
|
||||
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||
// QKV projections:
|
||||
if constexpr (kHeads == kKVHeads) {
|
||||
// Multi-Head Attention calculates qkv using q as scratch space.
|
||||
static_assert(TConfig::kInterleaveQKV);
|
||||
float* HWY_RESTRICT qkv =
|
||||
activations.q.data() + batch_idx * kHeads * kQKVDim * 3;
|
||||
MatVec<kHeads * kQKVDim * 3, kModelDim>(layer_weights->qkv_einsum_w, 0, x,
|
||||
activations.even_odd.data(), qkv,
|
||||
pool);
|
||||
MatMul_4x4_Batch<kModelDim, kHeads * kQKVDim * 3>(
|
||||
num_tokens, activations.pre_att_rms_out.data(),
|
||||
layer_weights->qkv_einsum_w.data(), activations.q.data(), pool);
|
||||
} else {
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.data() + batch_idx * kHeads * kQKVDim;
|
||||
MatVec<kHeads * kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, 0, x,
|
||||
activations.even_odd.data(), q, pool);
|
||||
MatMul_4x4_Batch<kModelDim, kHeads * kQKVDim>(
|
||||
num_tokens, activations.pre_att_rms_out.data(),
|
||||
layer_weights->qkv_einsum_w.data(), activations.q.data(), pool);
|
||||
}
|
||||
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||
// QKV projections:
|
||||
if constexpr (kHeads != kKVHeads) {
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
||||
const size_t kv_offset =
|
||||
cache_pos * kCachePosSize + layer * kCacheLayerSize;
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
// TODO: requires MatMul support for offsets.
|
||||
MatVec<kKVHeads * kQKVDim * 2, kModelDim>(
|
||||
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
|
||||
activations.even_odd.data(), kv, pool);
|
||||
|
|
@ -494,6 +494,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
|||
for (size_t head = 1; head < kHeads; ++head) {
|
||||
float* HWY_RESTRICT head_out =
|
||||
activations.att_post1.data() + head * kBatchSize * kModelDim;
|
||||
// TODO: requires MatMul support for offsets.
|
||||
MatVec<kModelDim, kQKVDim>(
|
||||
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim,
|
||||
att_out + head * kQKVDim,
|
||||
|
|
|
|||
|
|
@ -881,6 +881,7 @@ template <size_t kColsA_RowsB, size_t kColsBC, bool kAdd, typename MatTA,
|
|||
HWY_NOINLINE void MatMul_4x4_Batch_Add(
|
||||
size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B,
|
||||
OutT* HWY_RESTRICT C, const AddT* add, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Matmul");
|
||||
// Process reg-sized tiles of C in parallel. We currently write C directly,
|
||||
// which touches more memory than fits in L3. TODO: add another level of loops
|
||||
// so that we finish one L3-sized piece of C at a time.
|
||||
|
|
|
|||
Loading…
Reference in New Issue