diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index d4b6745..aa28bbc 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -238,9 +238,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, // Compute Q only or QKV (if MHA). // If MHA, this also computes KV, which we copy to the KV cache below. const float scale = layer_weights->qkv_einsum_w.scale(); - MatMul_4x4_Batch( - num_interleaved, activations.pre_att_rms_out.All(), - layer_weights->qkv_einsum_w.data(), scale, activations.q.All(), pool); + MatMul_4x4( + num_interleaved, activations.pre_att_rms_out.All(), 0, + layer_weights->qkv_einsum_w.data(), 0, scale, activations.q.All(), + /*add=*/nullptr, pool); // Compute KV if not MHA. if constexpr (!kIsMHA) { @@ -256,7 +257,7 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, cache_pos * kCachePosSize + layer * kCacheLayerSize; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; // KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v). - // TODO: requires MatMul support for offsets. + // TODO: requires batched KVCache support. MatVec( layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x, activations.even_odd.All(), kv, pool); @@ -431,7 +432,6 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved, const auto A = activations.bf_pre_ffw_rms_out.All(); const float scale = layer_weights->gating_einsum_w.scale(); const auto B1 = layer_weights->gating_einsum_w.data(); - const auto B2 = B1 + kColsA * kColsB; auto C1 = activations.C1.All(); auto C2 = activations.C2.All(); constexpr bool kAddBias = TConfig::kFFBiases; @@ -444,21 +444,23 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved, output_bias = layer_weights->ffw_output_biases.data_scale1(); } + const size_t A_ofs = 0; // no offset, using the same activations for both. // Will go through GELU. - MatMul_4x4_Batch_Add(num_interleaved, A, B1, scale, - C1, bias1, pool); + MatMul_4x4(num_interleaved, A, A_ofs, B1, + /*B_ofs=*/0, scale, C1, bias1, pool); // What to multiply by. - MatMul_4x4_Batch_Add(num_interleaved, A, B2, scale, - C2, bias2, pool); + MatMul_4x4(num_interleaved, A, A_ofs, B1, + /*B_ofs=*/kColsA * kColsB, scale, C2, + bias2, pool); // Activation (Gelu) and multiply by gate. Store activations in C1. Activation(C1, C2, kFFHiddenDim * num_interleaved); // Hidden layer -> output layer. - MatMul_4x4_Batch_Add( - num_interleaved, C1, layer_weights->linear_w.data(), - layer_weights->linear_w.scale(), activations.ffw_out.All(), - output_bias, pool); + MatMul_4x4( + num_interleaved, C1, 0, layer_weights->linear_w.data(), 0, + layer_weights->linear_w.scale(), activations.ffw_out.All(), output_bias, + pool); } // `batch_idx` indicates which row of `x` to write to. @@ -1003,12 +1005,14 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, bool all_queries_eos = true; PROFILER_ZONE("Gen.Embedding"); + // Compute logits from last layer activations. + MatMul_4x4( + num_queries, activations.x.All(), 0, + weights.embedder_input_embedding.data(), 0, + weights.embedder_input_embedding.scale(), activations.logits.All(), + /*add=*/nullptr, pool); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); - // Compute logits from last layer activations. TODO: MatMul - MatVec( - weights.embedder_input_embedding, 0, activations.x.Batch(query_idx), - activations.even_odd.All(), logits, pool); if constexpr (TConfig::kFinalCap > 0.0f) { LogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize); } diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 3a912f0..0e368e6 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -17,20 +17,14 @@ #ifndef THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_ #define THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_ -#include #include #include #include -#include -#include -#include // std::enable_if_t - #include "compression/compress.h" #include "compression/sfp.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/detect_targets.h" #include "hwy/profiler.h" #endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_ @@ -53,23 +47,8 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; -HWY_INLINE constexpr size_t MaxCols() { - // Vec + mat rows should fit into 32 KiB L1. - return 2048; -} - -template -HWY_INLINE constexpr size_t RowsPerStrip() { - // Aim for 128 work items to reduce pool overhead. Must be at least one - // vector; prefer a power of two for faster division. - constexpr size_t kLanes = hn::ScalableTag().MaxLanes(); - constexpr size_t kRowsPerStrip = - kOuter < 128 ? kLanes - : HWY_MAX(kLanes, 1ULL << hwy::FloorLog2(kOuter / 128)); - return kRowsPerStrip; -} - -// Shared between f32 and bf16, which also accumulates into f32 vectors. +// c## are partial sums of the products of A and B; their horizontal sums are +// the final matmul result, stored in C, which is always f32. template > HWY_INLINE void StoreHorizontalSums(DF df, // VF c00, VF c01, VF c02, VF c03, // @@ -106,7 +85,8 @@ HWY_INLINE void StoreHorizontalSums(DF df, // tile_c[stride_c * 3 + 3] = scale * hn::ReduceSum(df, c33); } -// Completes the tile by summing across the vectors, and adds the biases. +// As above, but also adds `add[0..3]` to columns 0..3 of `tile_c`. `add` has no +// scale, and points to a 1D slice of the row vector. template > HWY_INLINE void StoreHorizontalSumsAdd(DF df, // VF c00, VF c01, VF c02, VF c03, // @@ -121,32 +101,33 @@ HWY_INLINE void StoreHorizontalSumsAdd(DF df, // // Each entry of C[r,c] is a dot product of A.row and B.col, which reside in // the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is // expensive, but only a fraction of the kColsA_RowsB/N FMAs. - float addon0 = hwy::ConvertScalarTo(add[0]); - tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00) + addon0; - float addon1 = hwy::ConvertScalarTo(add[1]); - tile_c[stride_c * 0 + 1] = scale * hn::ReduceSum(df, c01) + addon1; - float addon2 = hwy::ConvertScalarTo(add[2]); - tile_c[stride_c * 0 + 2] = scale * hn::ReduceSum(df, c02) + addon2; - float addon3 = hwy::ConvertScalarTo(add[3]); - tile_c[stride_c * 0 + 3] = scale * hn::ReduceSum(df, c03) + addon3; + const float add0 = add[0]; + // TODO: 4x4 transpose, then 128-bit vector FMA? + tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00) + add0; + const float add1 = add[1]; + tile_c[stride_c * 0 + 1] = scale * hn::ReduceSum(df, c01) + add1; + const float add2 = add[2]; + tile_c[stride_c * 0 + 2] = scale * hn::ReduceSum(df, c02) + add2; + const float add3 = add[3]; + tile_c[stride_c * 0 + 3] = scale * hn::ReduceSum(df, c03) + add3; if (kNumRows == 1) return; - tile_c[stride_c * 1 + 0] = scale * hn::ReduceSum(df, c10) + addon0; - tile_c[stride_c * 1 + 1] = scale * hn::ReduceSum(df, c11) + addon1; - tile_c[stride_c * 1 + 2] = scale * hn::ReduceSum(df, c12) + addon2; - tile_c[stride_c * 1 + 3] = scale * hn::ReduceSum(df, c13) + addon3; + tile_c[stride_c * 1 + 0] = scale * hn::ReduceSum(df, c10) + add0; + tile_c[stride_c * 1 + 1] = scale * hn::ReduceSum(df, c11) + add1; + tile_c[stride_c * 1 + 2] = scale * hn::ReduceSum(df, c12) + add2; + tile_c[stride_c * 1 + 3] = scale * hn::ReduceSum(df, c13) + add3; if (kNumRows == 2) return; - tile_c[stride_c * 2 + 0] = scale * hn::ReduceSum(df, c20) + addon0; - tile_c[stride_c * 2 + 1] = scale * hn::ReduceSum(df, c21) + addon1; - tile_c[stride_c * 2 + 2] = scale * hn::ReduceSum(df, c22) + addon2; - tile_c[stride_c * 2 + 3] = scale * hn::ReduceSum(df, c23) + addon3; + tile_c[stride_c * 2 + 0] = scale * hn::ReduceSum(df, c20) + add0; + tile_c[stride_c * 2 + 1] = scale * hn::ReduceSum(df, c21) + add1; + tile_c[stride_c * 2 + 2] = scale * hn::ReduceSum(df, c22) + add2; + tile_c[stride_c * 2 + 3] = scale * hn::ReduceSum(df, c23) + add3; if (kNumRows == 3) return; - tile_c[stride_c * 3 + 0] = scale * hn::ReduceSum(df, c30) + addon0; - tile_c[stride_c * 3 + 1] = scale * hn::ReduceSum(df, c31) + addon1; - tile_c[stride_c * 3 + 2] = scale * hn::ReduceSum(df, c32) + addon2; - tile_c[stride_c * 3 + 3] = scale * hn::ReduceSum(df, c33) + addon3; + tile_c[stride_c * 3 + 0] = scale * hn::ReduceSum(df, c30) + add0; + tile_c[stride_c * 3 + 1] = scale * hn::ReduceSum(df, c31) + add1; + tile_c[stride_c * 3 + 2] = scale * hn::ReduceSum(df, c32) + add2; + tile_c[stride_c * 3 + 3] = scale * hn::ReduceSum(df, c33) + add3; } // Wrapper around StoreHorizontalSums and StoreHorizontalSumsAdd to shorten call @@ -180,15 +161,15 @@ HWY_INLINE void StoreHorizontalSumsMaybeAdd( #if GEMMA_NATIVE_BF16 // Specialization for f32 += bf16 * bf16 that avoids promoting to f32. -template +template HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A, + const size_t A_ofs, const hwy::bfloat16_t* HWY_RESTRICT B, - float* HWY_RESTRICT C, - const float scale, - const float* HWY_RESTRICT add, + const size_t B_ofs, float* HWY_RESTRICT C, + const float scale, const float* HWY_RESTRICT add, const size_t idx_tile, const size_t xtiles, - const size_t stride_a, const size_t stride_b, - const size_t stride_c) { + const size_t cols_a, const size_t stride_a, + const size_t stride_b, const size_t stride_c) { constexpr size_t kRegRows = 4; constexpr size_t kRegCols = 4; static_assert(kNumRows <= kRegRows); @@ -226,41 +207,42 @@ HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A, VF c32 = hn::Zero(df); VF c33 = hn::Zero(df); - const hwy::bfloat16_t* HWY_RESTRICT tile_a = A + stride_a * row_a; - const hwy::bfloat16_t* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c; + const hwy::bfloat16_t* HWY_RESTRICT A_tile = A + A_ofs + stride_a * row_a; + const hwy::bfloat16_t* HWY_RESTRICT B_tile = + B + B_ofs + stride_b * row_b_col_c; // Loop over columns of A and columns of the transposed B, in steps of N. // Accumulates into the c## vectors. HWY_UNROLL(1) - for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) { + for (size_t col_ab = 0; col_ab < cols_a; col_ab += N) { using V = hn::Vec; - const V b0 = hn::LoadU(d, tile_b + stride_b * 0 + col_ab); - const V b1 = hn::LoadU(d, tile_b + stride_b * 1 + col_ab); - const V b2 = hn::LoadU(d, tile_b + stride_b * 2 + col_ab); - const V b3 = hn::LoadU(d, tile_b + stride_b * 3 + col_ab); + const V b0 = hn::LoadU(d, B_tile + stride_b * 0 + col_ab); + const V b1 = hn::LoadU(d, B_tile + stride_b * 1 + col_ab); + const V b2 = hn::LoadU(d, B_tile + stride_b * 2 + col_ab); + const V b3 = hn::LoadU(d, B_tile + stride_b * 3 + col_ab); - const V a0 = hn::LoadU(d, tile_a + stride_a * 0 + col_ab); + const V a0 = hn::LoadU(d, A_tile + stride_a * 0 + col_ab); c00 = hn::ReorderWidenMulAccumulate(df, a0, b0, c00, unused_sum1); c01 = hn::ReorderWidenMulAccumulate(df, a0, b1, c01, unused_sum1); c02 = hn::ReorderWidenMulAccumulate(df, a0, b2, c02, unused_sum1); c03 = hn::ReorderWidenMulAccumulate(df, a0, b3, c03, unused_sum1); if constexpr (kNumRows == 1) continue; - const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab); + const V a1 = hn::LoadU(d, A_tile + stride_a * 1 + col_ab); c10 = hn::ReorderWidenMulAccumulate(df, a1, b0, c10, unused_sum1); c11 = hn::ReorderWidenMulAccumulate(df, a1, b1, c11, unused_sum1); c12 = hn::ReorderWidenMulAccumulate(df, a1, b2, c12, unused_sum1); c13 = hn::ReorderWidenMulAccumulate(df, a1, b3, c13, unused_sum1); if constexpr (kNumRows == 2) continue; - const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab); + const V a2 = hn::LoadU(d, A_tile + stride_a * 2 + col_ab); c20 = hn::ReorderWidenMulAccumulate(df, a2, b0, c20, unused_sum1); c21 = hn::ReorderWidenMulAccumulate(df, a2, b1, c21, unused_sum1); c22 = hn::ReorderWidenMulAccumulate(df, a2, b2, c22, unused_sum1); c23 = hn::ReorderWidenMulAccumulate(df, a2, b3, c23, unused_sum1); if constexpr (kNumRows == 3) continue; - const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab); + const V a3 = hn::LoadU(d, A_tile + stride_a * 3 + col_ab); c30 = hn::ReorderWidenMulAccumulate(df, a3, b0, c30, unused_sum1); c31 = hn::ReorderWidenMulAccumulate(df, a3, b1, c31, unused_sum1); c32 = hn::ReorderWidenMulAccumulate(df, a3, b2, c32, unused_sum1); @@ -270,10 +252,10 @@ HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A, // Ensure sum1 was indeed unused. HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); - float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; + float* HWY_RESTRICT C_tile = C + stride_c * row_a + row_b_col_c; StoreHorizontalSumsMaybeAdd( df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31, - c32, c33, add, row_b_col_c, scale, tile_c, stride_c); + c32, c33, add, row_b_col_c, scale, C_tile, stride_c); } #endif // GEMMA_NATIVE_BF16 @@ -295,19 +277,17 @@ HWY_INLINE void UpdateTileRow(const VF& a0, const VF& a1, const VF& b00, c3 = hn::MulAdd(a1, b31, c3); } -// Accumulates a single kNumRowsx4 tile of A x B into C. B is transposed, so we -// can iterate over both A and B with consecutive vector loads. kNumRows<=4. +// Accumulates a single kNumRows (<= 4) x 4 tile of A x B into C. B is +// transposed, so we iterate over both A and B with consecutive vector loads. // General case: uses CompressTraits to load from A and B. -template -HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, - const MatTB* HWY_RESTRICT B, - float* HWY_RESTRICT C, - const float scale, +template +HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const size_t A_ofs, + const MatTB* HWY_RESTRICT B, const size_t B_ofs, + float* HWY_RESTRICT C, const float scale, const float* HWY_RESTRICT add, const size_t idx_tile, const size_t xtiles, - const size_t stride_a, const size_t stride_b, - const size_t stride_c) { + const size_t cols_a, const size_t stride_a, + const size_t stride_b, const size_t stride_c) { constexpr size_t kRegRows = 4; constexpr size_t kRegCols = 4; static_assert(kNumRows <= kRegRows); @@ -343,8 +323,8 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, V c32 = hn::Zero(d32); V c33 = hn::Zero(d32); - const size_t tile_a_ofs = stride_a * row_a; - const size_t tile_b_ofs = stride_b * row_b_col_c; + const size_t A_tile_ofs = A_ofs + stride_a * row_a; + const size_t B_tile_ofs = B_ofs + stride_b * row_b_col_c; // Loop over columns of A and columns of the transposed B, in steps of 2*N // (since we are decoding consecutive bytes at each iteration). @@ -352,69 +332,74 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, size_t col_ab = 0; HWY_UNROLL(1) - for (; col_ab <= kColsA_RowsB - 2 * N; col_ab += 2 * N) { + for (; col_ab <= cols_a - 2 * N; col_ab += 2 * N) { V b00, b01; - TraitsB::Decompress2(d32, B, tile_b_ofs + stride_b * 0 + col_ab, b00, b01); + TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 0 + col_ab, b00, b01); V b10, b11; - TraitsB::Decompress2(d32, B, tile_b_ofs + stride_b * 1 + col_ab, b10, b11); + TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 1 + col_ab, b10, b11); V b20, b21; - TraitsB::Decompress2(d32, B, tile_b_ofs + stride_b * 2 + col_ab, b20, b21); + TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 2 + col_ab, b20, b21); V b30, b31; - TraitsB::Decompress2(d32, B, tile_b_ofs + stride_b * 3 + col_ab, b30, b31); + TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 3 + col_ab, b30, b31); V a00, a01; - TraitsA::Decompress2(d32, A, tile_a_ofs + stride_a * 0 + col_ab, a00, a01); + TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 0 + col_ab, a00, a01); UpdateTileRow(a00, a01, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01, c02, c03); if constexpr (kNumRows == 1) continue; V a10, a11; - TraitsA::Decompress2(d32, A, tile_a_ofs + stride_a * 1 + col_ab, a10, a11); + TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 1 + col_ab, a10, a11); UpdateTileRow(a10, a11, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11, c12, c13); if constexpr (kNumRows == 2) continue; V a20, a21; - TraitsA::Decompress2(d32, A, tile_a_ofs + stride_a * 2 + col_ab, a20, a21); + TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 2 + col_ab, a20, a21); UpdateTileRow(a20, a21, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21, c22, c23); if constexpr (kNumRows == 3) continue; V a30, a31; - TraitsA::Decompress2(d32, A, tile_a_ofs + stride_a * 3 + col_ab, a30, a31); + TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 3 + col_ab, a30, a31); UpdateTileRow(a30, a31, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31, c32, c33); } - float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; + float* HWY_RESTRICT C_tile = C + stride_c * row_a + row_b_col_c; StoreHorizontalSumsMaybeAdd( d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31, - c32, c33, add, row_b_col_c, scale, tile_c, stride_c); + c32, c33, add, row_b_col_c, scale, C_tile, stride_c); } -// C = A * B * scale [+ add]. -// Computes the matrix product of A and B and stores this in C. -// If kAdd is true, the row-vector `add` is added to each row of C. -// A is a matrix of size (batch_size, kColsA_RowsB). +// Tiled 4x4 GEMM: C = A * B * scale [+ add]. +// Computes the matrix product of A and B and stores this in C. Processes tiles +// of 4x4 vectors in parallel with a work-stealing thread pool. +// +// If kAdd is true, the row-vector `add` is added to each row of C, otherwise +// `add` is ignored and can be nullptr. +// A is a row-major matrix of size (batch_size, kColsA_RowsB). // B is passed transposed (column-major), so a matrix of size // (kColsBC, kColsA_RowsB), representing a B of size (kColsA_RowsB, kColsBC). +// A_ofs and B_ofs are offsets into A and B, respectively; they remain separate +// from the pointers because some MatTA/B such as NuqStream do not support +// pointer arithmetic. // C is a matrix of size (batch_size, kColsBC). // The product is scaled by `scale` to support CompressedArray with scale != 1, // the caller can pass the product of the scales of A and B. // A scale for `add` is not supported, so make sure its scale is 1. -// Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k, -// and kColsBC is 24k or 3k. -// This function processes tiles in parallel with a work-stealing thread pool. +// Typically batch_size is 1..512, kColsA_RowsB and kColsBC are 3k or 24k. template -HWY_NOINLINE void MatMul_4x4_Batch_Add( - size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B, - float scale, OutT* HWY_RESTRICT C, const AddT* HWY_RESTRICT add, - hwy::ThreadPool& pool) { + typename MatTB, typename OutT> +HWY_NOINLINE void MatMul_4x4(const size_t batch_size, + const MatTA* HWY_RESTRICT A, const size_t A_ofs, + const MatTB* HWY_RESTRICT B, const size_t B_ofs, + const float scale, OutT* HWY_RESTRICT C, + const float* HWY_RESTRICT add, + hwy::ThreadPool& pool) { PROFILER_ZONE("Matmul"); - // Process reg-sized tiles of C in parallel. We currently write C directly, - // which touches more memory than fits in L3. TODO: add another level of loops - // so that we finish one L3-sized piece of C at a time. + // We currently write C directly, which touches more memory than fits in L3. + // TODO: add another level of loops to finish L3-sized pieces of C at a time. const hn::ScalableTag d; const size_t N = Lanes(d); constexpr size_t kRegRows = 4; @@ -436,38 +421,28 @@ HWY_NOINLINE void MatMul_4x4_Batch_Add( HWY_ASSERT(num_rows > 0); switch (num_rows) { case 1: - GEMM_4x4_Tile<1, kColsA_RowsB, kAdd>(A, B, C, scale, add, idx_tile, - kTilesX, kStrideA, kStrideB, - kStrideC); + GEMM_4x4_Tile<1, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile, + kTilesX, kColsA_RowsB, kStrideA, kStrideB, + kStrideC); break; case 2: - GEMM_4x4_Tile<2, kColsA_RowsB, kAdd>(A, B, C, scale, add, idx_tile, - kTilesX, kStrideA, kStrideB, - kStrideC); + GEMM_4x4_Tile<2, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile, + kTilesX, kColsA_RowsB, kStrideA, kStrideB, + kStrideC); break; case 3: - GEMM_4x4_Tile<3, kColsA_RowsB, kAdd>(A, B, C, scale, add, idx_tile, - kTilesX, kStrideA, kStrideB, - kStrideC); + GEMM_4x4_Tile<3, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile, + kTilesX, kColsA_RowsB, kStrideA, kStrideB, + kStrideC); break; default: - GEMM_4x4_Tile<4, kColsA_RowsB, kAdd>(A, B, C, scale, add, idx_tile, - kTilesX, kStrideA, kStrideB, - kStrideC); + GEMM_4x4_Tile<4, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile, + kTilesX, kColsA_RowsB, kStrideA, kStrideB, + kStrideC); } }); } -// As above, without the add. -template -HWY_NOINLINE void MatMul_4x4_Batch( - size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B, - float scale, OutT* HWY_RESTRICT C, hwy::ThreadPool& pool) { - MatMul_4x4_Batch_Add( - batch_size, A, B, scale, C, /*add=*/static_cast(nullptr), pool); -} - //------------------------------------------------------------------------------ HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned, const size_t size, float* HWY_RESTRICT out) { @@ -525,6 +500,22 @@ HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0, } } +HWY_INLINE constexpr size_t MaxCols() { + // Vec + mat rows should fit into 32 KiB L1. + return 2048; +} + +template +HWY_INLINE constexpr size_t RowsPerStrip() { + // Aim for 128 work items to reduce pool overhead. Must be at least one + // vector; prefer a power of two for faster division. + constexpr size_t kLanes = hn::ScalableTag().MaxLanes(); + constexpr size_t kRowsPerStrip = + kOuter < 128 ? kLanes + : HWY_MAX(kLanes, 1ULL << hwy::FloorLog2(kOuter / 128)); + return kRowsPerStrip; +} + namespace detail { // For each i = [0, num_rows), compute partial (length `num_cols`) dot product diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 322cb89..ec9d390 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -301,15 +301,10 @@ void TestTiledBatchMatMul() { const double start_tiled = hwy::platform::Now(); EXPECT_EQ(scale, a->scale() * b_trans->scale()); - if (kAdd) { - MatMul_4x4_Batch_Add(kM, a->data(), b_trans->data(), scale, - c.get(), add->data(), pool); - } else { - MatMul_4x4_Batch(kM, a->data(), b_trans->data(), scale, c.get(), - pool); - } + MatMul_4x4(kM, a->data(), 0, b_trans->data(), 0, scale, c.get(), + add->data(), pool); const double tiled_matmul_seconds = hwy::platform::Now() - start_tiled; - fprintf(stderr, "MatMul_4x4_Batch took %f seconds.\n", tiled_matmul_seconds); + fprintf(stderr, "MatMul_4x4 took %f seconds.\n", tiled_matmul_seconds); AssertClose(c_slow->data(), c.get(), kM * kK); }