mirror of https://github.com/google/gemma.cpp.git
Added MatMul_4x4_Batch which is MatMul_4x4, but with the first template arg moved to the first function arg, so the batch size (num A rows) can be variable at run-time.
PiperOrigin-RevId: 643017973
This commit is contained in:
parent
1b40619864
commit
ea525da967
159
gemma/ops.h
159
gemma/ops.h
|
|
@ -82,7 +82,7 @@ HWY_INLINE constexpr size_t RowsPerStrip() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shared between f32 and bf16, which also accumulates into f32 vectors.
|
// Shared between f32 and bf16, which also accumulates into f32 vectors.
|
||||||
template <class DF, class VF = hn::Vec<DF>>
|
template <size_t kNumRows, class DF, class VF = hn::Vec<DF>>
|
||||||
HWY_INLINE void StoreHorizontalSums(DF df, VF c00, VF c01, VF c02, VF c03,
|
HWY_INLINE void StoreHorizontalSums(DF df, VF c00, VF c01, VF c02, VF c03,
|
||||||
VF c10, VF c11, VF c12, VF c13, //
|
VF c10, VF c11, VF c12, VF c13, //
|
||||||
VF c20, VF c21, VF c22, VF c23, //
|
VF c20, VF c21, VF c22, VF c23, //
|
||||||
|
|
@ -97,16 +97,19 @@ HWY_INLINE void StoreHorizontalSums(DF df, VF c00, VF c01, VF c02, VF c03,
|
||||||
tile_c[stride_c * 0 + 1] = hn::ReduceSum(df, c01);
|
tile_c[stride_c * 0 + 1] = hn::ReduceSum(df, c01);
|
||||||
tile_c[stride_c * 0 + 2] = hn::ReduceSum(df, c02);
|
tile_c[stride_c * 0 + 2] = hn::ReduceSum(df, c02);
|
||||||
tile_c[stride_c * 0 + 3] = hn::ReduceSum(df, c03);
|
tile_c[stride_c * 0 + 3] = hn::ReduceSum(df, c03);
|
||||||
|
if (kNumRows == 1) return;
|
||||||
|
|
||||||
tile_c[stride_c * 1 + 0] = hn::ReduceSum(df, c10);
|
tile_c[stride_c * 1 + 0] = hn::ReduceSum(df, c10);
|
||||||
tile_c[stride_c * 1 + 1] = hn::ReduceSum(df, c11);
|
tile_c[stride_c * 1 + 1] = hn::ReduceSum(df, c11);
|
||||||
tile_c[stride_c * 1 + 2] = hn::ReduceSum(df, c12);
|
tile_c[stride_c * 1 + 2] = hn::ReduceSum(df, c12);
|
||||||
tile_c[stride_c * 1 + 3] = hn::ReduceSum(df, c13);
|
tile_c[stride_c * 1 + 3] = hn::ReduceSum(df, c13);
|
||||||
|
if (kNumRows == 2) return;
|
||||||
|
|
||||||
tile_c[stride_c * 2 + 0] = hn::ReduceSum(df, c20);
|
tile_c[stride_c * 2 + 0] = hn::ReduceSum(df, c20);
|
||||||
tile_c[stride_c * 2 + 1] = hn::ReduceSum(df, c21);
|
tile_c[stride_c * 2 + 1] = hn::ReduceSum(df, c21);
|
||||||
tile_c[stride_c * 2 + 2] = hn::ReduceSum(df, c22);
|
tile_c[stride_c * 2 + 2] = hn::ReduceSum(df, c22);
|
||||||
tile_c[stride_c * 2 + 3] = hn::ReduceSum(df, c23);
|
tile_c[stride_c * 2 + 3] = hn::ReduceSum(df, c23);
|
||||||
|
if (kNumRows == 3) return;
|
||||||
|
|
||||||
tile_c[stride_c * 3 + 0] = hn::ReduceSum(df, c30);
|
tile_c[stride_c * 3 + 0] = hn::ReduceSum(df, c30);
|
||||||
tile_c[stride_c * 3 + 1] = hn::ReduceSum(df, c31);
|
tile_c[stride_c * 3 + 1] = hn::ReduceSum(df, c31);
|
||||||
|
|
@ -114,10 +117,10 @@ HWY_INLINE void StoreHorizontalSums(DF df, VF c00, VF c01, VF c02, VF c03,
|
||||||
tile_c[stride_c * 3 + 3] = hn::ReduceSum(df, c33);
|
tile_c[stride_c * 3 + 3] = hn::ReduceSum(df, c33);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accumulates a single 4x4 tile of A x B into C. B is transposed, so we can
|
// Accumulates a single kNumRowsx4 tile of A x B into C. B is transposed, so we
|
||||||
// iterate over both A and B with consecutive vector loads.
|
// can iterate over both A and B with consecutive vector loads. kNumRows<=4.
|
||||||
// Shared between parallelized and sequential (loop) callers.
|
// Shared between parallelized and sequential (loop) callers.
|
||||||
template <size_t kColsA_RowsB, typename MatT, HWY_IF_F32(MatT)>
|
template <size_t kNumRows, size_t kColsA_RowsB, typename MatT, HWY_IF_F32(MatT)>
|
||||||
HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
const MatT* HWY_RESTRICT B, MatT* HWY_RESTRICT C,
|
const MatT* HWY_RESTRICT B, MatT* HWY_RESTRICT C,
|
||||||
const size_t idx_tile, const size_t xtiles,
|
const size_t idx_tile, const size_t xtiles,
|
||||||
|
|
@ -129,6 +132,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
// threads write in units of 4*N elements, which is at least one cache line.
|
// threads write in units of 4*N elements, which is at least one cache line.
|
||||||
constexpr size_t kRegRows = 4;
|
constexpr size_t kRegRows = 4;
|
||||||
constexpr size_t kRegCols = 4;
|
constexpr size_t kRegCols = 4;
|
||||||
|
static_assert(kNumRows <= kRegRows);
|
||||||
|
|
||||||
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
||||||
const size_t row_a = idx_tile / xtiles * kRegRows;
|
const size_t row_a = idx_tile / xtiles * kRegRows;
|
||||||
|
|
@ -175,18 +179,21 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
c01 = hn::MulAdd(a0, b1, c01);
|
c01 = hn::MulAdd(a0, b1, c01);
|
||||||
c02 = hn::MulAdd(a0, b2, c02);
|
c02 = hn::MulAdd(a0, b2, c02);
|
||||||
c03 = hn::MulAdd(a0, b3, c03);
|
c03 = hn::MulAdd(a0, b3, c03);
|
||||||
|
if (kNumRows == 1) continue;
|
||||||
|
|
||||||
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
|
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
|
||||||
c10 = hn::MulAdd(a1, b0, c10);
|
c10 = hn::MulAdd(a1, b0, c10);
|
||||||
c11 = hn::MulAdd(a1, b1, c11);
|
c11 = hn::MulAdd(a1, b1, c11);
|
||||||
c12 = hn::MulAdd(a1, b2, c12);
|
c12 = hn::MulAdd(a1, b2, c12);
|
||||||
c13 = hn::MulAdd(a1, b3, c13);
|
c13 = hn::MulAdd(a1, b3, c13);
|
||||||
|
if (kNumRows == 2) continue;
|
||||||
|
|
||||||
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
|
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
|
||||||
c20 = hn::MulAdd(a2, b0, c20);
|
c20 = hn::MulAdd(a2, b0, c20);
|
||||||
c21 = hn::MulAdd(a2, b1, c21);
|
c21 = hn::MulAdd(a2, b1, c21);
|
||||||
c22 = hn::MulAdd(a2, b2, c22);
|
c22 = hn::MulAdd(a2, b2, c22);
|
||||||
c23 = hn::MulAdd(a2, b3, c23);
|
c23 = hn::MulAdd(a2, b3, c23);
|
||||||
|
if (kNumRows == 3) continue;
|
||||||
|
|
||||||
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
|
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
|
||||||
c30 = hn::MulAdd(a3, b0, c30);
|
c30 = hn::MulAdd(a3, b0, c30);
|
||||||
|
|
@ -196,8 +203,9 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
}
|
}
|
||||||
|
|
||||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||||
StoreHorizontalSums(d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22,
|
StoreHorizontalSums<kNumRows>(
|
||||||
c23, c30, c31, c32, c33, tile_c, stride_c);
|
d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||||
|
c30, c31, c32, c33, tile_c, stride_c);
|
||||||
}
|
}
|
||||||
|
|
||||||
#undef GEMMA_NATIVE_BF16
|
#undef GEMMA_NATIVE_BF16
|
||||||
|
|
@ -209,7 +217,8 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// As above, for MatT=bf16
|
// As above, for MatT=bf16
|
||||||
template <size_t kColsA_RowsB, typename MatT, HWY_IF_BF16(MatT)>
|
template <size_t kNumRows, size_t kColsA_RowsB, typename MatT,
|
||||||
|
HWY_IF_BF16(MatT)>
|
||||||
HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
const MatT* HWY_RESTRICT B, float* HWY_RESTRICT C,
|
const MatT* HWY_RESTRICT B, float* HWY_RESTRICT C,
|
||||||
const size_t idx_tile, const size_t xtiles,
|
const size_t idx_tile, const size_t xtiles,
|
||||||
|
|
@ -217,6 +226,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
const size_t stride_c) {
|
const size_t stride_c) {
|
||||||
constexpr size_t kRegRows = 4;
|
constexpr size_t kRegRows = 4;
|
||||||
constexpr size_t kRegCols = 4;
|
constexpr size_t kRegCols = 4;
|
||||||
|
static_assert(kNumRows <= kRegRows);
|
||||||
|
|
||||||
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
||||||
const size_t row_a = idx_tile / xtiles * kRegRows;
|
const size_t row_a = idx_tile / xtiles * kRegRows;
|
||||||
|
|
@ -277,18 +287,21 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
c01 = hn::ReorderWidenMulAccumulate(df, a0, b1, c01, unused_sum1);
|
c01 = hn::ReorderWidenMulAccumulate(df, a0, b1, c01, unused_sum1);
|
||||||
c02 = hn::ReorderWidenMulAccumulate(df, a0, b2, c02, unused_sum1);
|
c02 = hn::ReorderWidenMulAccumulate(df, a0, b2, c02, unused_sum1);
|
||||||
c03 = hn::ReorderWidenMulAccumulate(df, a0, b3, c03, unused_sum1);
|
c03 = hn::ReorderWidenMulAccumulate(df, a0, b3, c03, unused_sum1);
|
||||||
|
if (kNumRows == 1) continue;
|
||||||
|
|
||||||
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
|
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
|
||||||
c10 = hn::ReorderWidenMulAccumulate(df, a1, b0, c10, unused_sum1);
|
c10 = hn::ReorderWidenMulAccumulate(df, a1, b0, c10, unused_sum1);
|
||||||
c11 = hn::ReorderWidenMulAccumulate(df, a1, b1, c11, unused_sum1);
|
c11 = hn::ReorderWidenMulAccumulate(df, a1, b1, c11, unused_sum1);
|
||||||
c12 = hn::ReorderWidenMulAccumulate(df, a1, b2, c12, unused_sum1);
|
c12 = hn::ReorderWidenMulAccumulate(df, a1, b2, c12, unused_sum1);
|
||||||
c13 = hn::ReorderWidenMulAccumulate(df, a1, b3, c13, unused_sum1);
|
c13 = hn::ReorderWidenMulAccumulate(df, a1, b3, c13, unused_sum1);
|
||||||
|
if (kNumRows == 2) continue;
|
||||||
|
|
||||||
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
|
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
|
||||||
c20 = hn::ReorderWidenMulAccumulate(df, a2, b0, c20, unused_sum1);
|
c20 = hn::ReorderWidenMulAccumulate(df, a2, b0, c20, unused_sum1);
|
||||||
c21 = hn::ReorderWidenMulAccumulate(df, a2, b1, c21, unused_sum1);
|
c21 = hn::ReorderWidenMulAccumulate(df, a2, b1, c21, unused_sum1);
|
||||||
c22 = hn::ReorderWidenMulAccumulate(df, a2, b2, c22, unused_sum1);
|
c22 = hn::ReorderWidenMulAccumulate(df, a2, b2, c22, unused_sum1);
|
||||||
c23 = hn::ReorderWidenMulAccumulate(df, a2, b3, c23, unused_sum1);
|
c23 = hn::ReorderWidenMulAccumulate(df, a2, b3, c23, unused_sum1);
|
||||||
|
if (kNumRows == 3) continue;
|
||||||
|
|
||||||
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
|
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
|
||||||
c30 = hn::ReorderWidenMulAccumulate(df, a3, b0, c30, unused_sum1);
|
c30 = hn::ReorderWidenMulAccumulate(df, a3, b0, c30, unused_sum1);
|
||||||
|
|
@ -311,6 +324,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
c01 = hn::MulAdd(a0, b1, c01);
|
c01 = hn::MulAdd(a0, b1, c01);
|
||||||
c02 = hn::MulAdd(a0, b2, c02);
|
c02 = hn::MulAdd(a0, b2, c02);
|
||||||
c03 = hn::MulAdd(a0, b3, c03);
|
c03 = hn::MulAdd(a0, b3, c03);
|
||||||
|
if (kNumRows == 1) continue;
|
||||||
|
|
||||||
const VF a1 =
|
const VF a1 =
|
||||||
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 1 + col_ab));
|
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 1 + col_ab));
|
||||||
|
|
@ -318,6 +332,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
c11 = hn::MulAdd(a1, b1, c11);
|
c11 = hn::MulAdd(a1, b1, c11);
|
||||||
c12 = hn::MulAdd(a1, b2, c12);
|
c12 = hn::MulAdd(a1, b2, c12);
|
||||||
c13 = hn::MulAdd(a1, b3, c13);
|
c13 = hn::MulAdd(a1, b3, c13);
|
||||||
|
if (kNumRows == 2) continue;
|
||||||
|
|
||||||
const VF a2 =
|
const VF a2 =
|
||||||
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 2 + col_ab));
|
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 2 + col_ab));
|
||||||
|
|
@ -325,6 +340,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
c21 = hn::MulAdd(a2, b1, c21);
|
c21 = hn::MulAdd(a2, b1, c21);
|
||||||
c22 = hn::MulAdd(a2, b2, c22);
|
c22 = hn::MulAdd(a2, b2, c22);
|
||||||
c23 = hn::MulAdd(a2, b3, c23);
|
c23 = hn::MulAdd(a2, b3, c23);
|
||||||
|
if (kNumRows == 3) continue;
|
||||||
|
|
||||||
const VF a3 =
|
const VF a3 =
|
||||||
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 3 + col_ab));
|
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 3 + col_ab));
|
||||||
|
|
@ -341,12 +357,14 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||||
StoreHorizontalSums(df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22,
|
StoreHorizontalSums<kNumRows>(
|
||||||
c23, c30, c31, c32, c33, tile_c, stride_c);
|
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22,
|
||||||
|
c23, c30, c31, c32, c33, tile_c, stride_c);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as above, but with mixed Mat types: (f32, sfp)).
|
// Same as above, but with mixed Mat types: (f32, sfp)).
|
||||||
template <size_t kColsA_RowsB, typename MatTA, HWY_IF_F32(MatTA)>
|
template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
|
||||||
|
HWY_IF_F32(MatTA)>
|
||||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
const SfpStream* HWY_RESTRICT B,
|
const SfpStream* HWY_RESTRICT B,
|
||||||
float* HWY_RESTRICT C, const size_t idx_tile,
|
float* HWY_RESTRICT C, const size_t idx_tile,
|
||||||
|
|
@ -354,6 +372,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
const size_t stride_b, const size_t stride_c) {
|
const size_t stride_b, const size_t stride_c) {
|
||||||
constexpr size_t kRegRows = 4;
|
constexpr size_t kRegRows = 4;
|
||||||
constexpr size_t kRegCols = 4;
|
constexpr size_t kRegCols = 4;
|
||||||
|
static_assert(kNumRows <= kRegRows);
|
||||||
|
|
||||||
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
||||||
const size_t row_a = idx_tile / xtiles * kRegRows;
|
const size_t row_a = idx_tile / xtiles * kRegRows;
|
||||||
|
|
@ -407,18 +426,21 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
c01 = hn::MulAdd(a0, b1, c01);
|
c01 = hn::MulAdd(a0, b1, c01);
|
||||||
c02 = hn::MulAdd(a0, b2, c02);
|
c02 = hn::MulAdd(a0, b2, c02);
|
||||||
c03 = hn::MulAdd(a0, b3, c03);
|
c03 = hn::MulAdd(a0, b3, c03);
|
||||||
|
if (kNumRows == 1) continue;
|
||||||
|
|
||||||
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
|
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
|
||||||
c10 = hn::MulAdd(a1, b0, c10);
|
c10 = hn::MulAdd(a1, b0, c10);
|
||||||
c11 = hn::MulAdd(a1, b1, c11);
|
c11 = hn::MulAdd(a1, b1, c11);
|
||||||
c12 = hn::MulAdd(a1, b2, c12);
|
c12 = hn::MulAdd(a1, b2, c12);
|
||||||
c13 = hn::MulAdd(a1, b3, c13);
|
c13 = hn::MulAdd(a1, b3, c13);
|
||||||
|
if (kNumRows == 2) continue;
|
||||||
|
|
||||||
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
|
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
|
||||||
c20 = hn::MulAdd(a2, b0, c20);
|
c20 = hn::MulAdd(a2, b0, c20);
|
||||||
c21 = hn::MulAdd(a2, b1, c21);
|
c21 = hn::MulAdd(a2, b1, c21);
|
||||||
c22 = hn::MulAdd(a2, b2, c22);
|
c22 = hn::MulAdd(a2, b2, c22);
|
||||||
c23 = hn::MulAdd(a2, b3, c23);
|
c23 = hn::MulAdd(a2, b3, c23);
|
||||||
|
if (kNumRows == 3) continue;
|
||||||
|
|
||||||
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
|
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
|
||||||
c30 = hn::MulAdd(a3, b0, c30);
|
c30 = hn::MulAdd(a3, b0, c30);
|
||||||
|
|
@ -428,12 +450,14 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
}
|
}
|
||||||
|
|
||||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||||
StoreHorizontalSums(d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22,
|
StoreHorizontalSums<kNumRows>(
|
||||||
c23, c30, c31, c32, c33, tile_c, stride_c);
|
d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22,
|
||||||
|
c23, c30, c31, c32, c33, tile_c, stride_c);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as above, but with mixed Mat types: (bf16, sfp)).
|
// Same as above, but with mixed Mat types: (bf16, sfp)).
|
||||||
template <size_t kColsA_RowsB, typename MatTA, HWY_IF_BF16(MatTA)>
|
template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
|
||||||
|
HWY_IF_BF16(MatTA)>
|
||||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
const SfpStream* HWY_RESTRICT B,
|
const SfpStream* HWY_RESTRICT B,
|
||||||
float* HWY_RESTRICT C, const size_t idx_tile,
|
float* HWY_RESTRICT C, const size_t idx_tile,
|
||||||
|
|
@ -441,6 +465,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
const size_t stride_b, const size_t stride_c) {
|
const size_t stride_b, const size_t stride_c) {
|
||||||
constexpr size_t kRegRows = 4;
|
constexpr size_t kRegRows = 4;
|
||||||
constexpr size_t kRegCols = 4;
|
constexpr size_t kRegCols = 4;
|
||||||
|
static_assert(kNumRows <= kRegRows);
|
||||||
|
|
||||||
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
||||||
const size_t row_a = idx_tile / xtiles * kRegRows;
|
const size_t row_a = idx_tile / xtiles * kRegRows;
|
||||||
|
|
@ -500,6 +525,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
c01 = hn::MulAdd(a0, b1, c01);
|
c01 = hn::MulAdd(a0, b1, c01);
|
||||||
c02 = hn::MulAdd(a0, b2, c02);
|
c02 = hn::MulAdd(a0, b2, c02);
|
||||||
c03 = hn::MulAdd(a0, b3, c03);
|
c03 = hn::MulAdd(a0, b3, c03);
|
||||||
|
if (kNumRows == 1) continue;
|
||||||
|
|
||||||
const V a1 =
|
const V a1 =
|
||||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab));
|
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab));
|
||||||
|
|
@ -507,6 +533,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
c11 = hn::MulAdd(a1, b1, c11);
|
c11 = hn::MulAdd(a1, b1, c11);
|
||||||
c12 = hn::MulAdd(a1, b2, c12);
|
c12 = hn::MulAdd(a1, b2, c12);
|
||||||
c13 = hn::MulAdd(a1, b3, c13);
|
c13 = hn::MulAdd(a1, b3, c13);
|
||||||
|
if (kNumRows == 2) continue;
|
||||||
|
|
||||||
const V a2 =
|
const V a2 =
|
||||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab));
|
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab));
|
||||||
|
|
@ -514,6 +541,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
c21 = hn::MulAdd(a2, b1, c21);
|
c21 = hn::MulAdd(a2, b1, c21);
|
||||||
c22 = hn::MulAdd(a2, b2, c22);
|
c22 = hn::MulAdd(a2, b2, c22);
|
||||||
c23 = hn::MulAdd(a2, b3, c23);
|
c23 = hn::MulAdd(a2, b3, c23);
|
||||||
|
if (kNumRows == 3) continue;
|
||||||
|
|
||||||
const V a3 =
|
const V a3 =
|
||||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab));
|
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab));
|
||||||
|
|
@ -524,12 +552,14 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
}
|
}
|
||||||
|
|
||||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||||
StoreHorizontalSums(d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21,
|
StoreHorizontalSums<kNumRows>(
|
||||||
c22, c23, c30, c31, c32, c33, tile_c, stride_c);
|
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21,
|
||||||
|
c22, c23, c30, c31, c32, c33, tile_c, stride_c);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as above, but with mixed Mat types: (f32, bf16).
|
// Same as above, but with mixed Mat types: (f32, bf16).
|
||||||
template <size_t kColsA_RowsB, typename MatTA, HWY_IF_F32(MatTA),
|
template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
|
||||||
|
HWY_IF_F32(MatTA),
|
||||||
typename MatTB, HWY_IF_BF16(MatTB)>
|
typename MatTB, HWY_IF_BF16(MatTB)>
|
||||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
const MatTB* HWY_RESTRICT B,
|
const MatTB* HWY_RESTRICT B,
|
||||||
|
|
@ -538,6 +568,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
const size_t stride_b, const size_t stride_c) {
|
const size_t stride_b, const size_t stride_c) {
|
||||||
constexpr size_t kRegRows = 4;
|
constexpr size_t kRegRows = 4;
|
||||||
constexpr size_t kRegCols = 4;
|
constexpr size_t kRegCols = 4;
|
||||||
|
static_assert(kNumRows <= kRegRows);
|
||||||
|
|
||||||
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
||||||
const size_t row_a = idx_tile / xtiles * kRegRows;
|
const size_t row_a = idx_tile / xtiles * kRegRows;
|
||||||
|
|
@ -596,18 +627,21 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
c01 = hn::MulAdd(a0, b1, c01);
|
c01 = hn::MulAdd(a0, b1, c01);
|
||||||
c02 = hn::MulAdd(a0, b2, c02);
|
c02 = hn::MulAdd(a0, b2, c02);
|
||||||
c03 = hn::MulAdd(a0, b3, c03);
|
c03 = hn::MulAdd(a0, b3, c03);
|
||||||
|
if (kNumRows == 1) continue;
|
||||||
|
|
||||||
const VF a1 = hn::LoadU(d32, tile_a + stride_a * 1 + col_ab);
|
const VF a1 = hn::LoadU(d32, tile_a + stride_a * 1 + col_ab);
|
||||||
c10 = hn::MulAdd(a1, b0, c10);
|
c10 = hn::MulAdd(a1, b0, c10);
|
||||||
c11 = hn::MulAdd(a1, b1, c11);
|
c11 = hn::MulAdd(a1, b1, c11);
|
||||||
c12 = hn::MulAdd(a1, b2, c12);
|
c12 = hn::MulAdd(a1, b2, c12);
|
||||||
c13 = hn::MulAdd(a1, b3, c13);
|
c13 = hn::MulAdd(a1, b3, c13);
|
||||||
|
if (kNumRows == 2) continue;
|
||||||
|
|
||||||
const VF a2 = hn::LoadU(d32, tile_a + stride_a * 2 + col_ab);
|
const VF a2 = hn::LoadU(d32, tile_a + stride_a * 2 + col_ab);
|
||||||
c20 = hn::MulAdd(a2, b0, c20);
|
c20 = hn::MulAdd(a2, b0, c20);
|
||||||
c21 = hn::MulAdd(a2, b1, c21);
|
c21 = hn::MulAdd(a2, b1, c21);
|
||||||
c22 = hn::MulAdd(a2, b2, c22);
|
c22 = hn::MulAdd(a2, b2, c22);
|
||||||
c23 = hn::MulAdd(a2, b3, c23);
|
c23 = hn::MulAdd(a2, b3, c23);
|
||||||
|
if (kNumRows == 3) continue;
|
||||||
|
|
||||||
const VF a3 = hn::LoadU(d32, tile_a + stride_a * 3 + col_ab);
|
const VF a3 = hn::LoadU(d32, tile_a + stride_a * 3 + col_ab);
|
||||||
c30 = hn::MulAdd(a3, b0, c30);
|
c30 = hn::MulAdd(a3, b0, c30);
|
||||||
|
|
@ -617,8 +651,9 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
}
|
}
|
||||||
|
|
||||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||||
StoreHorizontalSums(d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21,
|
StoreHorizontalSums<kNumRows>(
|
||||||
c22, c23, c30, c31, c32, c33, tile_c, stride_c);
|
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21,
|
||||||
|
c22, c23, c30, c31, c32, c33, tile_c, stride_c);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and
|
// Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and
|
||||||
|
|
@ -646,8 +681,8 @@ void GEMM_4x4_Static(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B,
|
||||||
|
|
||||||
HWY_UNROLL(1)
|
HWY_UNROLL(1)
|
||||||
for (size_t idx_tile = 0; idx_tile < kTiles; ++idx_tile) {
|
for (size_t idx_tile = 0; idx_tile < kTiles; ++idx_tile) {
|
||||||
GEMM_4x4_Tile<kColsA_RowsB>(A, B, C, idx_tile, kTilesX, kStrideA, kStrideB,
|
GEMM_4x4_Tile<kRegRows, kColsA_RowsB>(
|
||||||
kStrideC);
|
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -680,8 +715,59 @@ HWY_NOINLINE void MatMul_4x4(const MatTA* HWY_RESTRICT A,
|
||||||
|
|
||||||
pool.Run(0, kTiles, [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
|
pool.Run(0, kTiles, [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
|
||||||
// Computes the finished product of one 4x4N tile and writes to C.
|
// Computes the finished product of one 4x4N tile and writes to C.
|
||||||
GEMM_4x4_Tile<kColsA_RowsB>(A, B, C, idx_tile, kTilesX, kStrideA, kStrideB,
|
GEMM_4x4_Tile<kRegRows, kColsA_RowsB>(
|
||||||
kStrideC);
|
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k,
|
||||||
|
// and kColsBC is 24k or 3k. Note: B is transposed (column-major).
|
||||||
|
// NOTE that batch_size is the number of rows of A and C.
|
||||||
|
// This function processes tiles in parallel with a work-stealing thread pool.
|
||||||
|
template <size_t kColsA_RowsB, size_t kColsBC, typename MatTA,
|
||||||
|
typename MatTB, typename OutT>
|
||||||
|
HWY_NOINLINE void MatMul_4x4_Batch(
|
||||||
|
size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B,
|
||||||
|
OutT* HWY_RESTRICT C, hwy::ThreadPool& pool) {
|
||||||
|
// Process reg-sized tiles of C in parallel. We currently write C directly,
|
||||||
|
// which touches more memory than fits in L3. TODO: add another level of loops
|
||||||
|
// so that we finish one L3-sized piece of C at a time.
|
||||||
|
const hn::ScalableTag<MatTA> d;
|
||||||
|
const size_t N = Lanes(d);
|
||||||
|
constexpr size_t kRegRows = 4;
|
||||||
|
constexpr size_t kRegCols = 4; // in vectors
|
||||||
|
|
||||||
|
static_assert(kColsBC % kRegCols == 0);
|
||||||
|
HWY_ASSERT(kColsA_RowsB % (N * kRegCols) == 0);
|
||||||
|
const size_t kTilesY = (batch_size + kRegRows - 1) / kRegRows;
|
||||||
|
const size_t kTilesX = kColsBC / kRegCols;
|
||||||
|
const size_t kTiles = kTilesX * kTilesY;
|
||||||
|
|
||||||
|
constexpr size_t kStrideA = kColsA_RowsB;
|
||||||
|
constexpr size_t kStrideB = kColsA_RowsB;
|
||||||
|
constexpr size_t kStrideC = kColsBC;
|
||||||
|
|
||||||
|
pool.Run(0, kTiles, [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
|
||||||
|
// Computes the finished product of one 4x4N tile and writes to C.
|
||||||
|
const size_t num_rows = batch_size - idx_tile / kTilesX * kRegRows;
|
||||||
|
HWY_ASSERT(num_rows > 0);
|
||||||
|
switch (num_rows) {
|
||||||
|
case 1:
|
||||||
|
GEMM_4x4_Tile<1, kColsA_RowsB>(
|
||||||
|
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
GEMM_4x4_Tile<2, kColsA_RowsB>(
|
||||||
|
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
GEMM_4x4_Tile<3, kColsA_RowsB>(
|
||||||
|
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GEMM_4x4_Tile<4, kColsA_RowsB>(
|
||||||
|
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -714,6 +800,35 @@ HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a,
|
||||||
MatMulSlow<kM, kN, kK>(a, b.get(), out);
|
MatMulSlow<kM, kN, kK>(a, b.get(), out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
|
||||||
|
// ops_test across instruction sets.
|
||||||
|
template <size_t kN, size_t kK, typename MatTA, typename MatTB>
|
||||||
|
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
|
||||||
|
const MatTB* HWY_RESTRICT b,
|
||||||
|
float* HWY_RESTRICT out) {
|
||||||
|
for (size_t i = 0; i < batch_size; ++i) {
|
||||||
|
for (size_t k = 0; k < kN; ++k) {
|
||||||
|
for (size_t j = 0; j < kK; ++j) {
|
||||||
|
const float a1 = hwy::ConvertScalarTo<float>(a[i * kN + k]);
|
||||||
|
const float b1 = hwy::ConvertScalarTo<float>(b[k * kK + j]);
|
||||||
|
out[i * kK + j] += a1 * b1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t kN, size_t kK, typename MatTA>
|
||||||
|
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
|
||||||
|
const SfpStream* HWY_RESTRICT b_sfp_stream,
|
||||||
|
float* HWY_RESTRICT out) {
|
||||||
|
const hn::ScalableTag<float> d;
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
|
||||||
|
CompressTraits<SfpStream>::Decompress(d,
|
||||||
|
/*in_capacity=*/0, b_sfp_stream, 0,
|
||||||
|
b.get(), kK * kN);
|
||||||
|
MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), out);
|
||||||
|
}
|
||||||
|
|
||||||
HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
|
HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
|
||||||
const size_t size, float* HWY_RESTRICT out) {
|
const size_t size, float* HWY_RESTRICT out) {
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
|
|
|
||||||
|
|
@ -538,8 +538,12 @@ void TestTiledMatMul() {
|
||||||
GenerateMatHeap<MatTB, kN, kK>(0, pool);
|
GenerateMatHeap<MatTB, kN, kK>(0, pool);
|
||||||
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow =
|
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow =
|
||||||
GenerateZeroMatHeap<float, kM, kK>(pool);
|
GenerateZeroMatHeap<float, kM, kK>(pool);
|
||||||
|
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow_batch =
|
||||||
|
GenerateZeroMatHeap<float, kM, kK>(pool);
|
||||||
|
|
||||||
MatMulSlow<kM, kN, kK>(a->data(), b->data(), c_slow->data());
|
MatMulSlow<kM, kN, kK>(a->data(), b->data(), c_slow->data());
|
||||||
|
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(), c_slow_batch->data());
|
||||||
|
AssertClose(c_slow->data(), c_slow_batch->data(), kM * kK);
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
|
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
|
||||||
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b_trans =
|
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b_trans =
|
||||||
|
|
@ -565,8 +569,66 @@ void TestAllTiledMatMul() {
|
||||||
TestTiledMatMul<32, 128, 32, hwy::bfloat16_t, SfpStream>();
|
TestTiledMatMul<32, 128, 32, hwy::bfloat16_t, SfpStream>();
|
||||||
|
|
||||||
// large-scale test
|
// large-scale test
|
||||||
// TODO(philculliton): investigate rounding issues with large matrices
|
// TODO(philculliton): investigate rounding issues with large matrices.
|
||||||
TestTiledMatMul<512, 24576, 3072, float>();
|
// Causes test timeout.
|
||||||
|
// TestTiledMatMul<512, 24576, 3072, float>();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t kM, size_t kN, size_t kK, typename MatTA,
|
||||||
|
typename MatTB = MatTA>
|
||||||
|
void TestTiledBatchMatMul() {
|
||||||
|
hwy::ThreadPool pool(3);
|
||||||
|
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
|
||||||
|
GenerateMatHeap<MatTA, kM, kN>(0, pool);
|
||||||
|
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b =
|
||||||
|
GenerateMatHeap<MatTB, kN, kK>(0, pool);
|
||||||
|
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow =
|
||||||
|
GenerateZeroMatHeap<float, kM, kK>(pool);
|
||||||
|
|
||||||
|
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(), c_slow->data());
|
||||||
|
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
|
||||||
|
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b_trans =
|
||||||
|
GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool);
|
||||||
|
MatMul_4x4_Batch<kN, kK>(kM, a->data(), b_trans->data(), c.get(), pool);
|
||||||
|
|
||||||
|
AssertClose(c_slow->data(), c.get(), kM * kK);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestAllTiledBatchMatMul() {
|
||||||
|
// medium-sized square test
|
||||||
|
TestTiledBatchMatMul<512, 512, 512, float>();
|
||||||
|
TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t>();
|
||||||
|
TestTiledBatchMatMul<512, 512, 512, float, hwy::bfloat16_t>();
|
||||||
|
TestTiledBatchMatMul<512, 512, 512, float, SfpStream>();
|
||||||
|
TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t, SfpStream>();
|
||||||
|
|
||||||
|
// minimal non-square test
|
||||||
|
TestTiledBatchMatMul<35, 128, 4, float>();
|
||||||
|
TestTiledBatchMatMul<34, 128, 4, hwy::bfloat16_t>();
|
||||||
|
TestTiledBatchMatMul<33, 128, 4, float, hwy::bfloat16_t>();
|
||||||
|
TestTiledBatchMatMul<31, 128, 32, float, SfpStream>();
|
||||||
|
TestTiledBatchMatMul<29, 128, 32, hwy::bfloat16_t, SfpStream>();
|
||||||
|
TestTiledBatchMatMul<4, 128, 4, float>();
|
||||||
|
TestTiledBatchMatMul<4, 128, 4, hwy::bfloat16_t>();
|
||||||
|
TestTiledBatchMatMul<4, 128, 4, float, hwy::bfloat16_t>();
|
||||||
|
TestTiledBatchMatMul<4, 128, 32, float, SfpStream>();
|
||||||
|
TestTiledBatchMatMul<4, 128, 32, hwy::bfloat16_t, SfpStream>();
|
||||||
|
TestTiledBatchMatMul<3, 128, 4, float>();
|
||||||
|
TestTiledBatchMatMul<3, 128, 4, hwy::bfloat16_t>();
|
||||||
|
TestTiledBatchMatMul<3, 128, 4, float, hwy::bfloat16_t>();
|
||||||
|
TestTiledBatchMatMul<3, 128, 32, float, SfpStream>();
|
||||||
|
TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t, SfpStream>();
|
||||||
|
TestTiledBatchMatMul<2, 128, 4, float>();
|
||||||
|
TestTiledBatchMatMul<2, 128, 4, hwy::bfloat16_t>();
|
||||||
|
TestTiledBatchMatMul<2, 128, 4, float, hwy::bfloat16_t>();
|
||||||
|
TestTiledBatchMatMul<2, 128, 32, float, SfpStream>();
|
||||||
|
TestTiledBatchMatMul<2, 128, 32, hwy::bfloat16_t, SfpStream>();
|
||||||
|
TestTiledBatchMatMul<1, 128, 4, float>();
|
||||||
|
TestTiledBatchMatMul<1, 128, 4, hwy::bfloat16_t>();
|
||||||
|
TestTiledBatchMatMul<1, 128, 4, float, hwy::bfloat16_t>();
|
||||||
|
TestTiledBatchMatMul<1, 128, 32, float, SfpStream>();
|
||||||
|
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, SfpStream>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestMatVecAdd() {
|
void TestMatVecAdd() {
|
||||||
|
|
@ -675,6 +737,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
||||||
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledBatchMatMul);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledMatMul);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledMatMul);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue