From 6ea4232b2e2444f490645b30fbb6026f141ee8dd Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 30 Jul 2024 01:55:09 -0700 Subject: [PATCH] MatMul cleanup: Mat struct, simplify args. Add large benchmark to test, use 4 threads, skip some targets. Also use Traits::Name instead of typeid. PiperOrigin-RevId: 657496185 --- compression/compress-inl.h | 4 + gemma/gemma-inl.h | 62 ++++--- ops/matmul-inl.h | 243 ++++++++++++++-------------- ops/matmul_test.cc | 319 +++++++++++++++++++------------------ 4 files changed, 314 insertions(+), 314 deletions(-) diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 1693cf0..3da53ce 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -60,6 +60,7 @@ struct CompressTraits {}; template <> struct CompressTraits { using MatT = float; + static const char* Name() { return "f32"; } static constexpr bool kSupportsEvenOdd = false; // unnecessary template @@ -123,6 +124,7 @@ struct CompressTraits { template <> struct CompressTraits { using MatT = hwy::bfloat16_t; + static const char* Name() { return "bf16"; } static constexpr bool kSupportsEvenOdd = true; template @@ -292,6 +294,7 @@ struct CompressTraits { template <> struct CompressTraits { using MatT = SfpStream; + static const char* Name() { return "sfp"; } static constexpr bool kSupportsEvenOdd = true; // Callers are responsible for scaling `in` such that its magnitudes do not @@ -389,6 +392,7 @@ struct CompressTraits { template <> struct CompressTraits { using MatT = NuqStream; + static const char* Name() { return "nuq"; } static constexpr bool kSupportsEvenOdd = false; template diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index e5cc768..02c0c44 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -237,12 +237,11 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, // // Compute Q only or QKV (if MHA). // If MHA, this also computes KV, which we copy to the KV cache below. - MatMul_4x4(num_interleaved, activations.pre_att_rms_out.All(), - 0, kModelDim, layer_weights->qkv_einsum_w.data(), - 0, kHeads * kQStride, - layer_weights->qkv_einsum_w.scale(), - activations.q.All(), kHeads * kQStride, - /*add=*/nullptr, pool); + MatMul_4x4( + num_interleaved, MakeMat(activations.pre_att_rms_out.All(), kModelDim), + MakeMat(layer_weights->qkv_einsum_w.data(), kModelDim), + layer_weights->qkv_einsum_w.scale(), /*add=*/nullptr, + MakeMat(activations.q.All(), kHeads * kQStride), pool); // Compute KV if not MHA. if constexpr (!kIsMHA) { @@ -250,16 +249,16 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, // directly into the KV cache with a stride of kCachePosSize. if (num_queries == 1 && batch_start + num_tokens <= div_seq_len.GetDivisor()) { - const size_t colsBC = kKVHeads * 2 * kQKVDim; const size_t kv_ofs = batch_start * kCachePosSize + layer * kCacheLayerSize; // KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v). float* HWY_RESTRICT kv = kv_caches[0].kv_cache.get() + kv_ofs; MatMul_4x4( - num_tokens, activations.pre_att_rms_out.All(), 0, kModelDim, - layer_weights->qkv_einsum_w.data(), kHeads * kQKVDim * kModelDim, - colsBC, layer_weights->qkv_einsum_w.scale(), kv, kCachePosSize, - /*add=*/nullptr, pool); + num_tokens, MakeMat(activations.pre_att_rms_out.All(), kModelDim), + MakeMat(layer_weights->qkv_einsum_w.data(), kModelDim, kModelDim, + kHeads * kQKVDim * kModelDim), + layer_weights->qkv_einsum_w.scale(), /*add=*/nullptr, + MakeMat(kv, kKVHeads * 2 * kQKVDim, kCachePosSize), pool); } else { for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; ++interleaved_idx) { @@ -441,14 +440,12 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved, // MatMul expects col-major B, which is what we have: kModelDim consecutive // elements in memory, repeated kFFHiddenDim times. - constexpr size_t kColsA = kModelDim; - constexpr size_t kColsBC = kFFHiddenDim; HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize()); - const auto A = activations.bf_pre_ffw_rms_out.All(); + const auto A = MakeMat(activations.bf_pre_ffw_rms_out.All(), kModelDim); + const auto B1 = MakeMat(layer_weights->gating_einsum_w.data(), kModelDim); + const auto B2 = MakeMat(layer_weights->gating_einsum_w.data(), kModelDim, + kModelDim, kModelDim * kFFHiddenDim); const float scale = layer_weights->gating_einsum_w.scale(); - const auto B1 = layer_weights->gating_einsum_w.data(); - auto C1 = activations.C1.All(); - auto C2 = activations.C2.All(); constexpr bool kAddBias = TConfig::kFFBiases; const float* bias1 = nullptr; const float* bias2 = nullptr; @@ -458,24 +455,22 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved, bias2 = bias1 + kFFHiddenDim; output_bias = layer_weights->ffw_output_biases.data_scale1(); } + auto C1 = MakeMat(activations.C1.All(), kFFHiddenDim); + auto C2 = MakeMat(activations.C2.All(), kFFHiddenDim); - const size_t A_ofs = 0; // no offset, using the same activations for both. // Will go through GELU. - MatMul_4x4(num_interleaved, A, A_ofs, kColsA, B1, - /*B_ofs=*/0, kColsBC, scale, C1, kColsBC, bias1, pool); + MatMul_4x4(num_interleaved, A, B1, scale, bias1, C1, pool); // What to multiply by. - MatMul_4x4(num_interleaved, A, A_ofs, kColsA, B1, - /*B_ofs=*/kColsA * kColsBC, kColsBC, scale, C2, kColsBC, - bias2, pool); + MatMul_4x4(num_interleaved, A, B2, scale, bias2, C2, pool); // Activation (Gelu) and multiply by gate. Store activations in C1. - Activation(C1, C2, kFFHiddenDim * num_interleaved); + Activation(C1.ptr, C2.ptr, kFFHiddenDim * num_interleaved); // Hidden layer -> output layer. - MatMul_4x4(num_interleaved, C1, 0, kFFHiddenDim, - layer_weights->linear_w.data(), 0, kModelDim, - layer_weights->linear_w.scale(), - activations.ffw_out.All(), kModelDim, output_bias, pool); + MatMul_4x4(num_interleaved, C1, + MakeMat(layer_weights->linear_w.data(), kFFHiddenDim), + layer_weights->linear_w.scale(), output_bias, + MakeMat(activations.ffw_out.All(), kModelDim), pool); } // `batch_idx` indicates which row of `x` to write to. @@ -1022,12 +1017,11 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, bool all_queries_eos = true; PROFILER_ZONE("Gen.Embedding"); // Compute logits from last layer activations. - MatMul_4x4(num_queries, activations.x.All(), 0, kModelDim, - weights.embedder_input_embedding.data(), 0, - kVocabSize, - weights.embedder_input_embedding.scale(), - activations.logits.All(), kVocabSize, - /*add=*/nullptr, pool); + MatMul_4x4( + num_queries, MakeMat(activations.x.All(), kModelDim), + MakeMat(weights.embedder_input_embedding.data(), kModelDim), + weights.embedder_input_embedding.scale(), /*add=*/nullptr, + MakeMat(activations.logits.All(), kVocabSize), pool); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); if constexpr (TConfig::kFinalCap > 0.0f) { diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index ea7f03e..0f45881 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -60,7 +60,7 @@ HWY_INLINE void StoreHorizontalSums(DF df, // // We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles. // Each entry of C[r,c] is a dot product of A.row and B.col, which reside in // the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is - // expensive, but only a fraction of the kColsA_RowsB/N FMAs. + // expensive, but only a fraction of the A.cols/N FMAs. tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00); tile_c[stride_c * 0 + 1] = scale * hn::ReduceSum(df, c01); tile_c[stride_c * 0 + 2] = scale * hn::ReduceSum(df, c02); @@ -93,14 +93,14 @@ HWY_INLINE void StoreHorizontalSumsAdd(DF df, // VF c10, VF c11, VF c12, VF c13, // VF c20, VF c21, VF c22, VF c23, // VF c30, VF c31, VF c32, VF c33, - const float* HWY_RESTRICT add, const float scale, + const float* HWY_RESTRICT add, float* HWY_RESTRICT tile_c, size_t stride_c) { // We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles. // Each entry of C[r,c] is a dot product of A.row and B.col, which reside in // the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is - // expensive, but only a fraction of the kColsA_RowsB/N FMAs. + // expensive, but only a fraction of the A.cols/N FMAs. const float add0 = add[0]; // TODO: 4x4 transpose, then 128-bit vector FMA? tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00) + add0; @@ -137,12 +137,12 @@ template > HWY_INLINE void StoreHorizontalSumsMaybeAdd( DF df, VF c00, VF c01, VF c02, VF c03, VF c10, VF c11, VF c12, VF c13, VF c20, VF c21, VF c22, VF c23, VF c30, VF c31, VF c32, VF c33, - const float* HWY_RESTRICT add, size_t add_offset, const float scale, + const float scale, const float* HWY_RESTRICT add, size_t add_offset, float* HWY_RESTRICT tile_c, size_t stride_c) { if constexpr (kAdd) { StoreHorizontalSumsAdd(df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31, c32, c33, - add + add_offset, scale, tile_c, stride_c); + scale, add + add_offset, tile_c, stride_c); } else { StoreHorizontalSums(df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31, c32, c33, @@ -150,6 +150,36 @@ HWY_INLINE void StoreHorizontalSumsMaybeAdd( } } +// Wrapper to simplify call sites. T can be const or non-const. +template +struct Mat { + bool NotEmpty() const { + return ptr != nullptr && cols != 0 && stride >= cols; + } + size_t Row(size_t r) const { return ofs + stride * r; } + + T* HWY_RESTRICT ptr; + size_t cols; + + // elements between rows, which is typically the same as `cols`. + size_t stride; + + // Offset to add to `ptr`; separate because T=NuqStream does not support + // pointer arithmetic. + size_t ofs; +}; + +template +Mat MakeMat(T* HWY_RESTRICT ptr, size_t cols, size_t stride, + size_t ofs = 0) { + return Mat{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs}; +} + +template +Mat MakeMat(T* HWY_RESTRICT ptr, size_t cols) { + return MakeMat(ptr, cols, cols); +} + #undef GEMMA_NATIVE_BF16 #if HWY_IDE || (defined(HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16) == \ defined(HWY_TARGET_TOGGLE)) @@ -162,31 +192,18 @@ HWY_INLINE void StoreHorizontalSumsMaybeAdd( // Specialization for f32 += bf16 * bf16 that avoids promoting to f32. template -HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A, - const size_t A_ofs, - const hwy::bfloat16_t* HWY_RESTRICT B, - const size_t B_ofs, float* HWY_RESTRICT C, - const float scale, const float* HWY_RESTRICT add, - const size_t idx_tile, const size_t xtiles, - const size_t cols_a, const size_t stride_a, - const size_t stride_b, const size_t stride_c) { - constexpr size_t kRegRows = 4; - constexpr size_t kRegCols = 4; - static_assert(kNumRows <= kRegRows); - - // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. - const size_t row_a = idx_tile / xtiles * kRegRows; - const size_t row_b_col_c = idx_tile % xtiles * kRegCols; - +HWY_INLINE void MatMulTile(const Mat& A, + const Mat& B, + const size_t row_a, const size_t row_b_col_c, + const float scale, const float* HWY_RESTRICT add, + const Mat& C) { const hn::ScalableTag df; using VF = hn::Vec; // ReorderWidenMulAccumulate does not use its sum1 arg and we can use full // bf16 vectors. const hn::Repartition d; - VF unused_sum1 = hn::Zero(df); - const size_t N = Lanes(d); - + VF unused_sum1 = hn::Zero(df); VF c00 = hn::Zero(df); VF c01 = hn::Zero(df); VF c02 = hn::Zero(df); @@ -207,42 +224,41 @@ HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A, VF c32 = hn::Zero(df); VF c33 = hn::Zero(df); - const hwy::bfloat16_t* HWY_RESTRICT A_tile = A + A_ofs + stride_a * row_a; - const hwy::bfloat16_t* HWY_RESTRICT B_tile = - B + B_ofs + stride_b * row_b_col_c; + const hwy::bfloat16_t* HWY_RESTRICT A_tile = A.ptr + A.Row(row_a); + const hwy::bfloat16_t* HWY_RESTRICT B_tile = B.ptr + B.Row(row_b_col_c); // Loop over columns of A and columns of the transposed B, in steps of N. // Accumulates into the c## vectors. HWY_UNROLL(1) - for (size_t col_ab = 0; col_ab < cols_a; col_ab += N) { + for (size_t col_ab = 0; col_ab < A.cols; col_ab += N) { using V = hn::Vec; - const V b0 = hn::LoadU(d, B_tile + stride_b * 0 + col_ab); - const V b1 = hn::LoadU(d, B_tile + stride_b * 1 + col_ab); - const V b2 = hn::LoadU(d, B_tile + stride_b * 2 + col_ab); - const V b3 = hn::LoadU(d, B_tile + stride_b * 3 + col_ab); + const V b0 = hn::LoadU(d, B_tile + B.stride * 0 + col_ab); + const V b1 = hn::LoadU(d, B_tile + B.stride * 1 + col_ab); + const V b2 = hn::LoadU(d, B_tile + B.stride * 2 + col_ab); + const V b3 = hn::LoadU(d, B_tile + B.stride * 3 + col_ab); - const V a0 = hn::LoadU(d, A_tile + stride_a * 0 + col_ab); + const V a0 = hn::LoadU(d, A_tile + A.stride * 0 + col_ab); c00 = hn::ReorderWidenMulAccumulate(df, a0, b0, c00, unused_sum1); c01 = hn::ReorderWidenMulAccumulate(df, a0, b1, c01, unused_sum1); c02 = hn::ReorderWidenMulAccumulate(df, a0, b2, c02, unused_sum1); c03 = hn::ReorderWidenMulAccumulate(df, a0, b3, c03, unused_sum1); if constexpr (kNumRows == 1) continue; - const V a1 = hn::LoadU(d, A_tile + stride_a * 1 + col_ab); + const V a1 = hn::LoadU(d, A_tile + A.stride * 1 + col_ab); c10 = hn::ReorderWidenMulAccumulate(df, a1, b0, c10, unused_sum1); c11 = hn::ReorderWidenMulAccumulate(df, a1, b1, c11, unused_sum1); c12 = hn::ReorderWidenMulAccumulate(df, a1, b2, c12, unused_sum1); c13 = hn::ReorderWidenMulAccumulate(df, a1, b3, c13, unused_sum1); if constexpr (kNumRows == 2) continue; - const V a2 = hn::LoadU(d, A_tile + stride_a * 2 + col_ab); + const V a2 = hn::LoadU(d, A_tile + A.stride * 2 + col_ab); c20 = hn::ReorderWidenMulAccumulate(df, a2, b0, c20, unused_sum1); c21 = hn::ReorderWidenMulAccumulate(df, a2, b1, c21, unused_sum1); c22 = hn::ReorderWidenMulAccumulate(df, a2, b2, c22, unused_sum1); c23 = hn::ReorderWidenMulAccumulate(df, a2, b3, c23, unused_sum1); if constexpr (kNumRows == 3) continue; - const V a3 = hn::LoadU(d, A_tile + stride_a * 3 + col_ab); + const V a3 = hn::LoadU(d, A_tile + A.stride * 3 + col_ab); c30 = hn::ReorderWidenMulAccumulate(df, a3, b0, c30, unused_sum1); c31 = hn::ReorderWidenMulAccumulate(df, a3, b1, c31, unused_sum1); c32 = hn::ReorderWidenMulAccumulate(df, a3, b2, c32, unused_sum1); @@ -252,10 +268,10 @@ HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A, // Ensure sum1 was indeed unused. HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); - float* HWY_RESTRICT C_tile = C + stride_c * row_a + row_b_col_c; + float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_a) + row_b_col_c; StoreHorizontalSumsMaybeAdd( df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31, - c32, c33, add, row_b_col_c, scale, C_tile, stride_c); + c32, c33, scale, add, row_b_col_c, C_tile, C.stride); } #endif // GEMMA_NATIVE_BF16 @@ -277,32 +293,20 @@ HWY_INLINE void UpdateTileRow(const VF& a0, const VF& a1, const VF& b00, c3 = hn::MulAdd(a1, b31, c3); } -// Accumulates a single kNumRows (<= 4) x 4 tile of A x B into C. B is -// transposed, so we iterate over both A and B with consecutive vector loads. +// Streams a `(kNumRows, 4)` strip of `A` and the transposed `B`, then writes a +// finished tile of `C`. // General case: uses CompressTraits to load from A and B. template -HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const size_t A_ofs, - const MatTB* HWY_RESTRICT B, const size_t B_ofs, - float* HWY_RESTRICT C, const float scale, - const float* HWY_RESTRICT add, - const size_t idx_tile, const size_t xtiles, - const size_t cols_a, const size_t stride_a, - const size_t stride_b, const size_t stride_c) { - constexpr size_t kRegRows = 4; - constexpr size_t kRegCols = 4; - static_assert(kNumRows <= kRegRows); - - using TraitsA = CompressTraits; - using TraitsB = CompressTraits; - - // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. - const size_t row_a = idx_tile / xtiles * kRegRows; - const size_t row_b_col_c = idx_tile % xtiles * kRegCols; +HWY_INLINE void MatMulTile(const Mat& A, const Mat& B, + const size_t row_a, const size_t row_b_col_c, + const float scale, const float* HWY_RESTRICT add, + const Mat& C) { + using TraitsA = CompressTraits>; + using TraitsB = CompressTraits>; const hn::ScalableTag d32; const size_t N = hn::Lanes(d32); using V = hn::Vec; - V c00 = hn::Zero(d32); V c01 = hn::Zero(d32); V c02 = hn::Zero(d32); @@ -323,127 +327,118 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const size_t A_ofs, V c32 = hn::Zero(d32); V c33 = hn::Zero(d32); - const size_t A_tile_ofs = A_ofs + stride_a * row_a; - const size_t B_tile_ofs = B_ofs + stride_b * row_b_col_c; + const size_t A_ofs = A.Row(row_a); + const size_t B_ofs = B.Row(row_b_col_c); // Loop over columns of A and columns of the transposed B, in steps of 2*N // (since we are decoding consecutive bytes at each iteration). - // Accumulates into the c## vectors. + // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, + // col_ab) for B. Accumulates into the c## vectors. size_t col_ab = 0; HWY_UNROLL(1) - for (; col_ab <= cols_a - 2 * N; col_ab += 2 * N) { + for (; col_ab <= A.cols - 2 * N; col_ab += 2 * N) { V b00, b01; - TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 0 + col_ab, b00, b01); + TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 0 + col_ab, b00, b01); V b10, b11; - TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 1 + col_ab, b10, b11); + TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 1 + col_ab, b10, b11); V b20, b21; - TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 2 + col_ab, b20, b21); + TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 2 + col_ab, b20, b21); V b30, b31; - TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 3 + col_ab, b30, b31); + TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 3 + col_ab, b30, b31); V a00, a01; - TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 0 + col_ab, a00, a01); + TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 0 + col_ab, a00, a01); UpdateTileRow(a00, a01, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01, c02, c03); if constexpr (kNumRows == 1) continue; V a10, a11; - TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 1 + col_ab, a10, a11); + TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 1 + col_ab, a10, a11); UpdateTileRow(a10, a11, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11, c12, c13); if constexpr (kNumRows == 2) continue; V a20, a21; - TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 2 + col_ab, a20, a21); + TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 2 + col_ab, a20, a21); UpdateTileRow(a20, a21, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21, c22, c23); if constexpr (kNumRows == 3) continue; V a30, a31; - TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 3 + col_ab, a30, a31); + TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 3 + col_ab, a30, a31); UpdateTileRow(a30, a31, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31, c32, c33); } - float* HWY_RESTRICT C_tile = C + stride_c * row_a + row_b_col_c; + float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_a) + row_b_col_c; StoreHorizontalSumsMaybeAdd( d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31, - c32, c33, add, row_b_col_c, scale, C_tile, stride_c); + c32, c33, scale, add, row_b_col_c, C_tile, C.stride); } -// Tiled 4x4 GEMM: C = A * B * scale [+ add]. -// Computes the matrix product of A and B and stores this in C. Processes tiles -// of 4x4 vectors in parallel with a work-stealing thread pool. +// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`. // -// If kAdd is true, the row-vector `add` is added to each row of C, otherwise -// `add` is ignored and can be nullptr. -// A is a row-major matrix of size (batch_size, colsA_rowsB). -// B is passed transposed (column-major), so a matrix of size -// (colsBC, colsA_rowsB), representing a B of size (colsA_rowsB, colsBC). -// A_ofs and B_ofs are offsets into A and B, respectively; they remain separate -// from the pointers because some MatTA/B such as NuqStream do not support -// pointer arithmetic. -// C is a row-major matrix of size (batch_size, colsBC), with `C_stride` -// elements between rows, which is typically the same as `colsBC`. There is no -// `C_ofs` because callers can simply add it to `C`. -// The product is scaled by `scale` to support CompressedArray with scale != 1, -// the caller can pass the product of the scales of A and B. -// A scale for `add` is not supported, so make sure its scale is 1. -// Typically batch_size is 1..512, colsA_rowsB and colsBC are 3k or 24k. +// `A` is a row-major matrix of shape `(batch_size, A.cols)`. +// `B` is transposed; `B.cols`, which must match `A.cols`, denotes the number of +// rows in the original B, and `C.cols` the number of columns in the original B. +// +// `scale` allows expanding the smaller range of `SfpStream` to the original +// values. When `A` and/or `B` are from CompressedArray, `scale` should be the +// product of their `.scale()` values. +// +// If `kAdd` is true, the row-vector `add` is added to each row of `C`, +// otherwise `add` is ignored and can be nullptr. A scale for `add` is not +// supported, so make sure its scale is 1. +// +// `C` is a row-major matrix of size `(batch_size, C.cols)`. +// Writes 4x4 tiles of C in parallel using a work-stealing thread pool. +// Typically batch_size is 1..512, A.cols and C.cols are 3k or 24k. template -HWY_NOINLINE void MatMul_4x4(const size_t batch_size, - const MatTA* HWY_RESTRICT A, const size_t A_ofs, - const size_t colsA_rowsB, - const MatTB* HWY_RESTRICT B, const size_t B_ofs, - const size_t colsBC, const float scale, - float* HWY_RESTRICT C, const size_t C_stride, - const float* HWY_RESTRICT add, +HWY_NOINLINE void MatMul_4x4(const size_t batch_size, const Mat& A, + const Mat& B, const float scale, + const float* HWY_RESTRICT add, const Mat& C, hwy::ThreadPool& pool) { PROFILER_ZONE("Matmul"); + constexpr size_t kRegRows = 4; // if changing, also update the switch below. + constexpr size_t kRegCols = 4; + + HWY_DASSERT(A.NotEmpty() && B.NotEmpty() && C.NotEmpty()); + HWY_DASSERT(A.cols == B.cols); + + // Use float instead of MatTA/MatTB because we decompress to float here. + const size_t N = hn::Lanes(hn::ScalableTag()); + (void)N; + HWY_DASSERT(A.cols % (N * 2) == 0); // For Decompress2. + HWY_DASSERT(C.cols % kRegCols == 0); + // We currently write C directly, which touches more memory than fits in L3. // TODO: add another level of loops to finish L3-sized pieces of C at a time. - const hn::ScalableTag d; - // Use float instead of MatTA/MatTB because we decompress to float here. - const size_t Nf = hn::Lanes(hn::ScalableTag()); - (void)Nf; // For HWY_DASSERT - constexpr size_t kRegRows = 4; // if changing, also update the switch below. - constexpr size_t kRegCols = 4; // in vectors - - HWY_DASSERT(colsA_rowsB % (Nf * 2) == 0); // For Decompress2. - HWY_DASSERT(colsBC % kRegCols == 0); const size_t tilesY = hwy::DivCeil(batch_size, kRegRows); - const size_t tilesX = colsBC / kRegCols; - - const size_t strideA = colsA_rowsB; - const size_t strideB = colsA_rowsB; + const size_t tilesX = C.cols / kRegCols; pool.Run(0, tilesX * tilesY, [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR { + const size_t tx = idx_tile % tilesX; + const size_t ty = idx_tile / tilesX; + const size_t row_a = ty * kRegRows; + const size_t row_b_col_c = tx * kRegCols; // How many rows of C are left to compute. If more than 4, this // tile still only computes 4 rows. - const size_t num_rows = batch_size - idx_tile / tilesX * kRegRows; - HWY_ASSERT(num_rows > 0); + const size_t num_rows = batch_size - row_a; + HWY_DASSERT(num_rows != 0); switch (num_rows) { case 1: - GEMM_4x4_Tile<1, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, - idx_tile, tilesX, colsA_rowsB, strideA, - strideB, C_stride); + MatMulTile<1, kAdd>(A, B, row_a, row_b_col_c, scale, add, C); break; case 2: - GEMM_4x4_Tile<2, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, - idx_tile, tilesX, colsA_rowsB, strideA, - strideB, C_stride); + MatMulTile<2, kAdd>(A, B, row_a, row_b_col_c, scale, add, C); break; case 3: - GEMM_4x4_Tile<3, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, - idx_tile, tilesX, colsA_rowsB, strideA, - strideB, C_stride); + MatMulTile<3, kAdd>(A, B, row_a, row_b_col_c, scale, add, C); break; default: - GEMM_4x4_Tile<4, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, - idx_tile, tilesX, colsA_rowsB, strideA, - strideB, C_stride); + MatMulTile<4, kAdd>(A, B, row_a, row_b_col_c, scale, add, C); } }); } diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 4885347..61204af 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -14,6 +14,7 @@ // limitations under the License. #ifndef HWY_DISABLED_TARGETS +// Exclude HWY_SCALAR due to 2x bf16 -> f32. #define HWY_DISABLED_TARGETS HWY_SCALAR #endif @@ -48,47 +49,10 @@ namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; -template -CompressedArray GenerateMat(size_t offset, - hwy::ThreadPool& pool) { - gcpp::CompressWorkingSet ws; - CompressedArray mat; - std::array content; - const float scale = 1.0f / kInner; - pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) { - for (size_t j = 0; j < kInner; j++) { - content[i * kInner + j] = - static_cast((i * kInner + j + offset) * scale); - } - }); - - Compress(content, ws, mat, pool); - mat.set_scale(1.9f); // Arbitrary value, different from 1. - return mat; -} - -template -CompressedArray GenerateZeroMat(hwy::ThreadPool& pool) { - gcpp::CompressWorkingSet ws; - CompressedArray mat; - std::array content; - - pool.Run(0, kOuter, [&](const size_t i, size_t thread) { - hwy::ZeroBytes(&content[i * kInner], kInner * sizeof(content[0])); - }); - - Compress(content, ws, mat, pool); - mat.set_scale(1.2f); // Arbitrary value, different from 1. - return mat; -} - template std::unique_ptr> GenerateMatHeap( size_t offset, hwy::ThreadPool& pool) { gcpp::CompressWorkingSet ws; - std::unique_ptr> mat = - std::unique_ptr>( - new CompressedArray); hwy::AlignedFreeUniquePtr content = hwy::AllocateAligned(kOuter * kInner); const float scale = 1.875f / (kInner * kOuter + offset); @@ -99,6 +63,8 @@ std::unique_ptr> GenerateMatHeap( } }); + std::unique_ptr> mat = + std::make_unique>(); Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0, pool); mat->set_scale(0.6f); // Arbitrary value, different from 1. @@ -109,9 +75,6 @@ template std::unique_ptr> GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) { gcpp::CompressWorkingSet ws; - std::unique_ptr> mat = - std::unique_ptr>( - new CompressedArray); hwy::AlignedFreeUniquePtr content = hwy::AllocateAligned(kOuter * kInner); const float scale = 1.875f / (kInner * kOuter + offset); @@ -122,6 +85,8 @@ GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) { } }); + std::unique_ptr> mat = + std::make_unique>(); Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0, pool); // Arbitrary value, different from 1, must match GenerateMatHeap. @@ -133,9 +98,6 @@ template std::unique_ptr> GenerateZeroMatHeap( hwy::ThreadPool& pool) { gcpp::CompressWorkingSet ws; - std::unique_ptr> mat = - std::unique_ptr>( - new CompressedArray); hwy::AlignedFreeUniquePtr content = hwy::AllocateAligned(kOuter * kInner); @@ -143,22 +105,14 @@ std::unique_ptr> GenerateZeroMatHeap( hwy::ZeroBytes(&content[i * kInner], kInner * sizeof(content[0])); }); + std::unique_ptr> mat = + std::make_unique>(); Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0, pool); mat->set_scale(1.2f); // Arbitrary value, different from 1. return mat; } -template -hwy::AlignedFreeUniquePtr GenerateVec(size_t offset) { - hwy::AlignedFreeUniquePtr vec = hwy::AllocateAligned(length); - HWY_ASSERT(vec); - for (size_t idx = 0; idx < length; idx++) { - vec[idx] = static_cast(idx + offset); - } - return vec; -} - // A simple matrix multiplication. No optimization / tiling. template hwy::AlignedFreeUniquePtr SimpleMatMul( @@ -179,27 +133,6 @@ hwy::AlignedFreeUniquePtr SimpleMatMul( return out; } -template -hwy::AlignedFreeUniquePtr SimpleMatVecAdd( - const CompressedArray& mat, - const hwy::AlignedFreeUniquePtr& vec, - const hwy::AlignedFreeUniquePtr& add) { - hwy::AlignedFreeUniquePtr uncompressed_mat = - hwy::AllocateAligned(kOuter * kInner); - hwy::AlignedFreeUniquePtr out = hwy::AllocateAligned(kOuter); - HWY_ASSERT(uncompressed_mat && out); - Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner); - MulByConst(mat.scale(), uncompressed_mat.get(), kOuter * kInner); - for (size_t idx_row = 0; idx_row < kOuter; idx_row++) { - out[idx_row] = add[idx_row]; - for (size_t idx_col = 0; idx_col < kInner; idx_col++) { - out[idx_row] += - uncompressed_mat[kInner * idx_row + idx_col] * vec[idx_col]; - } - } - return out; -} - template void AssertClose(const MatT* HWY_RESTRICT expected, const MatT* HWY_RESTRICT actual, size_t num) { @@ -233,8 +166,7 @@ void AssertClose(const hwy::AlignedFreeUniquePtr& a, } } -// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on -// ops_test across instruction sets. +// Largely unoptimized; reordered innermost loops nets ~5-10X speedup. template HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a, @@ -271,92 +203,167 @@ HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a, MatMulSlowBatch(batch_size, a, b.get(), scale, add, out); } -template -void TestTiledBatchMatMul() { - fprintf(stderr, - "TestTiledBatchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", - kM, kN, kK, kAdd, typeid(MatTA).name(), typeid(MatTB).name()); - hwy::ThreadPool pool(3); - std::unique_ptr> a = - GenerateMatHeap(0, pool); - std::unique_ptr> b = - GenerateMatHeap(0, pool); - std::unique_ptr> add = - GenerateMatHeap(0, pool); - add->set_scale(1.0f); - std::unique_ptr> c_slow = - GenerateZeroMatHeap(pool); - const float scale = a->scale() * b->scale(); - - const double start_slow = hwy::platform::Now(); - MatMulSlowBatch(kM, a->data(), b->data(), scale, - kAdd ? add->data() : nullptr, c_slow->data()); - const double slow_matmul_seconds = hwy::platform::Now() - start_slow; - fprintf(stderr, "MatMulSlowBatch took %f seconds.\n", slow_matmul_seconds); - - hwy::AlignedFreeUniquePtr c = hwy::AllocateAligned(kM * kK); - std::unique_ptr> b_trans = - GenerateTransposeMatHeap(0, pool); - - const double start_tiled = hwy::platform::Now(); - EXPECT_EQ(scale, a->scale() * b_trans->scale()); - MatMul_4x4(kM, a->data(), 0, kN, b_trans->data(), 0, kK, scale, c.get(), - kK, add->data(), pool); - const double tiled_matmul_seconds = hwy::platform::Now() - start_tiled; - fprintf(stderr, "MatMul_4x4 took %f seconds.\n", tiled_matmul_seconds); - - AssertClose(c_slow->data(), c.get(), kM * kK); +void PrintSpeed(const char* algo, size_t M, size_t N, size_t K, + double elapsed) { + // * 2 because of FMA. + fprintf(stderr, "%s: %f seconds, %f GFLOPS.\n", algo, elapsed, + 2E-9 * M * N * K / elapsed); } -void TestAllTiledBatchMatMul() { +template +void TestMatMul(hwy::ThreadPool& pool) { + using TraitsA = CompressTraits; + using TraitsB = CompressTraits; + fprintf(stderr, "TestMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", kM, + kN, kK, kAdd, TraitsA::Name(), TraitsB::Name()); + + std::unique_ptr> a = + GenerateMatHeap(0, pool); + std::unique_ptr> b_trans = + GenerateTransposeMatHeap(0, pool); + hwy::AlignedFreeUniquePtr c = hwy::AllocateAligned(kM * kK); + + const float scale = a->scale() * b_trans->scale(); + std::unique_ptr> add; + if (kAdd) { + add = GenerateMatHeap(0, pool); + add->set_scale(1.0f); + } + + std::unique_ptr> c_slow; + const bool compare_slow = kN < 2048; + if (compare_slow) { + std::unique_ptr> b = + GenerateMatHeap(0, pool); + HWY_ASSERT_EQ(scale, a->scale() * b->scale()); + c_slow = GenerateZeroMatHeap(pool); + const double start_slow = hwy::platform::Now(); + MatMulSlowBatch(kM, a->data(), b->data(), scale, + kAdd ? add->data() : nullptr, c_slow->data()); + PrintSpeed("MatMulSlowBatch", kM, kN, kK, + hwy::platform::Now() - start_slow); + } + + double min_elapsed = hwy::HighestValue(); + for (int rep = 0; rep < (compare_slow ? 1 : 3); ++rep) { + const double start_tiled = hwy::platform::Now(); + MatMul_4x4(kM, MakeMat(a->data(), kN), MakeMat(b_trans->data(), kN), + scale, kAdd ? add->data_scale1() : nullptr, + MakeMat(c.get(), kK), pool); + min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled); + } + PrintSpeed("MatMul_4x4", kM, kN, kK, min_elapsed); + + if (compare_slow) { + AssertClose(c_slow->data(), c.get(), kM * kK); + } +} + +void TestAllMatMul() { + // Skip EMU128 (10x slower than SSE4 for SFP) and older x86. + if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSE4 || + HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE2) { + return; + } + + hwy::ThreadPool pool(4); using BF16 = hwy::bfloat16_t; using F32 = float; using SFP = SfpStream; - // medium-sized square test - TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32>(); - TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16>(); - TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>(); - TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>(); - TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>(); - TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>(); - - // minimal non-square test. kK must be at least 2 vectors. - TestTiledBatchMatMul<35, 128, 32, /*kAdd=*/false, F32>(); - TestTiledBatchMatMul<34, 128, 32, /*kAdd=*/true, BF16>(); - TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>(); - TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>(); - TestTiledBatchMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>(); - TestTiledBatchMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>(); - TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32>(); - TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16>(); - TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>(); - TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>(); - TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>(); - TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>(); - TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32>(); - TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16>(); - TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>(); - TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>(); - TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>(); - TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>(); - TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32>(); - TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16>(); - TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>(); - TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>(); - TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>(); - TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>(); - TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32>(); - TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16>(); - TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>(); - TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>(); - TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>(); - TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>(); // large-scale test - // TODO(philculliton): investigate rounding issues with large matrices. - // Causes test timeout. - // TestTiledBatchMatMul<512, 24576, 3072, float>(); + TestMatMul<64, 24576, 3072, /*kAdd=*/false, F32, SFP>(pool); + + // medium-sized square test + TestMatMul<512, 512, 512, /*kAdd=*/false, F32>(pool); + TestMatMul<512, 512, 512, /*kAdd=*/true, BF16>(pool); + TestMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>(pool); + TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>(pool); + TestMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>(pool); + TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>(pool); + + // minimal non-square test. kK must be at least 2 vectors. + TestMatMul<35, 128, 32, /*kAdd=*/false, F32>(pool); + TestMatMul<34, 128, 32, /*kAdd=*/true, BF16>(pool); + TestMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>(pool); + TestMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>(pool); + TestMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>(pool); + TestMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>(pool); + TestMatMul<4, 128, 32, /*kAdd=*/true, F32>(pool); + TestMatMul<4, 128, 32, /*kAdd=*/false, BF16>(pool); + TestMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>(pool); + TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>(pool); + TestMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>(pool); + TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>(pool); + TestMatMul<3, 128, 32, /*kAdd=*/false, F32>(pool); + TestMatMul<3, 128, 32, /*kAdd=*/true, BF16>(pool); + TestMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>(pool); + TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>(pool); + TestMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>(pool); + TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>(pool); + TestMatMul<2, 128, 64, /*kAdd=*/true, F32>(pool); + TestMatMul<2, 128, 64, /*kAdd=*/false, BF16>(pool); + TestMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>(pool); + TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>(pool); + TestMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>(pool); + TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>(pool); + TestMatMul<1, 128, 32, /*kAdd=*/false, F32>(pool); + TestMatMul<1, 128, 32, /*kAdd=*/true, BF16>(pool); + TestMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>(pool); + TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>(pool); + TestMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>(pool); + TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>(pool); +} + +template +hwy::AlignedFreeUniquePtr SimpleMatVecAdd( + const CompressedArray& mat, + const hwy::AlignedFreeUniquePtr& vec, + const hwy::AlignedFreeUniquePtr& add) { + hwy::AlignedFreeUniquePtr uncompressed_mat = + hwy::AllocateAligned(kOuter * kInner); + hwy::AlignedFreeUniquePtr out = hwy::AllocateAligned(kOuter); + HWY_ASSERT(uncompressed_mat && out); + Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner); + MulByConst(mat.scale(), uncompressed_mat.get(), kOuter * kInner); + for (size_t idx_row = 0; idx_row < kOuter; idx_row++) { + out[idx_row] = add[idx_row]; + for (size_t idx_col = 0; idx_col < kInner; idx_col++) { + out[idx_row] += + uncompressed_mat[kInner * idx_row + idx_col] * vec[idx_col]; + } + } + return out; +} + +template +CompressedArray GenerateMat(size_t offset, + hwy::ThreadPool& pool) { + gcpp::CompressWorkingSet ws; + CompressedArray mat; + std::array content; + const float scale = 1.0f / kInner; + pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) { + for (size_t j = 0; j < kInner; j++) { + content[i * kInner + j] = + static_cast((i * kInner + j + offset) * scale); + } + }); + + Compress(content, ws, mat, pool); + mat.set_scale(1.9f); // Arbitrary value, different from 1. + return mat; +} + +template +hwy::AlignedFreeUniquePtr GenerateVec(size_t offset) { + hwy::AlignedFreeUniquePtr vec = hwy::AllocateAligned(length); + HWY_ASSERT(vec); + for (size_t idx = 0; idx < length; idx++) { + vec[idx] = static_cast(idx + offset); + } + return vec; } void TestMatVecAdd() { @@ -441,7 +448,7 @@ HWY_AFTER_NAMESPACE(); namespace gcpp { HWY_BEFORE_TEST(MatmulTest); -HWY_EXPORT_AND_TEST_P(MatmulTest, TestAllTiledBatchMatMul); +HWY_EXPORT_AND_TEST_P(MatmulTest, TestAllMatMul); HWY_EXPORT_AND_TEST_P(MatmulTest, TestMatVecAdd); HWY_EXPORT_AND_TEST_P(MatmulTest, TestTwoMatVecAdd); HWY_EXPORT_AND_TEST_P(MatmulTest, TestTwoOfsMatVecAddLoop);