From f27683152cab3247f8ba563ec5b213a684e8d110 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Mon, 29 Jul 2024 12:17:33 -0700 Subject: [PATCH] 1.05x prefill speedup: matvec -> matmul for !MHA Also add C_stride and make shape normal non-template arguments. PiperOrigin-RevId: 657285945 --- gemma/gemma-inl.h | 87 ++++++++++++++++++++++++++------------------ ops/matmul-inl.h | 91 ++++++++++++++++++++++++---------------------- ops/matmul_test.cc | 4 +- 3 files changed, 102 insertions(+), 80 deletions(-) diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index aa28bbc..e5cc768 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -237,30 +237,45 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, // // Compute Q only or QKV (if MHA). // If MHA, this also computes KV, which we copy to the KV cache below. - const float scale = layer_weights->qkv_einsum_w.scale(); - MatMul_4x4( - num_interleaved, activations.pre_att_rms_out.All(), 0, - layer_weights->qkv_einsum_w.data(), 0, scale, activations.q.All(), - /*add=*/nullptr, pool); + MatMul_4x4(num_interleaved, activations.pre_att_rms_out.All(), + 0, kModelDim, layer_weights->qkv_einsum_w.data(), + 0, kHeads * kQStride, + layer_weights->qkv_einsum_w.scale(), + activations.q.All(), kHeads * kQStride, + /*add=*/nullptr, pool); // Compute KV if not MHA. if constexpr (!kIsMHA) { - for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; - ++interleaved_idx) { - const float* x = activations.pre_att_rms_out.Batch(interleaved_idx); - const size_t query_idx = interleaved_idx % num_queries; - const size_t batch_idx = interleaved_idx / num_queries; - KVCache& kv_cache = kv_caches[query_idx]; - const size_t pos = batch_start + batch_idx; - const size_t cache_pos = div_seq_len.Remainder(pos); - const size_t kv_offset = - cache_pos * kCachePosSize + layer * kCacheLayerSize; - float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; + // Single query and no wraparound means we can use a matmul and write + // directly into the KV cache with a stride of kCachePosSize. + if (num_queries == 1 && + batch_start + num_tokens <= div_seq_len.GetDivisor()) { + const size_t colsBC = kKVHeads * 2 * kQKVDim; + const size_t kv_ofs = + batch_start * kCachePosSize + layer * kCacheLayerSize; // KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v). - // TODO: requires batched KVCache support. - MatVec( - layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x, - activations.even_odd.All(), kv, pool); + float* HWY_RESTRICT kv = kv_caches[0].kv_cache.get() + kv_ofs; + MatMul_4x4( + num_tokens, activations.pre_att_rms_out.All(), 0, kModelDim, + layer_weights->qkv_einsum_w.data(), kHeads * kQKVDim * kModelDim, + colsBC, layer_weights->qkv_einsum_w.scale(), kv, kCachePosSize, + /*add=*/nullptr, pool); + } else { + for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; + ++interleaved_idx) { + const float* x = activations.pre_att_rms_out.Batch(interleaved_idx); + const size_t query_idx = interleaved_idx % num_queries; + const size_t batch_idx = interleaved_idx / num_queries; + KVCache& kv_cache = kv_caches[query_idx]; + const size_t cache_pos = div_seq_len.Remainder(batch_start + batch_idx); + const size_t kv_offset = + cache_pos * kCachePosSize + layer * kCacheLayerSize; + float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; + // KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v). + MatVec( + layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x, + activations.even_odd.All(), kv, pool); + } } } @@ -427,7 +442,7 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved, // MatMul expects col-major B, which is what we have: kModelDim consecutive // elements in memory, repeated kFFHiddenDim times. constexpr size_t kColsA = kModelDim; - constexpr size_t kColsB = kFFHiddenDim; + constexpr size_t kColsBC = kFFHiddenDim; HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize()); const auto A = activations.bf_pre_ffw_rms_out.All(); const float scale = layer_weights->gating_einsum_w.scale(); @@ -446,21 +461,21 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved, const size_t A_ofs = 0; // no offset, using the same activations for both. // Will go through GELU. - MatMul_4x4(num_interleaved, A, A_ofs, B1, - /*B_ofs=*/0, scale, C1, bias1, pool); + MatMul_4x4(num_interleaved, A, A_ofs, kColsA, B1, + /*B_ofs=*/0, kColsBC, scale, C1, kColsBC, bias1, pool); // What to multiply by. - MatMul_4x4(num_interleaved, A, A_ofs, B1, - /*B_ofs=*/kColsA * kColsB, scale, C2, - bias2, pool); + MatMul_4x4(num_interleaved, A, A_ofs, kColsA, B1, + /*B_ofs=*/kColsA * kColsBC, kColsBC, scale, C2, kColsBC, + bias2, pool); // Activation (Gelu) and multiply by gate. Store activations in C1. Activation(C1, C2, kFFHiddenDim * num_interleaved); // Hidden layer -> output layer. - MatMul_4x4( - num_interleaved, C1, 0, layer_weights->linear_w.data(), 0, - layer_weights->linear_w.scale(), activations.ffw_out.All(), output_bias, - pool); + MatMul_4x4(num_interleaved, C1, 0, kFFHiddenDim, + layer_weights->linear_w.data(), 0, kModelDim, + layer_weights->linear_w.scale(), + activations.ffw_out.All(), kModelDim, output_bias, pool); } // `batch_idx` indicates which row of `x` to write to. @@ -932,6 +947,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, const MultiplePromptsTokens& prompts, const size_t pos, const size_t query_idx_start, const KVCaches& kv_caches, hwy::ThreadPool& pool, TimingInfo& timing_info) { + constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kVocabSize = TConfig::kVocabSize; const CompressedWeights& weights = *reinterpret_cast*>(weights_u8.get()); @@ -1006,11 +1022,12 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, bool all_queries_eos = true; PROFILER_ZONE("Gen.Embedding"); // Compute logits from last layer activations. - MatMul_4x4( - num_queries, activations.x.All(), 0, - weights.embedder_input_embedding.data(), 0, - weights.embedder_input_embedding.scale(), activations.logits.All(), - /*add=*/nullptr, pool); + MatMul_4x4(num_queries, activations.x.All(), 0, kModelDim, + weights.embedder_input_embedding.data(), 0, + kVocabSize, + weights.embedder_input_embedding.scale(), + activations.logits.All(), kVocabSize, + /*add=*/nullptr, pool); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); if constexpr (TConfig::kFinalCap > 0.0f) { diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 0e368e6..ea7f03e 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -378,69 +378,74 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const size_t A_ofs, // // If kAdd is true, the row-vector `add` is added to each row of C, otherwise // `add` is ignored and can be nullptr. -// A is a row-major matrix of size (batch_size, kColsA_RowsB). +// A is a row-major matrix of size (batch_size, colsA_rowsB). // B is passed transposed (column-major), so a matrix of size -// (kColsBC, kColsA_RowsB), representing a B of size (kColsA_RowsB, kColsBC). +// (colsBC, colsA_rowsB), representing a B of size (colsA_rowsB, colsBC). // A_ofs and B_ofs are offsets into A and B, respectively; they remain separate // from the pointers because some MatTA/B such as NuqStream do not support // pointer arithmetic. -// C is a matrix of size (batch_size, kColsBC). +// C is a row-major matrix of size (batch_size, colsBC), with `C_stride` +// elements between rows, which is typically the same as `colsBC`. There is no +// `C_ofs` because callers can simply add it to `C`. // The product is scaled by `scale` to support CompressedArray with scale != 1, // the caller can pass the product of the scales of A and B. // A scale for `add` is not supported, so make sure its scale is 1. -// Typically batch_size is 1..512, kColsA_RowsB and kColsBC are 3k or 24k. -template +// Typically batch_size is 1..512, colsA_rowsB and colsBC are 3k or 24k. +template HWY_NOINLINE void MatMul_4x4(const size_t batch_size, const MatTA* HWY_RESTRICT A, const size_t A_ofs, + const size_t colsA_rowsB, const MatTB* HWY_RESTRICT B, const size_t B_ofs, - const float scale, OutT* HWY_RESTRICT C, + const size_t colsBC, const float scale, + float* HWY_RESTRICT C, const size_t C_stride, const float* HWY_RESTRICT add, hwy::ThreadPool& pool) { PROFILER_ZONE("Matmul"); // We currently write C directly, which touches more memory than fits in L3. // TODO: add another level of loops to finish L3-sized pieces of C at a time. const hn::ScalableTag d; - const size_t N = Lanes(d); - constexpr size_t kRegRows = 4; + // Use float instead of MatTA/MatTB because we decompress to float here. + const size_t Nf = hn::Lanes(hn::ScalableTag()); + (void)Nf; // For HWY_DASSERT + constexpr size_t kRegRows = 4; // if changing, also update the switch below. constexpr size_t kRegCols = 4; // in vectors - static_assert(kColsBC % kRegCols == 0); - HWY_ASSERT(kColsA_RowsB % (N * kRegCols) == 0); - const size_t kTilesY = (batch_size + kRegRows - 1) / kRegRows; - const size_t kTilesX = kColsBC / kRegCols; - const size_t kTiles = kTilesX * kTilesY; + HWY_DASSERT(colsA_rowsB % (Nf * 2) == 0); // For Decompress2. + HWY_DASSERT(colsBC % kRegCols == 0); + const size_t tilesY = hwy::DivCeil(batch_size, kRegRows); + const size_t tilesX = colsBC / kRegCols; - constexpr size_t kStrideA = kColsA_RowsB; - constexpr size_t kStrideB = kColsA_RowsB; - constexpr size_t kStrideC = kColsBC; + const size_t strideA = colsA_rowsB; + const size_t strideB = colsA_rowsB; - pool.Run(0, kTiles, [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR { - // Computes the finished product of one 4x4N tile and writes to C. - const size_t num_rows = batch_size - idx_tile / kTilesX * kRegRows; - HWY_ASSERT(num_rows > 0); - switch (num_rows) { - case 1: - GEMM_4x4_Tile<1, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile, - kTilesX, kColsA_RowsB, kStrideA, kStrideB, - kStrideC); - break; - case 2: - GEMM_4x4_Tile<2, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile, - kTilesX, kColsA_RowsB, kStrideA, kStrideB, - kStrideC); - break; - case 3: - GEMM_4x4_Tile<3, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile, - kTilesX, kColsA_RowsB, kStrideA, kStrideB, - kStrideC); - break; - default: - GEMM_4x4_Tile<4, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile, - kTilesX, kColsA_RowsB, kStrideA, kStrideB, - kStrideC); - } - }); + pool.Run(0, tilesX * tilesY, + [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR { + // How many rows of C are left to compute. If more than 4, this + // tile still only computes 4 rows. + const size_t num_rows = batch_size - idx_tile / tilesX * kRegRows; + HWY_ASSERT(num_rows > 0); + switch (num_rows) { + case 1: + GEMM_4x4_Tile<1, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, + idx_tile, tilesX, colsA_rowsB, strideA, + strideB, C_stride); + break; + case 2: + GEMM_4x4_Tile<2, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, + idx_tile, tilesX, colsA_rowsB, strideA, + strideB, C_stride); + break; + case 3: + GEMM_4x4_Tile<3, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, + idx_tile, tilesX, colsA_rowsB, strideA, + strideB, C_stride); + break; + default: + GEMM_4x4_Tile<4, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, + idx_tile, tilesX, colsA_rowsB, strideA, + strideB, C_stride); + } + }); } //------------------------------------------------------------------------------ diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index ec9d390..4885347 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -301,8 +301,8 @@ void TestTiledBatchMatMul() { const double start_tiled = hwy::platform::Now(); EXPECT_EQ(scale, a->scale() * b_trans->scale()); - MatMul_4x4(kM, a->data(), 0, b_trans->data(), 0, scale, c.get(), - add->data(), pool); + MatMul_4x4(kM, a->data(), 0, kN, b_trans->data(), 0, kK, scale, c.get(), + kK, add->data(), pool); const double tiled_matmul_seconds = hwy::platform::Now() - start_tiled; fprintf(stderr, "MatMul_4x4 took %f seconds.\n", tiled_matmul_seconds);