mirror of https://github.com/google/gemma.cpp.git
1.15x 7b sfp prefill speedup: Matmul in attention
2b bf16: prefill 114.456 -> 115.222 decode 16.8847 -> 16.9987 7b sfp: prefill 18.8575 -> 21.7325 decode 5.68428 -> 5.79791 PiperOrigin-RevId: 644283676
This commit is contained in:
parent
355f7b4f80
commit
a07f60c9a1
|
|
@ -419,28 +419,28 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
|
||||||
const float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
|
||||||
// QKV projections:
|
|
||||||
if constexpr (kHeads == kKVHeads) {
|
if constexpr (kHeads == kKVHeads) {
|
||||||
// Multi-Head Attention calculates qkv using q as scratch space.
|
// Multi-Head Attention calculates qkv using q as scratch space.
|
||||||
static_assert(TConfig::kInterleaveQKV);
|
static_assert(TConfig::kInterleaveQKV);
|
||||||
float* HWY_RESTRICT qkv =
|
MatMul_4x4_Batch<kModelDim, kHeads * kQKVDim * 3>(
|
||||||
activations.q.data() + batch_idx * kHeads * kQKVDim * 3;
|
num_tokens, activations.pre_att_rms_out.data(),
|
||||||
MatVec<kHeads * kQKVDim * 3, kModelDim>(layer_weights->qkv_einsum_w, 0, x,
|
layer_weights->qkv_einsum_w.data(), activations.q.data(), pool);
|
||||||
activations.even_odd.data(), qkv,
|
|
||||||
pool);
|
|
||||||
} else {
|
} else {
|
||||||
const size_t pos = batch_start + batch_idx;
|
MatMul_4x4_Batch<kModelDim, kHeads * kQKVDim>(
|
||||||
float* HWY_RESTRICT q =
|
num_tokens, activations.pre_att_rms_out.data(),
|
||||||
activations.q.data() + batch_idx * kHeads * kQKVDim;
|
layer_weights->qkv_einsum_w.data(), activations.q.data(), pool);
|
||||||
MatVec<kHeads * kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, 0, x,
|
}
|
||||||
activations.even_odd.data(), q, pool);
|
|
||||||
|
|
||||||
|
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||||
|
const float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||||
|
// QKV projections:
|
||||||
|
if constexpr (kHeads != kKVHeads) {
|
||||||
|
const size_t pos = batch_start + batch_idx;
|
||||||
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
cache_pos * kCachePosSize + layer * kCacheLayerSize;
|
cache_pos * kCachePosSize + layer * kCacheLayerSize;
|
||||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||||
|
// TODO: requires MatMul support for offsets.
|
||||||
MatVec<kKVHeads * kQKVDim * 2, kModelDim>(
|
MatVec<kKVHeads * kQKVDim * 2, kModelDim>(
|
||||||
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
|
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
|
||||||
activations.even_odd.data(), kv, pool);
|
activations.even_odd.data(), kv, pool);
|
||||||
|
|
@ -494,6 +494,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
||||||
for (size_t head = 1; head < kHeads; ++head) {
|
for (size_t head = 1; head < kHeads; ++head) {
|
||||||
float* HWY_RESTRICT head_out =
|
float* HWY_RESTRICT head_out =
|
||||||
activations.att_post1.data() + head * kBatchSize * kModelDim;
|
activations.att_post1.data() + head * kBatchSize * kModelDim;
|
||||||
|
// TODO: requires MatMul support for offsets.
|
||||||
MatVec<kModelDim, kQKVDim>(
|
MatVec<kModelDim, kQKVDim>(
|
||||||
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim,
|
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim,
|
||||||
att_out + head * kQKVDim,
|
att_out + head * kQKVDim,
|
||||||
|
|
|
||||||
|
|
@ -881,6 +881,7 @@ template <size_t kColsA_RowsB, size_t kColsBC, bool kAdd, typename MatTA,
|
||||||
HWY_NOINLINE void MatMul_4x4_Batch_Add(
|
HWY_NOINLINE void MatMul_4x4_Batch_Add(
|
||||||
size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B,
|
size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B,
|
||||||
OutT* HWY_RESTRICT C, const AddT* add, hwy::ThreadPool& pool) {
|
OutT* HWY_RESTRICT C, const AddT* add, hwy::ThreadPool& pool) {
|
||||||
|
PROFILER_ZONE("Matmul");
|
||||||
// Process reg-sized tiles of C in parallel. We currently write C directly,
|
// Process reg-sized tiles of C in parallel. We currently write C directly,
|
||||||
// which touches more memory than fits in L3. TODO: add another level of loops
|
// which touches more memory than fits in L3. TODO: add another level of loops
|
||||||
// so that we finish one L3-sized piece of C at a time.
|
// so that we finish one L3-sized piece of C at a time.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue