mirror of https://github.com/google/gemma.cpp.git
1.05x prefill speedup: matvec -> matmul for !MHA
Also add C_stride and make shape normal non-template arguments. PiperOrigin-RevId: 657285945
This commit is contained in:
parent
2721f54446
commit
f27683152c
|
|
@ -237,30 +237,45 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
|||
//
|
||||
// Compute Q only or QKV (if MHA).
|
||||
// If MHA, this also computes KV, which we copy to the KV cache below.
|
||||
const float scale = layer_weights->qkv_einsum_w.scale();
|
||||
MatMul_4x4<kModelDim, kHeads * kQStride, /*kAdd=*/false>(
|
||||
num_interleaved, activations.pre_att_rms_out.All(), 0,
|
||||
layer_weights->qkv_einsum_w.data(), 0, scale, activations.q.All(),
|
||||
/*add=*/nullptr, pool);
|
||||
MatMul_4x4</*kAdd=*/false>(num_interleaved, activations.pre_att_rms_out.All(),
|
||||
0, kModelDim, layer_weights->qkv_einsum_w.data(),
|
||||
0, kHeads * kQStride,
|
||||
layer_weights->qkv_einsum_w.scale(),
|
||||
activations.q.All(), kHeads * kQStride,
|
||||
/*add=*/nullptr, pool);
|
||||
|
||||
// Compute KV if not MHA.
|
||||
if constexpr (!kIsMHA) {
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
++interleaved_idx) {
|
||||
const float* x = activations.pre_att_rms_out.Batch(interleaved_idx);
|
||||
const size_t query_idx = interleaved_idx % num_queries;
|
||||
const size_t batch_idx = interleaved_idx / num_queries;
|
||||
KVCache& kv_cache = kv_caches[query_idx];
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
const size_t cache_pos = div_seq_len.Remainder(pos);
|
||||
const size_t kv_offset =
|
||||
cache_pos * kCachePosSize + layer * kCacheLayerSize;
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
// Single query and no wraparound means we can use a matmul and write
|
||||
// directly into the KV cache with a stride of kCachePosSize.
|
||||
if (num_queries == 1 &&
|
||||
batch_start + num_tokens <= div_seq_len.GetDivisor()) {
|
||||
const size_t colsBC = kKVHeads * 2 * kQKVDim;
|
||||
const size_t kv_ofs =
|
||||
batch_start * kCachePosSize + layer * kCacheLayerSize;
|
||||
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
|
||||
// TODO: requires batched KVCache support.
|
||||
MatVec<kKVHeads * 2 * kQKVDim, kModelDim>(
|
||||
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
|
||||
activations.even_odd.All(), kv, pool);
|
||||
float* HWY_RESTRICT kv = kv_caches[0].kv_cache.get() + kv_ofs;
|
||||
MatMul_4x4</*kAdd=*/false>(
|
||||
num_tokens, activations.pre_att_rms_out.All(), 0, kModelDim,
|
||||
layer_weights->qkv_einsum_w.data(), kHeads * kQKVDim * kModelDim,
|
||||
colsBC, layer_weights->qkv_einsum_w.scale(), kv, kCachePosSize,
|
||||
/*add=*/nullptr, pool);
|
||||
} else {
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
++interleaved_idx) {
|
||||
const float* x = activations.pre_att_rms_out.Batch(interleaved_idx);
|
||||
const size_t query_idx = interleaved_idx % num_queries;
|
||||
const size_t batch_idx = interleaved_idx / num_queries;
|
||||
KVCache& kv_cache = kv_caches[query_idx];
|
||||
const size_t cache_pos = div_seq_len.Remainder(batch_start + batch_idx);
|
||||
const size_t kv_offset =
|
||||
cache_pos * kCachePosSize + layer * kCacheLayerSize;
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
|
||||
MatVec<kKVHeads * 2 * kQKVDim, kModelDim>(
|
||||
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
|
||||
activations.even_odd.All(), kv, pool);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -427,7 +442,7 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
|
|||
// MatMul expects col-major B, which is what we have: kModelDim consecutive
|
||||
// elements in memory, repeated kFFHiddenDim times.
|
||||
constexpr size_t kColsA = kModelDim;
|
||||
constexpr size_t kColsB = kFFHiddenDim;
|
||||
constexpr size_t kColsBC = kFFHiddenDim;
|
||||
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
|
||||
const auto A = activations.bf_pre_ffw_rms_out.All();
|
||||
const float scale = layer_weights->gating_einsum_w.scale();
|
||||
|
|
@ -446,21 +461,21 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
|
|||
|
||||
const size_t A_ofs = 0; // no offset, using the same activations for both.
|
||||
// Will go through GELU.
|
||||
MatMul_4x4<kColsA, kColsB, kAddBias>(num_interleaved, A, A_ofs, B1,
|
||||
/*B_ofs=*/0, scale, C1, bias1, pool);
|
||||
MatMul_4x4<kAddBias>(num_interleaved, A, A_ofs, kColsA, B1,
|
||||
/*B_ofs=*/0, kColsBC, scale, C1, kColsBC, bias1, pool);
|
||||
// What to multiply by.
|
||||
MatMul_4x4<kColsA, kColsB, kAddBias>(num_interleaved, A, A_ofs, B1,
|
||||
/*B_ofs=*/kColsA * kColsB, scale, C2,
|
||||
bias2, pool);
|
||||
MatMul_4x4<kAddBias>(num_interleaved, A, A_ofs, kColsA, B1,
|
||||
/*B_ofs=*/kColsA * kColsBC, kColsBC, scale, C2, kColsBC,
|
||||
bias2, pool);
|
||||
|
||||
// Activation (Gelu) and multiply by gate. Store activations in C1.
|
||||
Activation<TConfig>(C1, C2, kFFHiddenDim * num_interleaved);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
MatMul_4x4<kFFHiddenDim, kModelDim, kAddBias>(
|
||||
num_interleaved, C1, 0, layer_weights->linear_w.data(), 0,
|
||||
layer_weights->linear_w.scale(), activations.ffw_out.All(), output_bias,
|
||||
pool);
|
||||
MatMul_4x4<kAddBias>(num_interleaved, C1, 0, kFFHiddenDim,
|
||||
layer_weights->linear_w.data(), 0, kModelDim,
|
||||
layer_weights->linear_w.scale(),
|
||||
activations.ffw_out.All(), kModelDim, output_bias, pool);
|
||||
}
|
||||
|
||||
// `batch_idx` indicates which row of `x` to write to.
|
||||
|
|
@ -932,6 +947,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
const MultiplePromptsTokens& prompts, const size_t pos,
|
||||
const size_t query_idx_start, const KVCaches& kv_caches,
|
||||
hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
const CompressedWeights<TConfig>& weights =
|
||||
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
|
||||
|
|
@ -1006,11 +1022,12 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
bool all_queries_eos = true;
|
||||
PROFILER_ZONE("Gen.Embedding");
|
||||
// Compute logits from last layer activations.
|
||||
MatMul_4x4<TConfig::kModelDim, kVocabSize, /*kAdd=*/false>(
|
||||
num_queries, activations.x.All(), 0,
|
||||
weights.embedder_input_embedding.data(), 0,
|
||||
weights.embedder_input_embedding.scale(), activations.logits.All(),
|
||||
/*add=*/nullptr, pool);
|
||||
MatMul_4x4</*kAdd=*/false>(num_queries, activations.x.All(), 0, kModelDim,
|
||||
weights.embedder_input_embedding.data(), 0,
|
||||
kVocabSize,
|
||||
weights.embedder_input_embedding.scale(),
|
||||
activations.logits.All(), kVocabSize,
|
||||
/*add=*/nullptr, pool);
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
|
||||
if constexpr (TConfig::kFinalCap > 0.0f) {
|
||||
|
|
|
|||
|
|
@ -378,69 +378,74 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const size_t A_ofs,
|
|||
//
|
||||
// If kAdd is true, the row-vector `add` is added to each row of C, otherwise
|
||||
// `add` is ignored and can be nullptr.
|
||||
// A is a row-major matrix of size (batch_size, kColsA_RowsB).
|
||||
// A is a row-major matrix of size (batch_size, colsA_rowsB).
|
||||
// B is passed transposed (column-major), so a matrix of size
|
||||
// (kColsBC, kColsA_RowsB), representing a B of size (kColsA_RowsB, kColsBC).
|
||||
// (colsBC, colsA_rowsB), representing a B of size (colsA_rowsB, colsBC).
|
||||
// A_ofs and B_ofs are offsets into A and B, respectively; they remain separate
|
||||
// from the pointers because some MatTA/B such as NuqStream do not support
|
||||
// pointer arithmetic.
|
||||
// C is a matrix of size (batch_size, kColsBC).
|
||||
// C is a row-major matrix of size (batch_size, colsBC), with `C_stride`
|
||||
// elements between rows, which is typically the same as `colsBC`. There is no
|
||||
// `C_ofs` because callers can simply add it to `C`.
|
||||
// The product is scaled by `scale` to support CompressedArray with scale != 1,
|
||||
// the caller can pass the product of the scales of A and B.
|
||||
// A scale for `add` is not supported, so make sure its scale is 1.
|
||||
// Typically batch_size is 1..512, kColsA_RowsB and kColsBC are 3k or 24k.
|
||||
template <size_t kColsA_RowsB, size_t kColsBC, bool kAdd, typename MatTA,
|
||||
typename MatTB, typename OutT>
|
||||
// Typically batch_size is 1..512, colsA_rowsB and colsBC are 3k or 24k.
|
||||
template <bool kAdd, typename MatTA, typename MatTB>
|
||||
HWY_NOINLINE void MatMul_4x4(const size_t batch_size,
|
||||
const MatTA* HWY_RESTRICT A, const size_t A_ofs,
|
||||
const size_t colsA_rowsB,
|
||||
const MatTB* HWY_RESTRICT B, const size_t B_ofs,
|
||||
const float scale, OutT* HWY_RESTRICT C,
|
||||
const size_t colsBC, const float scale,
|
||||
float* HWY_RESTRICT C, const size_t C_stride,
|
||||
const float* HWY_RESTRICT add,
|
||||
hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Matmul");
|
||||
// We currently write C directly, which touches more memory than fits in L3.
|
||||
// TODO: add another level of loops to finish L3-sized pieces of C at a time.
|
||||
const hn::ScalableTag<MatTA> d;
|
||||
const size_t N = Lanes(d);
|
||||
constexpr size_t kRegRows = 4;
|
||||
// Use float instead of MatTA/MatTB because we decompress to float here.
|
||||
const size_t Nf = hn::Lanes(hn::ScalableTag<float>());
|
||||
(void)Nf; // For HWY_DASSERT
|
||||
constexpr size_t kRegRows = 4; // if changing, also update the switch below.
|
||||
constexpr size_t kRegCols = 4; // in vectors
|
||||
|
||||
static_assert(kColsBC % kRegCols == 0);
|
||||
HWY_ASSERT(kColsA_RowsB % (N * kRegCols) == 0);
|
||||
const size_t kTilesY = (batch_size + kRegRows - 1) / kRegRows;
|
||||
const size_t kTilesX = kColsBC / kRegCols;
|
||||
const size_t kTiles = kTilesX * kTilesY;
|
||||
HWY_DASSERT(colsA_rowsB % (Nf * 2) == 0); // For Decompress2.
|
||||
HWY_DASSERT(colsBC % kRegCols == 0);
|
||||
const size_t tilesY = hwy::DivCeil(batch_size, kRegRows);
|
||||
const size_t tilesX = colsBC / kRegCols;
|
||||
|
||||
constexpr size_t kStrideA = kColsA_RowsB;
|
||||
constexpr size_t kStrideB = kColsA_RowsB;
|
||||
constexpr size_t kStrideC = kColsBC;
|
||||
const size_t strideA = colsA_rowsB;
|
||||
const size_t strideB = colsA_rowsB;
|
||||
|
||||
pool.Run(0, kTiles, [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
|
||||
// Computes the finished product of one 4x4N tile and writes to C.
|
||||
const size_t num_rows = batch_size - idx_tile / kTilesX * kRegRows;
|
||||
HWY_ASSERT(num_rows > 0);
|
||||
switch (num_rows) {
|
||||
case 1:
|
||||
GEMM_4x4_Tile<1, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile,
|
||||
kTilesX, kColsA_RowsB, kStrideA, kStrideB,
|
||||
kStrideC);
|
||||
break;
|
||||
case 2:
|
||||
GEMM_4x4_Tile<2, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile,
|
||||
kTilesX, kColsA_RowsB, kStrideA, kStrideB,
|
||||
kStrideC);
|
||||
break;
|
||||
case 3:
|
||||
GEMM_4x4_Tile<3, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile,
|
||||
kTilesX, kColsA_RowsB, kStrideA, kStrideB,
|
||||
kStrideC);
|
||||
break;
|
||||
default:
|
||||
GEMM_4x4_Tile<4, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile,
|
||||
kTilesX, kColsA_RowsB, kStrideA, kStrideB,
|
||||
kStrideC);
|
||||
}
|
||||
});
|
||||
pool.Run(0, tilesX * tilesY,
|
||||
[&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
|
||||
// How many rows of C are left to compute. If more than 4, this
|
||||
// tile still only computes 4 rows.
|
||||
const size_t num_rows = batch_size - idx_tile / tilesX * kRegRows;
|
||||
HWY_ASSERT(num_rows > 0);
|
||||
switch (num_rows) {
|
||||
case 1:
|
||||
GEMM_4x4_Tile<1, kAdd>(A, A_ofs, B, B_ofs, C, scale, add,
|
||||
idx_tile, tilesX, colsA_rowsB, strideA,
|
||||
strideB, C_stride);
|
||||
break;
|
||||
case 2:
|
||||
GEMM_4x4_Tile<2, kAdd>(A, A_ofs, B, B_ofs, C, scale, add,
|
||||
idx_tile, tilesX, colsA_rowsB, strideA,
|
||||
strideB, C_stride);
|
||||
break;
|
||||
case 3:
|
||||
GEMM_4x4_Tile<3, kAdd>(A, A_ofs, B, B_ofs, C, scale, add,
|
||||
idx_tile, tilesX, colsA_rowsB, strideA,
|
||||
strideB, C_stride);
|
||||
break;
|
||||
default:
|
||||
GEMM_4x4_Tile<4, kAdd>(A, A_ofs, B, B_ofs, C, scale, add,
|
||||
idx_tile, tilesX, colsA_rowsB, strideA,
|
||||
strideB, C_stride);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -301,8 +301,8 @@ void TestTiledBatchMatMul() {
|
|||
|
||||
const double start_tiled = hwy::platform::Now();
|
||||
EXPECT_EQ(scale, a->scale() * b_trans->scale());
|
||||
MatMul_4x4<kN, kK, kAdd>(kM, a->data(), 0, b_trans->data(), 0, scale, c.get(),
|
||||
add->data(), pool);
|
||||
MatMul_4x4<kAdd>(kM, a->data(), 0, kN, b_trans->data(), 0, kK, scale, c.get(),
|
||||
kK, add->data(), pool);
|
||||
const double tiled_matmul_seconds = hwy::platform::Now() - start_tiled;
|
||||
fprintf(stderr, "MatMul_4x4 took %f seconds.\n", tiled_matmul_seconds);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue