mirror of https://github.com/google/gemma.cpp.git
1.05x prefill speedup: matvec -> matmul for !MHA
Also add C_stride and make shape normal non-template arguments. PiperOrigin-RevId: 657285945
This commit is contained in:
parent
2721f54446
commit
f27683152c
|
|
@ -237,32 +237,47 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
||||||
//
|
//
|
||||||
// Compute Q only or QKV (if MHA).
|
// Compute Q only or QKV (if MHA).
|
||||||
// If MHA, this also computes KV, which we copy to the KV cache below.
|
// If MHA, this also computes KV, which we copy to the KV cache below.
|
||||||
const float scale = layer_weights->qkv_einsum_w.scale();
|
MatMul_4x4</*kAdd=*/false>(num_interleaved, activations.pre_att_rms_out.All(),
|
||||||
MatMul_4x4<kModelDim, kHeads * kQStride, /*kAdd=*/false>(
|
0, kModelDim, layer_weights->qkv_einsum_w.data(),
|
||||||
num_interleaved, activations.pre_att_rms_out.All(), 0,
|
0, kHeads * kQStride,
|
||||||
layer_weights->qkv_einsum_w.data(), 0, scale, activations.q.All(),
|
layer_weights->qkv_einsum_w.scale(),
|
||||||
|
activations.q.All(), kHeads * kQStride,
|
||||||
/*add=*/nullptr, pool);
|
/*add=*/nullptr, pool);
|
||||||
|
|
||||||
// Compute KV if not MHA.
|
// Compute KV if not MHA.
|
||||||
if constexpr (!kIsMHA) {
|
if constexpr (!kIsMHA) {
|
||||||
|
// Single query and no wraparound means we can use a matmul and write
|
||||||
|
// directly into the KV cache with a stride of kCachePosSize.
|
||||||
|
if (num_queries == 1 &&
|
||||||
|
batch_start + num_tokens <= div_seq_len.GetDivisor()) {
|
||||||
|
const size_t colsBC = kKVHeads * 2 * kQKVDim;
|
||||||
|
const size_t kv_ofs =
|
||||||
|
batch_start * kCachePosSize + layer * kCacheLayerSize;
|
||||||
|
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
|
||||||
|
float* HWY_RESTRICT kv = kv_caches[0].kv_cache.get() + kv_ofs;
|
||||||
|
MatMul_4x4</*kAdd=*/false>(
|
||||||
|
num_tokens, activations.pre_att_rms_out.All(), 0, kModelDim,
|
||||||
|
layer_weights->qkv_einsum_w.data(), kHeads * kQKVDim * kModelDim,
|
||||||
|
colsBC, layer_weights->qkv_einsum_w.scale(), kv, kCachePosSize,
|
||||||
|
/*add=*/nullptr, pool);
|
||||||
|
} else {
|
||||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||||
++interleaved_idx) {
|
++interleaved_idx) {
|
||||||
const float* x = activations.pre_att_rms_out.Batch(interleaved_idx);
|
const float* x = activations.pre_att_rms_out.Batch(interleaved_idx);
|
||||||
const size_t query_idx = interleaved_idx % num_queries;
|
const size_t query_idx = interleaved_idx % num_queries;
|
||||||
const size_t batch_idx = interleaved_idx / num_queries;
|
const size_t batch_idx = interleaved_idx / num_queries;
|
||||||
KVCache& kv_cache = kv_caches[query_idx];
|
KVCache& kv_cache = kv_caches[query_idx];
|
||||||
const size_t pos = batch_start + batch_idx;
|
const size_t cache_pos = div_seq_len.Remainder(batch_start + batch_idx);
|
||||||
const size_t cache_pos = div_seq_len.Remainder(pos);
|
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
cache_pos * kCachePosSize + layer * kCacheLayerSize;
|
cache_pos * kCachePosSize + layer * kCacheLayerSize;
|
||||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||||
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
|
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
|
||||||
// TODO: requires batched KVCache support.
|
|
||||||
MatVec<kKVHeads * 2 * kQKVDim, kModelDim>(
|
MatVec<kKVHeads * 2 * kQKVDim, kModelDim>(
|
||||||
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
|
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
|
||||||
activations.even_odd.All(), kv, pool);
|
activations.even_odd.All(), kv, pool);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Apply positional encodings for K (and copy KV to cache if MHA).
|
// Apply positional encodings for K (and copy KV to cache if MHA).
|
||||||
pool.Run(
|
pool.Run(
|
||||||
|
|
@ -427,7 +442,7 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
|
||||||
// MatMul expects col-major B, which is what we have: kModelDim consecutive
|
// MatMul expects col-major B, which is what we have: kModelDim consecutive
|
||||||
// elements in memory, repeated kFFHiddenDim times.
|
// elements in memory, repeated kFFHiddenDim times.
|
||||||
constexpr size_t kColsA = kModelDim;
|
constexpr size_t kColsA = kModelDim;
|
||||||
constexpr size_t kColsB = kFFHiddenDim;
|
constexpr size_t kColsBC = kFFHiddenDim;
|
||||||
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
|
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
|
||||||
const auto A = activations.bf_pre_ffw_rms_out.All();
|
const auto A = activations.bf_pre_ffw_rms_out.All();
|
||||||
const float scale = layer_weights->gating_einsum_w.scale();
|
const float scale = layer_weights->gating_einsum_w.scale();
|
||||||
|
|
@ -446,21 +461,21 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
|
||||||
|
|
||||||
const size_t A_ofs = 0; // no offset, using the same activations for both.
|
const size_t A_ofs = 0; // no offset, using the same activations for both.
|
||||||
// Will go through GELU.
|
// Will go through GELU.
|
||||||
MatMul_4x4<kColsA, kColsB, kAddBias>(num_interleaved, A, A_ofs, B1,
|
MatMul_4x4<kAddBias>(num_interleaved, A, A_ofs, kColsA, B1,
|
||||||
/*B_ofs=*/0, scale, C1, bias1, pool);
|
/*B_ofs=*/0, kColsBC, scale, C1, kColsBC, bias1, pool);
|
||||||
// What to multiply by.
|
// What to multiply by.
|
||||||
MatMul_4x4<kColsA, kColsB, kAddBias>(num_interleaved, A, A_ofs, B1,
|
MatMul_4x4<kAddBias>(num_interleaved, A, A_ofs, kColsA, B1,
|
||||||
/*B_ofs=*/kColsA * kColsB, scale, C2,
|
/*B_ofs=*/kColsA * kColsBC, kColsBC, scale, C2, kColsBC,
|
||||||
bias2, pool);
|
bias2, pool);
|
||||||
|
|
||||||
// Activation (Gelu) and multiply by gate. Store activations in C1.
|
// Activation (Gelu) and multiply by gate. Store activations in C1.
|
||||||
Activation<TConfig>(C1, C2, kFFHiddenDim * num_interleaved);
|
Activation<TConfig>(C1, C2, kFFHiddenDim * num_interleaved);
|
||||||
|
|
||||||
// Hidden layer -> output layer.
|
// Hidden layer -> output layer.
|
||||||
MatMul_4x4<kFFHiddenDim, kModelDim, kAddBias>(
|
MatMul_4x4<kAddBias>(num_interleaved, C1, 0, kFFHiddenDim,
|
||||||
num_interleaved, C1, 0, layer_weights->linear_w.data(), 0,
|
layer_weights->linear_w.data(), 0, kModelDim,
|
||||||
layer_weights->linear_w.scale(), activations.ffw_out.All(), output_bias,
|
layer_weights->linear_w.scale(),
|
||||||
pool);
|
activations.ffw_out.All(), kModelDim, output_bias, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
// `batch_idx` indicates which row of `x` to write to.
|
// `batch_idx` indicates which row of `x` to write to.
|
||||||
|
|
@ -932,6 +947,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
const MultiplePromptsTokens& prompts, const size_t pos,
|
const MultiplePromptsTokens& prompts, const size_t pos,
|
||||||
const size_t query_idx_start, const KVCaches& kv_caches,
|
const size_t query_idx_start, const KVCaches& kv_caches,
|
||||||
hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
||||||
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
constexpr size_t kVocabSize = TConfig::kVocabSize;
|
constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||||
const CompressedWeights<TConfig>& weights =
|
const CompressedWeights<TConfig>& weights =
|
||||||
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
|
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
|
||||||
|
|
@ -1006,10 +1022,11 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
bool all_queries_eos = true;
|
bool all_queries_eos = true;
|
||||||
PROFILER_ZONE("Gen.Embedding");
|
PROFILER_ZONE("Gen.Embedding");
|
||||||
// Compute logits from last layer activations.
|
// Compute logits from last layer activations.
|
||||||
MatMul_4x4<TConfig::kModelDim, kVocabSize, /*kAdd=*/false>(
|
MatMul_4x4</*kAdd=*/false>(num_queries, activations.x.All(), 0, kModelDim,
|
||||||
num_queries, activations.x.All(), 0,
|
|
||||||
weights.embedder_input_embedding.data(), 0,
|
weights.embedder_input_embedding.data(), 0,
|
||||||
weights.embedder_input_embedding.scale(), activations.logits.All(),
|
kVocabSize,
|
||||||
|
weights.embedder_input_embedding.scale(),
|
||||||
|
activations.logits.All(), kVocabSize,
|
||||||
/*add=*/nullptr, pool);
|
/*add=*/nullptr, pool);
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||||
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
|
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
|
||||||
|
|
|
||||||
|
|
@ -378,67 +378,72 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const size_t A_ofs,
|
||||||
//
|
//
|
||||||
// If kAdd is true, the row-vector `add` is added to each row of C, otherwise
|
// If kAdd is true, the row-vector `add` is added to each row of C, otherwise
|
||||||
// `add` is ignored and can be nullptr.
|
// `add` is ignored and can be nullptr.
|
||||||
// A is a row-major matrix of size (batch_size, kColsA_RowsB).
|
// A is a row-major matrix of size (batch_size, colsA_rowsB).
|
||||||
// B is passed transposed (column-major), so a matrix of size
|
// B is passed transposed (column-major), so a matrix of size
|
||||||
// (kColsBC, kColsA_RowsB), representing a B of size (kColsA_RowsB, kColsBC).
|
// (colsBC, colsA_rowsB), representing a B of size (colsA_rowsB, colsBC).
|
||||||
// A_ofs and B_ofs are offsets into A and B, respectively; they remain separate
|
// A_ofs and B_ofs are offsets into A and B, respectively; they remain separate
|
||||||
// from the pointers because some MatTA/B such as NuqStream do not support
|
// from the pointers because some MatTA/B such as NuqStream do not support
|
||||||
// pointer arithmetic.
|
// pointer arithmetic.
|
||||||
// C is a matrix of size (batch_size, kColsBC).
|
// C is a row-major matrix of size (batch_size, colsBC), with `C_stride`
|
||||||
|
// elements between rows, which is typically the same as `colsBC`. There is no
|
||||||
|
// `C_ofs` because callers can simply add it to `C`.
|
||||||
// The product is scaled by `scale` to support CompressedArray with scale != 1,
|
// The product is scaled by `scale` to support CompressedArray with scale != 1,
|
||||||
// the caller can pass the product of the scales of A and B.
|
// the caller can pass the product of the scales of A and B.
|
||||||
// A scale for `add` is not supported, so make sure its scale is 1.
|
// A scale for `add` is not supported, so make sure its scale is 1.
|
||||||
// Typically batch_size is 1..512, kColsA_RowsB and kColsBC are 3k or 24k.
|
// Typically batch_size is 1..512, colsA_rowsB and colsBC are 3k or 24k.
|
||||||
template <size_t kColsA_RowsB, size_t kColsBC, bool kAdd, typename MatTA,
|
template <bool kAdd, typename MatTA, typename MatTB>
|
||||||
typename MatTB, typename OutT>
|
|
||||||
HWY_NOINLINE void MatMul_4x4(const size_t batch_size,
|
HWY_NOINLINE void MatMul_4x4(const size_t batch_size,
|
||||||
const MatTA* HWY_RESTRICT A, const size_t A_ofs,
|
const MatTA* HWY_RESTRICT A, const size_t A_ofs,
|
||||||
|
const size_t colsA_rowsB,
|
||||||
const MatTB* HWY_RESTRICT B, const size_t B_ofs,
|
const MatTB* HWY_RESTRICT B, const size_t B_ofs,
|
||||||
const float scale, OutT* HWY_RESTRICT C,
|
const size_t colsBC, const float scale,
|
||||||
|
float* HWY_RESTRICT C, const size_t C_stride,
|
||||||
const float* HWY_RESTRICT add,
|
const float* HWY_RESTRICT add,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("Matmul");
|
PROFILER_ZONE("Matmul");
|
||||||
// We currently write C directly, which touches more memory than fits in L3.
|
// We currently write C directly, which touches more memory than fits in L3.
|
||||||
// TODO: add another level of loops to finish L3-sized pieces of C at a time.
|
// TODO: add another level of loops to finish L3-sized pieces of C at a time.
|
||||||
const hn::ScalableTag<MatTA> d;
|
const hn::ScalableTag<MatTA> d;
|
||||||
const size_t N = Lanes(d);
|
// Use float instead of MatTA/MatTB because we decompress to float here.
|
||||||
constexpr size_t kRegRows = 4;
|
const size_t Nf = hn::Lanes(hn::ScalableTag<float>());
|
||||||
|
(void)Nf; // For HWY_DASSERT
|
||||||
|
constexpr size_t kRegRows = 4; // if changing, also update the switch below.
|
||||||
constexpr size_t kRegCols = 4; // in vectors
|
constexpr size_t kRegCols = 4; // in vectors
|
||||||
|
|
||||||
static_assert(kColsBC % kRegCols == 0);
|
HWY_DASSERT(colsA_rowsB % (Nf * 2) == 0); // For Decompress2.
|
||||||
HWY_ASSERT(kColsA_RowsB % (N * kRegCols) == 0);
|
HWY_DASSERT(colsBC % kRegCols == 0);
|
||||||
const size_t kTilesY = (batch_size + kRegRows - 1) / kRegRows;
|
const size_t tilesY = hwy::DivCeil(batch_size, kRegRows);
|
||||||
const size_t kTilesX = kColsBC / kRegCols;
|
const size_t tilesX = colsBC / kRegCols;
|
||||||
const size_t kTiles = kTilesX * kTilesY;
|
|
||||||
|
|
||||||
constexpr size_t kStrideA = kColsA_RowsB;
|
const size_t strideA = colsA_rowsB;
|
||||||
constexpr size_t kStrideB = kColsA_RowsB;
|
const size_t strideB = colsA_rowsB;
|
||||||
constexpr size_t kStrideC = kColsBC;
|
|
||||||
|
|
||||||
pool.Run(0, kTiles, [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
|
pool.Run(0, tilesX * tilesY,
|
||||||
// Computes the finished product of one 4x4N tile and writes to C.
|
[&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
|
||||||
const size_t num_rows = batch_size - idx_tile / kTilesX * kRegRows;
|
// How many rows of C are left to compute. If more than 4, this
|
||||||
|
// tile still only computes 4 rows.
|
||||||
|
const size_t num_rows = batch_size - idx_tile / tilesX * kRegRows;
|
||||||
HWY_ASSERT(num_rows > 0);
|
HWY_ASSERT(num_rows > 0);
|
||||||
switch (num_rows) {
|
switch (num_rows) {
|
||||||
case 1:
|
case 1:
|
||||||
GEMM_4x4_Tile<1, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile,
|
GEMM_4x4_Tile<1, kAdd>(A, A_ofs, B, B_ofs, C, scale, add,
|
||||||
kTilesX, kColsA_RowsB, kStrideA, kStrideB,
|
idx_tile, tilesX, colsA_rowsB, strideA,
|
||||||
kStrideC);
|
strideB, C_stride);
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
GEMM_4x4_Tile<2, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile,
|
GEMM_4x4_Tile<2, kAdd>(A, A_ofs, B, B_ofs, C, scale, add,
|
||||||
kTilesX, kColsA_RowsB, kStrideA, kStrideB,
|
idx_tile, tilesX, colsA_rowsB, strideA,
|
||||||
kStrideC);
|
strideB, C_stride);
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
GEMM_4x4_Tile<3, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile,
|
GEMM_4x4_Tile<3, kAdd>(A, A_ofs, B, B_ofs, C, scale, add,
|
||||||
kTilesX, kColsA_RowsB, kStrideA, kStrideB,
|
idx_tile, tilesX, colsA_rowsB, strideA,
|
||||||
kStrideC);
|
strideB, C_stride);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GEMM_4x4_Tile<4, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile,
|
GEMM_4x4_Tile<4, kAdd>(A, A_ofs, B, B_ofs, C, scale, add,
|
||||||
kTilesX, kColsA_RowsB, kStrideA, kStrideB,
|
idx_tile, tilesX, colsA_rowsB, strideA,
|
||||||
kStrideC);
|
strideB, C_stride);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -301,8 +301,8 @@ void TestTiledBatchMatMul() {
|
||||||
|
|
||||||
const double start_tiled = hwy::platform::Now();
|
const double start_tiled = hwy::platform::Now();
|
||||||
EXPECT_EQ(scale, a->scale() * b_trans->scale());
|
EXPECT_EQ(scale, a->scale() * b_trans->scale());
|
||||||
MatMul_4x4<kN, kK, kAdd>(kM, a->data(), 0, b_trans->data(), 0, scale, c.get(),
|
MatMul_4x4<kAdd>(kM, a->data(), 0, kN, b_trans->data(), 0, kK, scale, c.get(),
|
||||||
add->data(), pool);
|
kK, add->data(), pool);
|
||||||
const double tiled_matmul_seconds = hwy::platform::Now() - start_tiled;
|
const double tiled_matmul_seconds = hwy::platform::Now() - start_tiled;
|
||||||
fprintf(stderr, "MatMul_4x4 took %f seconds.\n", tiled_matmul_seconds);
|
fprintf(stderr, "MatMul_4x4 took %f seconds.\n", tiled_matmul_seconds);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue