Added MatMul_4x4_Batch which is MatMul_4x4, but with the first template arg moved to the first function arg, so the batch size (num A rows) can be variable at run-time.

PiperOrigin-RevId: 643017973
This commit is contained in:
Ray Smith 2024-06-13 09:05:01 -07:00 committed by Copybara-Service
parent 1b40619864
commit ea525da967
2 changed files with 202 additions and 24 deletions

View File

@ -82,7 +82,7 @@ HWY_INLINE constexpr size_t RowsPerStrip() {
} }
// Shared between f32 and bf16, which also accumulates into f32 vectors. // Shared between f32 and bf16, which also accumulates into f32 vectors.
template <class DF, class VF = hn::Vec<DF>> template <size_t kNumRows, class DF, class VF = hn::Vec<DF>>
HWY_INLINE void StoreHorizontalSums(DF df, VF c00, VF c01, VF c02, VF c03, HWY_INLINE void StoreHorizontalSums(DF df, VF c00, VF c01, VF c02, VF c03,
VF c10, VF c11, VF c12, VF c13, // VF c10, VF c11, VF c12, VF c13, //
VF c20, VF c21, VF c22, VF c23, // VF c20, VF c21, VF c22, VF c23, //
@ -97,16 +97,19 @@ HWY_INLINE void StoreHorizontalSums(DF df, VF c00, VF c01, VF c02, VF c03,
tile_c[stride_c * 0 + 1] = hn::ReduceSum(df, c01); tile_c[stride_c * 0 + 1] = hn::ReduceSum(df, c01);
tile_c[stride_c * 0 + 2] = hn::ReduceSum(df, c02); tile_c[stride_c * 0 + 2] = hn::ReduceSum(df, c02);
tile_c[stride_c * 0 + 3] = hn::ReduceSum(df, c03); tile_c[stride_c * 0 + 3] = hn::ReduceSum(df, c03);
if (kNumRows == 1) return;
tile_c[stride_c * 1 + 0] = hn::ReduceSum(df, c10); tile_c[stride_c * 1 + 0] = hn::ReduceSum(df, c10);
tile_c[stride_c * 1 + 1] = hn::ReduceSum(df, c11); tile_c[stride_c * 1 + 1] = hn::ReduceSum(df, c11);
tile_c[stride_c * 1 + 2] = hn::ReduceSum(df, c12); tile_c[stride_c * 1 + 2] = hn::ReduceSum(df, c12);
tile_c[stride_c * 1 + 3] = hn::ReduceSum(df, c13); tile_c[stride_c * 1 + 3] = hn::ReduceSum(df, c13);
if (kNumRows == 2) return;
tile_c[stride_c * 2 + 0] = hn::ReduceSum(df, c20); tile_c[stride_c * 2 + 0] = hn::ReduceSum(df, c20);
tile_c[stride_c * 2 + 1] = hn::ReduceSum(df, c21); tile_c[stride_c * 2 + 1] = hn::ReduceSum(df, c21);
tile_c[stride_c * 2 + 2] = hn::ReduceSum(df, c22); tile_c[stride_c * 2 + 2] = hn::ReduceSum(df, c22);
tile_c[stride_c * 2 + 3] = hn::ReduceSum(df, c23); tile_c[stride_c * 2 + 3] = hn::ReduceSum(df, c23);
if (kNumRows == 3) return;
tile_c[stride_c * 3 + 0] = hn::ReduceSum(df, c30); tile_c[stride_c * 3 + 0] = hn::ReduceSum(df, c30);
tile_c[stride_c * 3 + 1] = hn::ReduceSum(df, c31); tile_c[stride_c * 3 + 1] = hn::ReduceSum(df, c31);
@ -114,10 +117,10 @@ HWY_INLINE void StoreHorizontalSums(DF df, VF c00, VF c01, VF c02, VF c03,
tile_c[stride_c * 3 + 3] = hn::ReduceSum(df, c33); tile_c[stride_c * 3 + 3] = hn::ReduceSum(df, c33);
} }
// Accumulates a single 4x4 tile of A x B into C. B is transposed, so we can // Accumulates a single kNumRowsx4 tile of A x B into C. B is transposed, so we
// iterate over both A and B with consecutive vector loads. // can iterate over both A and B with consecutive vector loads. kNumRows<=4.
// Shared between parallelized and sequential (loop) callers. // Shared between parallelized and sequential (loop) callers.
template <size_t kColsA_RowsB, typename MatT, HWY_IF_F32(MatT)> template <size_t kNumRows, size_t kColsA_RowsB, typename MatT, HWY_IF_F32(MatT)>
HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
const MatT* HWY_RESTRICT B, MatT* HWY_RESTRICT C, const MatT* HWY_RESTRICT B, MatT* HWY_RESTRICT C,
const size_t idx_tile, const size_t xtiles, const size_t idx_tile, const size_t xtiles,
@ -129,6 +132,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
// threads write in units of 4*N elements, which is at least one cache line. // threads write in units of 4*N elements, which is at least one cache line.
constexpr size_t kRegRows = 4; constexpr size_t kRegRows = 4;
constexpr size_t kRegCols = 4; constexpr size_t kRegCols = 4;
static_assert(kNumRows <= kRegRows);
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
const size_t row_a = idx_tile / xtiles * kRegRows; const size_t row_a = idx_tile / xtiles * kRegRows;
@ -175,18 +179,21 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
c01 = hn::MulAdd(a0, b1, c01); c01 = hn::MulAdd(a0, b1, c01);
c02 = hn::MulAdd(a0, b2, c02); c02 = hn::MulAdd(a0, b2, c02);
c03 = hn::MulAdd(a0, b3, c03); c03 = hn::MulAdd(a0, b3, c03);
if (kNumRows == 1) continue;
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab); const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
c10 = hn::MulAdd(a1, b0, c10); c10 = hn::MulAdd(a1, b0, c10);
c11 = hn::MulAdd(a1, b1, c11); c11 = hn::MulAdd(a1, b1, c11);
c12 = hn::MulAdd(a1, b2, c12); c12 = hn::MulAdd(a1, b2, c12);
c13 = hn::MulAdd(a1, b3, c13); c13 = hn::MulAdd(a1, b3, c13);
if (kNumRows == 2) continue;
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab); const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
c20 = hn::MulAdd(a2, b0, c20); c20 = hn::MulAdd(a2, b0, c20);
c21 = hn::MulAdd(a2, b1, c21); c21 = hn::MulAdd(a2, b1, c21);
c22 = hn::MulAdd(a2, b2, c22); c22 = hn::MulAdd(a2, b2, c22);
c23 = hn::MulAdd(a2, b3, c23); c23 = hn::MulAdd(a2, b3, c23);
if (kNumRows == 3) continue;
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab); const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
c30 = hn::MulAdd(a3, b0, c30); c30 = hn::MulAdd(a3, b0, c30);
@ -196,8 +203,9 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
} }
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
StoreHorizontalSums(d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, StoreHorizontalSums<kNumRows>(
c23, c30, c31, c32, c33, tile_c, stride_c); d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
c30, c31, c32, c33, tile_c, stride_c);
} }
#undef GEMMA_NATIVE_BF16 #undef GEMMA_NATIVE_BF16
@ -209,7 +217,8 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
#endif #endif
// As above, for MatT=bf16 // As above, for MatT=bf16
template <size_t kColsA_RowsB, typename MatT, HWY_IF_BF16(MatT)> template <size_t kNumRows, size_t kColsA_RowsB, typename MatT,
HWY_IF_BF16(MatT)>
HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
const MatT* HWY_RESTRICT B, float* HWY_RESTRICT C, const MatT* HWY_RESTRICT B, float* HWY_RESTRICT C,
const size_t idx_tile, const size_t xtiles, const size_t idx_tile, const size_t xtiles,
@ -217,6 +226,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
const size_t stride_c) { const size_t stride_c) {
constexpr size_t kRegRows = 4; constexpr size_t kRegRows = 4;
constexpr size_t kRegCols = 4; constexpr size_t kRegCols = 4;
static_assert(kNumRows <= kRegRows);
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
const size_t row_a = idx_tile / xtiles * kRegRows; const size_t row_a = idx_tile / xtiles * kRegRows;
@ -277,18 +287,21 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
c01 = hn::ReorderWidenMulAccumulate(df, a0, b1, c01, unused_sum1); c01 = hn::ReorderWidenMulAccumulate(df, a0, b1, c01, unused_sum1);
c02 = hn::ReorderWidenMulAccumulate(df, a0, b2, c02, unused_sum1); c02 = hn::ReorderWidenMulAccumulate(df, a0, b2, c02, unused_sum1);
c03 = hn::ReorderWidenMulAccumulate(df, a0, b3, c03, unused_sum1); c03 = hn::ReorderWidenMulAccumulate(df, a0, b3, c03, unused_sum1);
if (kNumRows == 1) continue;
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab); const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
c10 = hn::ReorderWidenMulAccumulate(df, a1, b0, c10, unused_sum1); c10 = hn::ReorderWidenMulAccumulate(df, a1, b0, c10, unused_sum1);
c11 = hn::ReorderWidenMulAccumulate(df, a1, b1, c11, unused_sum1); c11 = hn::ReorderWidenMulAccumulate(df, a1, b1, c11, unused_sum1);
c12 = hn::ReorderWidenMulAccumulate(df, a1, b2, c12, unused_sum1); c12 = hn::ReorderWidenMulAccumulate(df, a1, b2, c12, unused_sum1);
c13 = hn::ReorderWidenMulAccumulate(df, a1, b3, c13, unused_sum1); c13 = hn::ReorderWidenMulAccumulate(df, a1, b3, c13, unused_sum1);
if (kNumRows == 2) continue;
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab); const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
c20 = hn::ReorderWidenMulAccumulate(df, a2, b0, c20, unused_sum1); c20 = hn::ReorderWidenMulAccumulate(df, a2, b0, c20, unused_sum1);
c21 = hn::ReorderWidenMulAccumulate(df, a2, b1, c21, unused_sum1); c21 = hn::ReorderWidenMulAccumulate(df, a2, b1, c21, unused_sum1);
c22 = hn::ReorderWidenMulAccumulate(df, a2, b2, c22, unused_sum1); c22 = hn::ReorderWidenMulAccumulate(df, a2, b2, c22, unused_sum1);
c23 = hn::ReorderWidenMulAccumulate(df, a2, b3, c23, unused_sum1); c23 = hn::ReorderWidenMulAccumulate(df, a2, b3, c23, unused_sum1);
if (kNumRows == 3) continue;
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab); const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
c30 = hn::ReorderWidenMulAccumulate(df, a3, b0, c30, unused_sum1); c30 = hn::ReorderWidenMulAccumulate(df, a3, b0, c30, unused_sum1);
@ -311,6 +324,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
c01 = hn::MulAdd(a0, b1, c01); c01 = hn::MulAdd(a0, b1, c01);
c02 = hn::MulAdd(a0, b2, c02); c02 = hn::MulAdd(a0, b2, c02);
c03 = hn::MulAdd(a0, b3, c03); c03 = hn::MulAdd(a0, b3, c03);
if (kNumRows == 1) continue;
const VF a1 = const VF a1 =
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 1 + col_ab)); hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 1 + col_ab));
@ -318,6 +332,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
c11 = hn::MulAdd(a1, b1, c11); c11 = hn::MulAdd(a1, b1, c11);
c12 = hn::MulAdd(a1, b2, c12); c12 = hn::MulAdd(a1, b2, c12);
c13 = hn::MulAdd(a1, b3, c13); c13 = hn::MulAdd(a1, b3, c13);
if (kNumRows == 2) continue;
const VF a2 = const VF a2 =
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 2 + col_ab)); hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 2 + col_ab));
@ -325,6 +340,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
c21 = hn::MulAdd(a2, b1, c21); c21 = hn::MulAdd(a2, b1, c21);
c22 = hn::MulAdd(a2, b2, c22); c22 = hn::MulAdd(a2, b2, c22);
c23 = hn::MulAdd(a2, b3, c23); c23 = hn::MulAdd(a2, b3, c23);
if (kNumRows == 3) continue;
const VF a3 = const VF a3 =
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 3 + col_ab)); hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 3 + col_ab));
@ -341,12 +357,14 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
#endif #endif
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
StoreHorizontalSums(df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, StoreHorizontalSums<kNumRows>(
c23, c30, c31, c32, c33, tile_c, stride_c); df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22,
c23, c30, c31, c32, c33, tile_c, stride_c);
} }
// Same as above, but with mixed Mat types: (f32, sfp)). // Same as above, but with mixed Mat types: (f32, sfp)).
template <size_t kColsA_RowsB, typename MatTA, HWY_IF_F32(MatTA)> template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
HWY_IF_F32(MatTA)>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const SfpStream* HWY_RESTRICT B, const SfpStream* HWY_RESTRICT B,
float* HWY_RESTRICT C, const size_t idx_tile, float* HWY_RESTRICT C, const size_t idx_tile,
@ -354,6 +372,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const size_t stride_b, const size_t stride_c) { const size_t stride_b, const size_t stride_c) {
constexpr size_t kRegRows = 4; constexpr size_t kRegRows = 4;
constexpr size_t kRegCols = 4; constexpr size_t kRegCols = 4;
static_assert(kNumRows <= kRegRows);
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
const size_t row_a = idx_tile / xtiles * kRegRows; const size_t row_a = idx_tile / xtiles * kRegRows;
@ -407,18 +426,21 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
c01 = hn::MulAdd(a0, b1, c01); c01 = hn::MulAdd(a0, b1, c01);
c02 = hn::MulAdd(a0, b2, c02); c02 = hn::MulAdd(a0, b2, c02);
c03 = hn::MulAdd(a0, b3, c03); c03 = hn::MulAdd(a0, b3, c03);
if (kNumRows == 1) continue;
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab); const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
c10 = hn::MulAdd(a1, b0, c10); c10 = hn::MulAdd(a1, b0, c10);
c11 = hn::MulAdd(a1, b1, c11); c11 = hn::MulAdd(a1, b1, c11);
c12 = hn::MulAdd(a1, b2, c12); c12 = hn::MulAdd(a1, b2, c12);
c13 = hn::MulAdd(a1, b3, c13); c13 = hn::MulAdd(a1, b3, c13);
if (kNumRows == 2) continue;
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab); const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
c20 = hn::MulAdd(a2, b0, c20); c20 = hn::MulAdd(a2, b0, c20);
c21 = hn::MulAdd(a2, b1, c21); c21 = hn::MulAdd(a2, b1, c21);
c22 = hn::MulAdd(a2, b2, c22); c22 = hn::MulAdd(a2, b2, c22);
c23 = hn::MulAdd(a2, b3, c23); c23 = hn::MulAdd(a2, b3, c23);
if (kNumRows == 3) continue;
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab); const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
c30 = hn::MulAdd(a3, b0, c30); c30 = hn::MulAdd(a3, b0, c30);
@ -428,12 +450,14 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
} }
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
StoreHorizontalSums(d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, StoreHorizontalSums<kNumRows>(
c23, c30, c31, c32, c33, tile_c, stride_c); d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22,
c23, c30, c31, c32, c33, tile_c, stride_c);
} }
// Same as above, but with mixed Mat types: (bf16, sfp)). // Same as above, but with mixed Mat types: (bf16, sfp)).
template <size_t kColsA_RowsB, typename MatTA, HWY_IF_BF16(MatTA)> template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
HWY_IF_BF16(MatTA)>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const SfpStream* HWY_RESTRICT B, const SfpStream* HWY_RESTRICT B,
float* HWY_RESTRICT C, const size_t idx_tile, float* HWY_RESTRICT C, const size_t idx_tile,
@ -441,6 +465,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const size_t stride_b, const size_t stride_c) { const size_t stride_b, const size_t stride_c) {
constexpr size_t kRegRows = 4; constexpr size_t kRegRows = 4;
constexpr size_t kRegCols = 4; constexpr size_t kRegCols = 4;
static_assert(kNumRows <= kRegRows);
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
const size_t row_a = idx_tile / xtiles * kRegRows; const size_t row_a = idx_tile / xtiles * kRegRows;
@ -500,6 +525,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
c01 = hn::MulAdd(a0, b1, c01); c01 = hn::MulAdd(a0, b1, c01);
c02 = hn::MulAdd(a0, b2, c02); c02 = hn::MulAdd(a0, b2, c02);
c03 = hn::MulAdd(a0, b3, c03); c03 = hn::MulAdd(a0, b3, c03);
if (kNumRows == 1) continue;
const V a1 = const V a1 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab)); hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab));
@ -507,6 +533,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
c11 = hn::MulAdd(a1, b1, c11); c11 = hn::MulAdd(a1, b1, c11);
c12 = hn::MulAdd(a1, b2, c12); c12 = hn::MulAdd(a1, b2, c12);
c13 = hn::MulAdd(a1, b3, c13); c13 = hn::MulAdd(a1, b3, c13);
if (kNumRows == 2) continue;
const V a2 = const V a2 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab)); hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab));
@ -514,6 +541,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
c21 = hn::MulAdd(a2, b1, c21); c21 = hn::MulAdd(a2, b1, c21);
c22 = hn::MulAdd(a2, b2, c22); c22 = hn::MulAdd(a2, b2, c22);
c23 = hn::MulAdd(a2, b3, c23); c23 = hn::MulAdd(a2, b3, c23);
if (kNumRows == 3) continue;
const V a3 = const V a3 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab)); hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab));
@ -524,12 +552,14 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
} }
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
StoreHorizontalSums(d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, StoreHorizontalSums<kNumRows>(
c22, c23, c30, c31, c32, c33, tile_c, stride_c); d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21,
c22, c23, c30, c31, c32, c33, tile_c, stride_c);
} }
// Same as above, but with mixed Mat types: (f32, bf16). // Same as above, but with mixed Mat types: (f32, bf16).
template <size_t kColsA_RowsB, typename MatTA, HWY_IF_F32(MatTA), template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
HWY_IF_F32(MatTA),
typename MatTB, HWY_IF_BF16(MatTB)> typename MatTB, HWY_IF_BF16(MatTB)>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const MatTB* HWY_RESTRICT B, const MatTB* HWY_RESTRICT B,
@ -538,6 +568,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const size_t stride_b, const size_t stride_c) { const size_t stride_b, const size_t stride_c) {
constexpr size_t kRegRows = 4; constexpr size_t kRegRows = 4;
constexpr size_t kRegCols = 4; constexpr size_t kRegCols = 4;
static_assert(kNumRows <= kRegRows);
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
const size_t row_a = idx_tile / xtiles * kRegRows; const size_t row_a = idx_tile / xtiles * kRegRows;
@ -596,18 +627,21 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
c01 = hn::MulAdd(a0, b1, c01); c01 = hn::MulAdd(a0, b1, c01);
c02 = hn::MulAdd(a0, b2, c02); c02 = hn::MulAdd(a0, b2, c02);
c03 = hn::MulAdd(a0, b3, c03); c03 = hn::MulAdd(a0, b3, c03);
if (kNumRows == 1) continue;
const VF a1 = hn::LoadU(d32, tile_a + stride_a * 1 + col_ab); const VF a1 = hn::LoadU(d32, tile_a + stride_a * 1 + col_ab);
c10 = hn::MulAdd(a1, b0, c10); c10 = hn::MulAdd(a1, b0, c10);
c11 = hn::MulAdd(a1, b1, c11); c11 = hn::MulAdd(a1, b1, c11);
c12 = hn::MulAdd(a1, b2, c12); c12 = hn::MulAdd(a1, b2, c12);
c13 = hn::MulAdd(a1, b3, c13); c13 = hn::MulAdd(a1, b3, c13);
if (kNumRows == 2) continue;
const VF a2 = hn::LoadU(d32, tile_a + stride_a * 2 + col_ab); const VF a2 = hn::LoadU(d32, tile_a + stride_a * 2 + col_ab);
c20 = hn::MulAdd(a2, b0, c20); c20 = hn::MulAdd(a2, b0, c20);
c21 = hn::MulAdd(a2, b1, c21); c21 = hn::MulAdd(a2, b1, c21);
c22 = hn::MulAdd(a2, b2, c22); c22 = hn::MulAdd(a2, b2, c22);
c23 = hn::MulAdd(a2, b3, c23); c23 = hn::MulAdd(a2, b3, c23);
if (kNumRows == 3) continue;
const VF a3 = hn::LoadU(d32, tile_a + stride_a * 3 + col_ab); const VF a3 = hn::LoadU(d32, tile_a + stride_a * 3 + col_ab);
c30 = hn::MulAdd(a3, b0, c30); c30 = hn::MulAdd(a3, b0, c30);
@ -617,8 +651,9 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
} }
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
StoreHorizontalSums(d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, StoreHorizontalSums<kNumRows>(
c22, c23, c30, c31, c32, c33, tile_c, stride_c); d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21,
c22, c23, c30, c31, c32, c33, tile_c, stride_c);
} }
// Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and // Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and
@ -646,8 +681,8 @@ void GEMM_4x4_Static(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B,
HWY_UNROLL(1) HWY_UNROLL(1)
for (size_t idx_tile = 0; idx_tile < kTiles; ++idx_tile) { for (size_t idx_tile = 0; idx_tile < kTiles; ++idx_tile) {
GEMM_4x4_Tile<kColsA_RowsB>(A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, GEMM_4x4_Tile<kRegRows, kColsA_RowsB>(
kStrideC); A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
} }
} }
@ -680,8 +715,59 @@ HWY_NOINLINE void MatMul_4x4(const MatTA* HWY_RESTRICT A,
pool.Run(0, kTiles, [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR { pool.Run(0, kTiles, [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
// Computes the finished product of one 4x4N tile and writes to C. // Computes the finished product of one 4x4N tile and writes to C.
GEMM_4x4_Tile<kColsA_RowsB>(A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, GEMM_4x4_Tile<kRegRows, kColsA_RowsB>(
kStrideC); A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
});
}
// Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k,
// and kColsBC is 24k or 3k. Note: B is transposed (column-major).
// NOTE that batch_size is the number of rows of A and C.
// This function processes tiles in parallel with a work-stealing thread pool.
template <size_t kColsA_RowsB, size_t kColsBC, typename MatTA,
typename MatTB, typename OutT>
HWY_NOINLINE void MatMul_4x4_Batch(
size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B,
OutT* HWY_RESTRICT C, hwy::ThreadPool& pool) {
// Process reg-sized tiles of C in parallel. We currently write C directly,
// which touches more memory than fits in L3. TODO: add another level of loops
// so that we finish one L3-sized piece of C at a time.
const hn::ScalableTag<MatTA> d;
const size_t N = Lanes(d);
constexpr size_t kRegRows = 4;
constexpr size_t kRegCols = 4; // in vectors
static_assert(kColsBC % kRegCols == 0);
HWY_ASSERT(kColsA_RowsB % (N * kRegCols) == 0);
const size_t kTilesY = (batch_size + kRegRows - 1) / kRegRows;
const size_t kTilesX = kColsBC / kRegCols;
const size_t kTiles = kTilesX * kTilesY;
constexpr size_t kStrideA = kColsA_RowsB;
constexpr size_t kStrideB = kColsA_RowsB;
constexpr size_t kStrideC = kColsBC;
pool.Run(0, kTiles, [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
// Computes the finished product of one 4x4N tile and writes to C.
const size_t num_rows = batch_size - idx_tile / kTilesX * kRegRows;
HWY_ASSERT(num_rows > 0);
switch (num_rows) {
case 1:
GEMM_4x4_Tile<1, kColsA_RowsB>(
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
break;
case 2:
GEMM_4x4_Tile<2, kColsA_RowsB>(
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
break;
case 3:
GEMM_4x4_Tile<3, kColsA_RowsB>(
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
break;
default:
GEMM_4x4_Tile<4, kColsA_RowsB>(
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
}
}); });
} }
@ -714,6 +800,35 @@ HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a,
MatMulSlow<kM, kN, kK>(a, b.get(), out); MatMulSlow<kM, kN, kK>(a, b.get(), out);
} }
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
// ops_test across instruction sets.
template <size_t kN, size_t kK, typename MatTA, typename MatTB>
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b,
float* HWY_RESTRICT out) {
for (size_t i = 0; i < batch_size; ++i) {
for (size_t k = 0; k < kN; ++k) {
for (size_t j = 0; j < kK; ++j) {
const float a1 = hwy::ConvertScalarTo<float>(a[i * kN + k]);
const float b1 = hwy::ConvertScalarTo<float>(b[k * kK + j]);
out[i * kK + j] += a1 * b1;
}
}
}
}
template <size_t kN, size_t kK, typename MatTA>
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
const SfpStream* HWY_RESTRICT b_sfp_stream,
float* HWY_RESTRICT out) {
const hn::ScalableTag<float> d;
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
CompressTraits<SfpStream>::Decompress(d,
/*in_capacity=*/0, b_sfp_stream, 0,
b.get(), kK * kN);
MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), out);
}
HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned, HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
const size_t size, float* HWY_RESTRICT out) { const size_t size, float* HWY_RESTRICT out) {
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;

View File

@ -538,8 +538,12 @@ void TestTiledMatMul() {
GenerateMatHeap<MatTB, kN, kK>(0, pool); GenerateMatHeap<MatTB, kN, kK>(0, pool);
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow = std::unique_ptr<CompressedArray<float, kM * kK>> c_slow =
GenerateZeroMatHeap<float, kM, kK>(pool); GenerateZeroMatHeap<float, kM, kK>(pool);
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow_batch =
GenerateZeroMatHeap<float, kM, kK>(pool);
MatMulSlow<kM, kN, kK>(a->data(), b->data(), c_slow->data()); MatMulSlow<kM, kN, kK>(a->data(), b->data(), c_slow->data());
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(), c_slow_batch->data());
AssertClose(c_slow->data(), c_slow_batch->data(), kM * kK);
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK); hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b_trans = std::unique_ptr<CompressedArray<MatTB, kN * kK>> b_trans =
@ -565,8 +569,66 @@ void TestAllTiledMatMul() {
TestTiledMatMul<32, 128, 32, hwy::bfloat16_t, SfpStream>(); TestTiledMatMul<32, 128, 32, hwy::bfloat16_t, SfpStream>();
// large-scale test // large-scale test
// TODO(philculliton): investigate rounding issues with large matrices // TODO(philculliton): investigate rounding issues with large matrices.
TestTiledMatMul<512, 24576, 3072, float>(); // Causes test timeout.
// TestTiledMatMul<512, 24576, 3072, float>();
}
template <size_t kM, size_t kN, size_t kK, typename MatTA,
typename MatTB = MatTA>
void TestTiledBatchMatMul() {
hwy::ThreadPool pool(3);
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
GenerateMatHeap<MatTA, kM, kN>(0, pool);
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b =
GenerateMatHeap<MatTB, kN, kK>(0, pool);
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow =
GenerateZeroMatHeap<float, kM, kK>(pool);
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(), c_slow->data());
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b_trans =
GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool);
MatMul_4x4_Batch<kN, kK>(kM, a->data(), b_trans->data(), c.get(), pool);
AssertClose(c_slow->data(), c.get(), kM * kK);
}
void TestAllTiledBatchMatMul() {
// medium-sized square test
TestTiledBatchMatMul<512, 512, 512, float>();
TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t>();
TestTiledBatchMatMul<512, 512, 512, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<512, 512, 512, float, SfpStream>();
TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t, SfpStream>();
// minimal non-square test
TestTiledBatchMatMul<35, 128, 4, float>();
TestTiledBatchMatMul<34, 128, 4, hwy::bfloat16_t>();
TestTiledBatchMatMul<33, 128, 4, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<31, 128, 32, float, SfpStream>();
TestTiledBatchMatMul<29, 128, 32, hwy::bfloat16_t, SfpStream>();
TestTiledBatchMatMul<4, 128, 4, float>();
TestTiledBatchMatMul<4, 128, 4, hwy::bfloat16_t>();
TestTiledBatchMatMul<4, 128, 4, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<4, 128, 32, float, SfpStream>();
TestTiledBatchMatMul<4, 128, 32, hwy::bfloat16_t, SfpStream>();
TestTiledBatchMatMul<3, 128, 4, float>();
TestTiledBatchMatMul<3, 128, 4, hwy::bfloat16_t>();
TestTiledBatchMatMul<3, 128, 4, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<3, 128, 32, float, SfpStream>();
TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t, SfpStream>();
TestTiledBatchMatMul<2, 128, 4, float>();
TestTiledBatchMatMul<2, 128, 4, hwy::bfloat16_t>();
TestTiledBatchMatMul<2, 128, 4, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<2, 128, 32, float, SfpStream>();
TestTiledBatchMatMul<2, 128, 32, hwy::bfloat16_t, SfpStream>();
TestTiledBatchMatMul<1, 128, 4, float>();
TestTiledBatchMatMul<1, 128, 4, hwy::bfloat16_t>();
TestTiledBatchMatMul<1, 128, 4, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<1, 128, 32, float, SfpStream>();
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, SfpStream>();
} }
void TestMatVecAdd() { void TestMatVecAdd() {
@ -675,6 +737,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledBatchMatMul);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledMatMul); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledMatMul);
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);