diff --git a/gemma/ops.h b/gemma/ops.h index 7b29b9f..b29d8e9 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -82,7 +82,7 @@ HWY_INLINE constexpr size_t RowsPerStrip() { } // Shared between f32 and bf16, which also accumulates into f32 vectors. -template > +template > HWY_INLINE void StoreHorizontalSums(DF df, VF c00, VF c01, VF c02, VF c03, VF c10, VF c11, VF c12, VF c13, // VF c20, VF c21, VF c22, VF c23, // @@ -97,16 +97,19 @@ HWY_INLINE void StoreHorizontalSums(DF df, VF c00, VF c01, VF c02, VF c03, tile_c[stride_c * 0 + 1] = hn::ReduceSum(df, c01); tile_c[stride_c * 0 + 2] = hn::ReduceSum(df, c02); tile_c[stride_c * 0 + 3] = hn::ReduceSum(df, c03); + if (kNumRows == 1) return; tile_c[stride_c * 1 + 0] = hn::ReduceSum(df, c10); tile_c[stride_c * 1 + 1] = hn::ReduceSum(df, c11); tile_c[stride_c * 1 + 2] = hn::ReduceSum(df, c12); tile_c[stride_c * 1 + 3] = hn::ReduceSum(df, c13); + if (kNumRows == 2) return; tile_c[stride_c * 2 + 0] = hn::ReduceSum(df, c20); tile_c[stride_c * 2 + 1] = hn::ReduceSum(df, c21); tile_c[stride_c * 2 + 2] = hn::ReduceSum(df, c22); tile_c[stride_c * 2 + 3] = hn::ReduceSum(df, c23); + if (kNumRows == 3) return; tile_c[stride_c * 3 + 0] = hn::ReduceSum(df, c30); tile_c[stride_c * 3 + 1] = hn::ReduceSum(df, c31); @@ -114,10 +117,10 @@ HWY_INLINE void StoreHorizontalSums(DF df, VF c00, VF c01, VF c02, VF c03, tile_c[stride_c * 3 + 3] = hn::ReduceSum(df, c33); } -// Accumulates a single 4x4 tile of A x B into C. B is transposed, so we can -// iterate over both A and B with consecutive vector loads. +// Accumulates a single kNumRowsx4 tile of A x B into C. B is transposed, so we +// can iterate over both A and B with consecutive vector loads. kNumRows<=4. // Shared between parallelized and sequential (loop) callers. -template +template HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B, MatT* HWY_RESTRICT C, const size_t idx_tile, const size_t xtiles, @@ -129,6 +132,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, // threads write in units of 4*N elements, which is at least one cache line. constexpr size_t kRegRows = 4; constexpr size_t kRegCols = 4; + static_assert(kNumRows <= kRegRows); // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. const size_t row_a = idx_tile / xtiles * kRegRows; @@ -175,18 +179,21 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, c01 = hn::MulAdd(a0, b1, c01); c02 = hn::MulAdd(a0, b2, c02); c03 = hn::MulAdd(a0, b3, c03); + if (kNumRows == 1) continue; const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab); c10 = hn::MulAdd(a1, b0, c10); c11 = hn::MulAdd(a1, b1, c11); c12 = hn::MulAdd(a1, b2, c12); c13 = hn::MulAdd(a1, b3, c13); + if (kNumRows == 2) continue; const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab); c20 = hn::MulAdd(a2, b0, c20); c21 = hn::MulAdd(a2, b1, c21); c22 = hn::MulAdd(a2, b2, c22); c23 = hn::MulAdd(a2, b3, c23); + if (kNumRows == 3) continue; const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab); c30 = hn::MulAdd(a3, b0, c30); @@ -196,8 +203,9 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, } float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; - StoreHorizontalSums(d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, - c23, c30, c31, c32, c33, tile_c, stride_c); + StoreHorizontalSums( + d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, + c30, c31, c32, c33, tile_c, stride_c); } #undef GEMMA_NATIVE_BF16 @@ -209,7 +217,8 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, #endif // As above, for MatT=bf16 -template +template HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B, float* HWY_RESTRICT C, const size_t idx_tile, const size_t xtiles, @@ -217,6 +226,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, const size_t stride_c) { constexpr size_t kRegRows = 4; constexpr size_t kRegCols = 4; + static_assert(kNumRows <= kRegRows); // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. const size_t row_a = idx_tile / xtiles * kRegRows; @@ -277,18 +287,21 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, c01 = hn::ReorderWidenMulAccumulate(df, a0, b1, c01, unused_sum1); c02 = hn::ReorderWidenMulAccumulate(df, a0, b2, c02, unused_sum1); c03 = hn::ReorderWidenMulAccumulate(df, a0, b3, c03, unused_sum1); + if (kNumRows == 1) continue; const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab); c10 = hn::ReorderWidenMulAccumulate(df, a1, b0, c10, unused_sum1); c11 = hn::ReorderWidenMulAccumulate(df, a1, b1, c11, unused_sum1); c12 = hn::ReorderWidenMulAccumulate(df, a1, b2, c12, unused_sum1); c13 = hn::ReorderWidenMulAccumulate(df, a1, b3, c13, unused_sum1); + if (kNumRows == 2) continue; const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab); c20 = hn::ReorderWidenMulAccumulate(df, a2, b0, c20, unused_sum1); c21 = hn::ReorderWidenMulAccumulate(df, a2, b1, c21, unused_sum1); c22 = hn::ReorderWidenMulAccumulate(df, a2, b2, c22, unused_sum1); c23 = hn::ReorderWidenMulAccumulate(df, a2, b3, c23, unused_sum1); + if (kNumRows == 3) continue; const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab); c30 = hn::ReorderWidenMulAccumulate(df, a3, b0, c30, unused_sum1); @@ -311,6 +324,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, c01 = hn::MulAdd(a0, b1, c01); c02 = hn::MulAdd(a0, b2, c02); c03 = hn::MulAdd(a0, b3, c03); + if (kNumRows == 1) continue; const VF a1 = hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 1 + col_ab)); @@ -318,6 +332,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, c11 = hn::MulAdd(a1, b1, c11); c12 = hn::MulAdd(a1, b2, c12); c13 = hn::MulAdd(a1, b3, c13); + if (kNumRows == 2) continue; const VF a2 = hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 2 + col_ab)); @@ -325,6 +340,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, c21 = hn::MulAdd(a2, b1, c21); c22 = hn::MulAdd(a2, b2, c22); c23 = hn::MulAdd(a2, b3, c23); + if (kNumRows == 3) continue; const VF a3 = hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 3 + col_ab)); @@ -341,12 +357,14 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, #endif float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; - StoreHorizontalSums(df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, - c23, c30, c31, c32, c33, tile_c, stride_c); + StoreHorizontalSums( + df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, + c23, c30, c31, c32, c33, tile_c, stride_c); } // Same as above, but with mixed Mat types: (f32, sfp)). -template +template HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const SfpStream* HWY_RESTRICT B, float* HWY_RESTRICT C, const size_t idx_tile, @@ -354,6 +372,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const size_t stride_b, const size_t stride_c) { constexpr size_t kRegRows = 4; constexpr size_t kRegCols = 4; + static_assert(kNumRows <= kRegRows); // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. const size_t row_a = idx_tile / xtiles * kRegRows; @@ -407,18 +426,21 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, c01 = hn::MulAdd(a0, b1, c01); c02 = hn::MulAdd(a0, b2, c02); c03 = hn::MulAdd(a0, b3, c03); + if (kNumRows == 1) continue; const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab); c10 = hn::MulAdd(a1, b0, c10); c11 = hn::MulAdd(a1, b1, c11); c12 = hn::MulAdd(a1, b2, c12); c13 = hn::MulAdd(a1, b3, c13); + if (kNumRows == 2) continue; const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab); c20 = hn::MulAdd(a2, b0, c20); c21 = hn::MulAdd(a2, b1, c21); c22 = hn::MulAdd(a2, b2, c22); c23 = hn::MulAdd(a2, b3, c23); + if (kNumRows == 3) continue; const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab); c30 = hn::MulAdd(a3, b0, c30); @@ -428,12 +450,14 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, } float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; - StoreHorizontalSums(d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, - c23, c30, c31, c32, c33, tile_c, stride_c); + StoreHorizontalSums( + d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, + c23, c30, c31, c32, c33, tile_c, stride_c); } // Same as above, but with mixed Mat types: (bf16, sfp)). -template +template HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const SfpStream* HWY_RESTRICT B, float* HWY_RESTRICT C, const size_t idx_tile, @@ -441,6 +465,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const size_t stride_b, const size_t stride_c) { constexpr size_t kRegRows = 4; constexpr size_t kRegCols = 4; + static_assert(kNumRows <= kRegRows); // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. const size_t row_a = idx_tile / xtiles * kRegRows; @@ -500,6 +525,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, c01 = hn::MulAdd(a0, b1, c01); c02 = hn::MulAdd(a0, b2, c02); c03 = hn::MulAdd(a0, b3, c03); + if (kNumRows == 1) continue; const V a1 = hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab)); @@ -507,6 +533,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, c11 = hn::MulAdd(a1, b1, c11); c12 = hn::MulAdd(a1, b2, c12); c13 = hn::MulAdd(a1, b3, c13); + if (kNumRows == 2) continue; const V a2 = hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab)); @@ -514,6 +541,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, c21 = hn::MulAdd(a2, b1, c21); c22 = hn::MulAdd(a2, b2, c22); c23 = hn::MulAdd(a2, b3, c23); + if (kNumRows == 3) continue; const V a3 = hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab)); @@ -524,12 +552,14 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, } float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; - StoreHorizontalSums(d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, - c22, c23, c30, c31, c32, c33, tile_c, stride_c); + StoreHorizontalSums( + d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, + c22, c23, c30, c31, c32, c33, tile_c, stride_c); } // Same as above, but with mixed Mat types: (f32, bf16). -template HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B, @@ -538,6 +568,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const size_t stride_b, const size_t stride_c) { constexpr size_t kRegRows = 4; constexpr size_t kRegCols = 4; + static_assert(kNumRows <= kRegRows); // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. const size_t row_a = idx_tile / xtiles * kRegRows; @@ -596,18 +627,21 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, c01 = hn::MulAdd(a0, b1, c01); c02 = hn::MulAdd(a0, b2, c02); c03 = hn::MulAdd(a0, b3, c03); + if (kNumRows == 1) continue; const VF a1 = hn::LoadU(d32, tile_a + stride_a * 1 + col_ab); c10 = hn::MulAdd(a1, b0, c10); c11 = hn::MulAdd(a1, b1, c11); c12 = hn::MulAdd(a1, b2, c12); c13 = hn::MulAdd(a1, b3, c13); + if (kNumRows == 2) continue; const VF a2 = hn::LoadU(d32, tile_a + stride_a * 2 + col_ab); c20 = hn::MulAdd(a2, b0, c20); c21 = hn::MulAdd(a2, b1, c21); c22 = hn::MulAdd(a2, b2, c22); c23 = hn::MulAdd(a2, b3, c23); + if (kNumRows == 3) continue; const VF a3 = hn::LoadU(d32, tile_a + stride_a * 3 + col_ab); c30 = hn::MulAdd(a3, b0, c30); @@ -617,8 +651,9 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, } float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; - StoreHorizontalSums(d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, - c22, c23, c30, c31, c32, c33, tile_c, stride_c); + StoreHorizontalSums( + d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, + c22, c23, c30, c31, c32, c33, tile_c, stride_c); } // Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and @@ -646,8 +681,8 @@ void GEMM_4x4_Static(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B, HWY_UNROLL(1) for (size_t idx_tile = 0; idx_tile < kTiles; ++idx_tile) { - GEMM_4x4_Tile(A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, - kStrideC); + GEMM_4x4_Tile( + A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC); } } @@ -680,8 +715,59 @@ HWY_NOINLINE void MatMul_4x4(const MatTA* HWY_RESTRICT A, pool.Run(0, kTiles, [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR { // Computes the finished product of one 4x4N tile and writes to C. - GEMM_4x4_Tile(A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, - kStrideC); + GEMM_4x4_Tile( + A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC); + }); +} + +// Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k, +// and kColsBC is 24k or 3k. Note: B is transposed (column-major). +// NOTE that batch_size is the number of rows of A and C. +// This function processes tiles in parallel with a work-stealing thread pool. +template +HWY_NOINLINE void MatMul_4x4_Batch( + size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B, + OutT* HWY_RESTRICT C, hwy::ThreadPool& pool) { + // Process reg-sized tiles of C in parallel. We currently write C directly, + // which touches more memory than fits in L3. TODO: add another level of loops + // so that we finish one L3-sized piece of C at a time. + const hn::ScalableTag d; + const size_t N = Lanes(d); + constexpr size_t kRegRows = 4; + constexpr size_t kRegCols = 4; // in vectors + + static_assert(kColsBC % kRegCols == 0); + HWY_ASSERT(kColsA_RowsB % (N * kRegCols) == 0); + const size_t kTilesY = (batch_size + kRegRows - 1) / kRegRows; + const size_t kTilesX = kColsBC / kRegCols; + const size_t kTiles = kTilesX * kTilesY; + + constexpr size_t kStrideA = kColsA_RowsB; + constexpr size_t kStrideB = kColsA_RowsB; + constexpr size_t kStrideC = kColsBC; + + pool.Run(0, kTiles, [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR { + // Computes the finished product of one 4x4N tile and writes to C. + const size_t num_rows = batch_size - idx_tile / kTilesX * kRegRows; + HWY_ASSERT(num_rows > 0); + switch (num_rows) { + case 1: + GEMM_4x4_Tile<1, kColsA_RowsB>( + A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC); + break; + case 2: + GEMM_4x4_Tile<2, kColsA_RowsB>( + A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC); + break; + case 3: + GEMM_4x4_Tile<3, kColsA_RowsB>( + A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC); + break; + default: + GEMM_4x4_Tile<4, kColsA_RowsB>( + A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC); + } }); } @@ -714,6 +800,35 @@ HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a, MatMulSlow(a, b.get(), out); } +// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on +// ops_test across instruction sets. +template +HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a, + const MatTB* HWY_RESTRICT b, + float* HWY_RESTRICT out) { + for (size_t i = 0; i < batch_size; ++i) { + for (size_t k = 0; k < kN; ++k) { + for (size_t j = 0; j < kK; ++j) { + const float a1 = hwy::ConvertScalarTo(a[i * kN + k]); + const float b1 = hwy::ConvertScalarTo(b[k * kK + j]); + out[i * kK + j] += a1 * b1; + } + } + } +} + +template +HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a, + const SfpStream* HWY_RESTRICT b_sfp_stream, + float* HWY_RESTRICT out) { + const hn::ScalableTag d; + hwy::AlignedFreeUniquePtr b = hwy::AllocateAligned(kK * kN); + CompressTraits::Decompress(d, + /*in_capacity=*/0, b_sfp_stream, 0, + b.get(), kK * kN); + MatMulSlowBatch(batch_size, a, b.get(), out); +} + HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned, const size_t size, float* HWY_RESTRICT out) { const hn::ScalableTag df; diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index 9e0d4ec..f027364 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -538,8 +538,12 @@ void TestTiledMatMul() { GenerateMatHeap(0, pool); std::unique_ptr> c_slow = GenerateZeroMatHeap(pool); + std::unique_ptr> c_slow_batch = + GenerateZeroMatHeap(pool); MatMulSlow(a->data(), b->data(), c_slow->data()); + MatMulSlowBatch(kM, a->data(), b->data(), c_slow_batch->data()); + AssertClose(c_slow->data(), c_slow_batch->data(), kM * kK); hwy::AlignedFreeUniquePtr c = hwy::AllocateAligned(kM * kK); std::unique_ptr> b_trans = @@ -565,8 +569,66 @@ void TestAllTiledMatMul() { TestTiledMatMul<32, 128, 32, hwy::bfloat16_t, SfpStream>(); // large-scale test - // TODO(philculliton): investigate rounding issues with large matrices - TestTiledMatMul<512, 24576, 3072, float>(); + // TODO(philculliton): investigate rounding issues with large matrices. + // Causes test timeout. + // TestTiledMatMul<512, 24576, 3072, float>(); +} + +template +void TestTiledBatchMatMul() { + hwy::ThreadPool pool(3); + std::unique_ptr> a = + GenerateMatHeap(0, pool); + std::unique_ptr> b = + GenerateMatHeap(0, pool); + std::unique_ptr> c_slow = + GenerateZeroMatHeap(pool); + + MatMulSlowBatch(kM, a->data(), b->data(), c_slow->data()); + + hwy::AlignedFreeUniquePtr c = hwy::AllocateAligned(kM * kK); + std::unique_ptr> b_trans = + GenerateTransposeMatHeap(0, pool); + MatMul_4x4_Batch(kM, a->data(), b_trans->data(), c.get(), pool); + + AssertClose(c_slow->data(), c.get(), kM * kK); +} + +void TestAllTiledBatchMatMul() { + // medium-sized square test + TestTiledBatchMatMul<512, 512, 512, float>(); + TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t>(); + TestTiledBatchMatMul<512, 512, 512, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<512, 512, 512, float, SfpStream>(); + TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t, SfpStream>(); + + // minimal non-square test + TestTiledBatchMatMul<35, 128, 4, float>(); + TestTiledBatchMatMul<34, 128, 4, hwy::bfloat16_t>(); + TestTiledBatchMatMul<33, 128, 4, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<31, 128, 32, float, SfpStream>(); + TestTiledBatchMatMul<29, 128, 32, hwy::bfloat16_t, SfpStream>(); + TestTiledBatchMatMul<4, 128, 4, float>(); + TestTiledBatchMatMul<4, 128, 4, hwy::bfloat16_t>(); + TestTiledBatchMatMul<4, 128, 4, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<4, 128, 32, float, SfpStream>(); + TestTiledBatchMatMul<4, 128, 32, hwy::bfloat16_t, SfpStream>(); + TestTiledBatchMatMul<3, 128, 4, float>(); + TestTiledBatchMatMul<3, 128, 4, hwy::bfloat16_t>(); + TestTiledBatchMatMul<3, 128, 4, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<3, 128, 32, float, SfpStream>(); + TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t, SfpStream>(); + TestTiledBatchMatMul<2, 128, 4, float>(); + TestTiledBatchMatMul<2, 128, 4, hwy::bfloat16_t>(); + TestTiledBatchMatMul<2, 128, 4, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<2, 128, 32, float, SfpStream>(); + TestTiledBatchMatMul<2, 128, 32, hwy::bfloat16_t, SfpStream>(); + TestTiledBatchMatMul<1, 128, 4, float>(); + TestTiledBatchMatMul<1, 128, 4, hwy::bfloat16_t>(); + TestTiledBatchMatMul<1, 128, 4, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<1, 128, 32, float, SfpStream>(); + TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, SfpStream>(); } void TestMatVecAdd() { @@ -675,6 +737,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledBatchMatMul); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledMatMul); HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);