1.05x prefill speedup: matvec -> matmul for !MHA

Also add C_stride and make shape normal non-template arguments.

PiperOrigin-RevId: 657285945
This commit is contained in:
Jan Wassenberg 2024-07-29 12:17:33 -07:00 committed by Copybara-Service
parent 2721f54446
commit f27683152c
3 changed files with 102 additions and 80 deletions

View File

@ -237,30 +237,45 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
//
// Compute Q only or QKV (if MHA).
// If MHA, this also computes KV, which we copy to the KV cache below.
const float scale = layer_weights->qkv_einsum_w.scale();
MatMul_4x4<kModelDim, kHeads * kQStride, /*kAdd=*/false>(
num_interleaved, activations.pre_att_rms_out.All(), 0,
layer_weights->qkv_einsum_w.data(), 0, scale, activations.q.All(),
/*add=*/nullptr, pool);
MatMul_4x4</*kAdd=*/false>(num_interleaved, activations.pre_att_rms_out.All(),
0, kModelDim, layer_weights->qkv_einsum_w.data(),
0, kHeads * kQStride,
layer_weights->qkv_einsum_w.scale(),
activations.q.All(), kHeads * kQStride,
/*add=*/nullptr, pool);
// Compute KV if not MHA.
if constexpr (!kIsMHA) {
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const float* x = activations.pre_att_rms_out.Batch(interleaved_idx);
const size_t query_idx = interleaved_idx % num_queries;
const size_t batch_idx = interleaved_idx / num_queries;
KVCache& kv_cache = kv_caches[query_idx];
const size_t pos = batch_start + batch_idx;
const size_t cache_pos = div_seq_len.Remainder(pos);
const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
// Single query and no wraparound means we can use a matmul and write
// directly into the KV cache with a stride of kCachePosSize.
if (num_queries == 1 &&
batch_start + num_tokens <= div_seq_len.GetDivisor()) {
const size_t colsBC = kKVHeads * 2 * kQKVDim;
const size_t kv_ofs =
batch_start * kCachePosSize + layer * kCacheLayerSize;
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
// TODO: requires batched KVCache support.
MatVec<kKVHeads * 2 * kQKVDim, kModelDim>(
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
activations.even_odd.All(), kv, pool);
float* HWY_RESTRICT kv = kv_caches[0].kv_cache.get() + kv_ofs;
MatMul_4x4</*kAdd=*/false>(
num_tokens, activations.pre_att_rms_out.All(), 0, kModelDim,
layer_weights->qkv_einsum_w.data(), kHeads * kQKVDim * kModelDim,
colsBC, layer_weights->qkv_einsum_w.scale(), kv, kCachePosSize,
/*add=*/nullptr, pool);
} else {
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const float* x = activations.pre_att_rms_out.Batch(interleaved_idx);
const size_t query_idx = interleaved_idx % num_queries;
const size_t batch_idx = interleaved_idx / num_queries;
KVCache& kv_cache = kv_caches[query_idx];
const size_t cache_pos = div_seq_len.Remainder(batch_start + batch_idx);
const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
MatVec<kKVHeads * 2 * kQKVDim, kModelDim>(
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
activations.even_odd.All(), kv, pool);
}
}
}
@ -427,7 +442,7 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
// MatMul expects col-major B, which is what we have: kModelDim consecutive
// elements in memory, repeated kFFHiddenDim times.
constexpr size_t kColsA = kModelDim;
constexpr size_t kColsB = kFFHiddenDim;
constexpr size_t kColsBC = kFFHiddenDim;
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
const auto A = activations.bf_pre_ffw_rms_out.All();
const float scale = layer_weights->gating_einsum_w.scale();
@ -446,21 +461,21 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
const size_t A_ofs = 0; // no offset, using the same activations for both.
// Will go through GELU.
MatMul_4x4<kColsA, kColsB, kAddBias>(num_interleaved, A, A_ofs, B1,
/*B_ofs=*/0, scale, C1, bias1, pool);
MatMul_4x4<kAddBias>(num_interleaved, A, A_ofs, kColsA, B1,
/*B_ofs=*/0, kColsBC, scale, C1, kColsBC, bias1, pool);
// What to multiply by.
MatMul_4x4<kColsA, kColsB, kAddBias>(num_interleaved, A, A_ofs, B1,
/*B_ofs=*/kColsA * kColsB, scale, C2,
bias2, pool);
MatMul_4x4<kAddBias>(num_interleaved, A, A_ofs, kColsA, B1,
/*B_ofs=*/kColsA * kColsBC, kColsBC, scale, C2, kColsBC,
bias2, pool);
// Activation (Gelu) and multiply by gate. Store activations in C1.
Activation<TConfig>(C1, C2, kFFHiddenDim * num_interleaved);
// Hidden layer -> output layer.
MatMul_4x4<kFFHiddenDim, kModelDim, kAddBias>(
num_interleaved, C1, 0, layer_weights->linear_w.data(), 0,
layer_weights->linear_w.scale(), activations.ffw_out.All(), output_bias,
pool);
MatMul_4x4<kAddBias>(num_interleaved, C1, 0, kFFHiddenDim,
layer_weights->linear_w.data(), 0, kModelDim,
layer_weights->linear_w.scale(),
activations.ffw_out.All(), kModelDim, output_bias, pool);
}
// `batch_idx` indicates which row of `x` to write to.
@ -932,6 +947,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
const MultiplePromptsTokens& prompts, const size_t pos,
const size_t query_idx_start, const KVCaches& kv_caches,
hwy::ThreadPool& pool, TimingInfo& timing_info) {
constexpr size_t kModelDim = TConfig::kModelDim;
constexpr size_t kVocabSize = TConfig::kVocabSize;
const CompressedWeights<TConfig>& weights =
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
@ -1006,11 +1022,12 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
bool all_queries_eos = true;
PROFILER_ZONE("Gen.Embedding");
// Compute logits from last layer activations.
MatMul_4x4<TConfig::kModelDim, kVocabSize, /*kAdd=*/false>(
num_queries, activations.x.All(), 0,
weights.embedder_input_embedding.data(), 0,
weights.embedder_input_embedding.scale(), activations.logits.All(),
/*add=*/nullptr, pool);
MatMul_4x4</*kAdd=*/false>(num_queries, activations.x.All(), 0, kModelDim,
weights.embedder_input_embedding.data(), 0,
kVocabSize,
weights.embedder_input_embedding.scale(),
activations.logits.All(), kVocabSize,
/*add=*/nullptr, pool);
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
if constexpr (TConfig::kFinalCap > 0.0f) {

View File

@ -378,69 +378,74 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const size_t A_ofs,
//
// If kAdd is true, the row-vector `add` is added to each row of C, otherwise
// `add` is ignored and can be nullptr.
// A is a row-major matrix of size (batch_size, kColsA_RowsB).
// A is a row-major matrix of size (batch_size, colsA_rowsB).
// B is passed transposed (column-major), so a matrix of size
// (kColsBC, kColsA_RowsB), representing a B of size (kColsA_RowsB, kColsBC).
// (colsBC, colsA_rowsB), representing a B of size (colsA_rowsB, colsBC).
// A_ofs and B_ofs are offsets into A and B, respectively; they remain separate
// from the pointers because some MatTA/B such as NuqStream do not support
// pointer arithmetic.
// C is a matrix of size (batch_size, kColsBC).
// C is a row-major matrix of size (batch_size, colsBC), with `C_stride`
// elements between rows, which is typically the same as `colsBC`. There is no
// `C_ofs` because callers can simply add it to `C`.
// The product is scaled by `scale` to support CompressedArray with scale != 1,
// the caller can pass the product of the scales of A and B.
// A scale for `add` is not supported, so make sure its scale is 1.
// Typically batch_size is 1..512, kColsA_RowsB and kColsBC are 3k or 24k.
template <size_t kColsA_RowsB, size_t kColsBC, bool kAdd, typename MatTA,
typename MatTB, typename OutT>
// Typically batch_size is 1..512, colsA_rowsB and colsBC are 3k or 24k.
template <bool kAdd, typename MatTA, typename MatTB>
HWY_NOINLINE void MatMul_4x4(const size_t batch_size,
const MatTA* HWY_RESTRICT A, const size_t A_ofs,
const size_t colsA_rowsB,
const MatTB* HWY_RESTRICT B, const size_t B_ofs,
const float scale, OutT* HWY_RESTRICT C,
const size_t colsBC, const float scale,
float* HWY_RESTRICT C, const size_t C_stride,
const float* HWY_RESTRICT add,
hwy::ThreadPool& pool) {
PROFILER_ZONE("Matmul");
// We currently write C directly, which touches more memory than fits in L3.
// TODO: add another level of loops to finish L3-sized pieces of C at a time.
const hn::ScalableTag<MatTA> d;
const size_t N = Lanes(d);
constexpr size_t kRegRows = 4;
// Use float instead of MatTA/MatTB because we decompress to float here.
const size_t Nf = hn::Lanes(hn::ScalableTag<float>());
(void)Nf; // For HWY_DASSERT
constexpr size_t kRegRows = 4; // if changing, also update the switch below.
constexpr size_t kRegCols = 4; // in vectors
static_assert(kColsBC % kRegCols == 0);
HWY_ASSERT(kColsA_RowsB % (N * kRegCols) == 0);
const size_t kTilesY = (batch_size + kRegRows - 1) / kRegRows;
const size_t kTilesX = kColsBC / kRegCols;
const size_t kTiles = kTilesX * kTilesY;
HWY_DASSERT(colsA_rowsB % (Nf * 2) == 0); // For Decompress2.
HWY_DASSERT(colsBC % kRegCols == 0);
const size_t tilesY = hwy::DivCeil(batch_size, kRegRows);
const size_t tilesX = colsBC / kRegCols;
constexpr size_t kStrideA = kColsA_RowsB;
constexpr size_t kStrideB = kColsA_RowsB;
constexpr size_t kStrideC = kColsBC;
const size_t strideA = colsA_rowsB;
const size_t strideB = colsA_rowsB;
pool.Run(0, kTiles, [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
// Computes the finished product of one 4x4N tile and writes to C.
const size_t num_rows = batch_size - idx_tile / kTilesX * kRegRows;
HWY_ASSERT(num_rows > 0);
switch (num_rows) {
case 1:
GEMM_4x4_Tile<1, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile,
kTilesX, kColsA_RowsB, kStrideA, kStrideB,
kStrideC);
break;
case 2:
GEMM_4x4_Tile<2, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile,
kTilesX, kColsA_RowsB, kStrideA, kStrideB,
kStrideC);
break;
case 3:
GEMM_4x4_Tile<3, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile,
kTilesX, kColsA_RowsB, kStrideA, kStrideB,
kStrideC);
break;
default:
GEMM_4x4_Tile<4, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile,
kTilesX, kColsA_RowsB, kStrideA, kStrideB,
kStrideC);
}
});
pool.Run(0, tilesX * tilesY,
[&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
// How many rows of C are left to compute. If more than 4, this
// tile still only computes 4 rows.
const size_t num_rows = batch_size - idx_tile / tilesX * kRegRows;
HWY_ASSERT(num_rows > 0);
switch (num_rows) {
case 1:
GEMM_4x4_Tile<1, kAdd>(A, A_ofs, B, B_ofs, C, scale, add,
idx_tile, tilesX, colsA_rowsB, strideA,
strideB, C_stride);
break;
case 2:
GEMM_4x4_Tile<2, kAdd>(A, A_ofs, B, B_ofs, C, scale, add,
idx_tile, tilesX, colsA_rowsB, strideA,
strideB, C_stride);
break;
case 3:
GEMM_4x4_Tile<3, kAdd>(A, A_ofs, B, B_ofs, C, scale, add,
idx_tile, tilesX, colsA_rowsB, strideA,
strideB, C_stride);
break;
default:
GEMM_4x4_Tile<4, kAdd>(A, A_ofs, B, B_ofs, C, scale, add,
idx_tile, tilesX, colsA_rowsB, strideA,
strideB, C_stride);
}
});
}
//------------------------------------------------------------------------------

View File

@ -301,8 +301,8 @@ void TestTiledBatchMatMul() {
const double start_tiled = hwy::platform::Now();
EXPECT_EQ(scale, a->scale() * b_trans->scale());
MatMul_4x4<kN, kK, kAdd>(kM, a->data(), 0, b_trans->data(), 0, scale, c.get(),
add->data(), pool);
MatMul_4x4<kAdd>(kM, a->data(), 0, kN, b_trans->data(), 0, kK, scale, c.get(),
kK, add->data(), pool);
const double tiled_matmul_seconds = hwy::platform::Now() - start_tiled;
fprintf(stderr, "MatMul_4x4 took %f seconds.\n", tiled_matmul_seconds);