mirror of https://github.com/google/gemma.cpp.git
Added bias vector addition to MatMul
PiperOrigin-RevId: 643385381
This commit is contained in:
parent
2228055bb8
commit
e0afdfa8fb
275
gemma/ops.h
275
gemma/ops.h
|
|
@ -141,12 +141,55 @@ HWY_INLINE void StoreHorizontalSums(DF df, VF c00, VF c01, VF c02, VF c03,
|
||||||
tile_c[stride_c * 3 + 3] = hn::ReduceSum(df, c33);
|
tile_c[stride_c * 3 + 3] = hn::ReduceSum(df, c33);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Completes the tile by summing across the vectors, and adds the biases.
|
||||||
|
template <size_t kNumRows, class DF, class VF = hn::Vec<DF>, typename AddT>
|
||||||
|
HWY_INLINE void StoreHorizontalSumsAdd(DF df, VF c00, VF c01, VF c02, VF c03,
|
||||||
|
VF c10, VF c11, VF c12, VF c13, //
|
||||||
|
VF c20, VF c21, VF c22, VF c23, //
|
||||||
|
VF c30, VF c31, VF c32, VF c33,
|
||||||
|
const AddT* add,
|
||||||
|
float* HWY_RESTRICT tile_c,
|
||||||
|
size_t stride_c) {
|
||||||
|
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
|
||||||
|
// Each entry of C[r,c] is a dot product of A.row and B.col, which reside in
|
||||||
|
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
|
||||||
|
// expensive, but only a fraction of the kColsA_RowsB/N FMAs.
|
||||||
|
float addon0 = hwy::ConvertScalarTo<float>(add[0]);
|
||||||
|
tile_c[stride_c * 0 + 0] = hn::ReduceSum(df, c00) + addon0;
|
||||||
|
float addon1 = hwy::ConvertScalarTo<float>(add[1]);
|
||||||
|
tile_c[stride_c * 0 + 1] = hn::ReduceSum(df, c01) + addon1;
|
||||||
|
float addon2 = hwy::ConvertScalarTo<float>(add[2]);
|
||||||
|
tile_c[stride_c * 0 + 2] = hn::ReduceSum(df, c02) + addon2;
|
||||||
|
float addon3 = hwy::ConvertScalarTo<float>(add[3]);
|
||||||
|
tile_c[stride_c * 0 + 3] = hn::ReduceSum(df, c03) + addon3;
|
||||||
|
if (kNumRows == 1) return;
|
||||||
|
|
||||||
|
tile_c[stride_c * 1 + 0] = hn::ReduceSum(df, c10) + addon0;
|
||||||
|
tile_c[stride_c * 1 + 1] = hn::ReduceSum(df, c11) + addon1;
|
||||||
|
tile_c[stride_c * 1 + 2] = hn::ReduceSum(df, c12) + addon2;
|
||||||
|
tile_c[stride_c * 1 + 3] = hn::ReduceSum(df, c13) + addon3;
|
||||||
|
if (kNumRows == 2) return;
|
||||||
|
|
||||||
|
tile_c[stride_c * 2 + 0] = hn::ReduceSum(df, c20) + addon0;
|
||||||
|
tile_c[stride_c * 2 + 1] = hn::ReduceSum(df, c21) + addon1;
|
||||||
|
tile_c[stride_c * 2 + 2] = hn::ReduceSum(df, c22) + addon2;
|
||||||
|
tile_c[stride_c * 2 + 3] = hn::ReduceSum(df, c23) + addon3;
|
||||||
|
if (kNumRows == 3) return;
|
||||||
|
|
||||||
|
tile_c[stride_c * 3 + 0] = hn::ReduceSum(df, c30) + addon0;
|
||||||
|
tile_c[stride_c * 3 + 1] = hn::ReduceSum(df, c31) + addon1;
|
||||||
|
tile_c[stride_c * 3 + 2] = hn::ReduceSum(df, c32) + addon2;
|
||||||
|
tile_c[stride_c * 3 + 3] = hn::ReduceSum(df, c33) + addon3;
|
||||||
|
}
|
||||||
|
|
||||||
// Accumulates a single kNumRowsx4 tile of A x B into C. B is transposed, so we
|
// Accumulates a single kNumRowsx4 tile of A x B into C. B is transposed, so we
|
||||||
// can iterate over both A and B with consecutive vector loads. kNumRows<=4.
|
// can iterate over both A and B with consecutive vector loads. kNumRows<=4.
|
||||||
// Shared between parallelized and sequential (loop) callers.
|
// Shared between parallelized and sequential (loop) callers.
|
||||||
template <size_t kNumRows, size_t kColsA_RowsB, typename MatT, HWY_IF_F32(MatT)>
|
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatT,
|
||||||
|
HWY_IF_F32(MatT), typename AddT>
|
||||||
HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
const MatT* HWY_RESTRICT B, MatT* HWY_RESTRICT C,
|
const MatT* HWY_RESTRICT B, MatT* HWY_RESTRICT C,
|
||||||
|
const AddT* add,
|
||||||
const size_t idx_tile, const size_t xtiles,
|
const size_t idx_tile, const size_t xtiles,
|
||||||
const size_t stride_a, const size_t stride_b,
|
const size_t stride_a, const size_t stride_b,
|
||||||
const size_t stride_c) {
|
const size_t stride_c) {
|
||||||
|
|
@ -203,21 +246,21 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
c01 = hn::MulAdd(a0, b1, c01);
|
c01 = hn::MulAdd(a0, b1, c01);
|
||||||
c02 = hn::MulAdd(a0, b2, c02);
|
c02 = hn::MulAdd(a0, b2, c02);
|
||||||
c03 = hn::MulAdd(a0, b3, c03);
|
c03 = hn::MulAdd(a0, b3, c03);
|
||||||
if (kNumRows == 1) continue;
|
if constexpr (kNumRows == 1) continue;
|
||||||
|
|
||||||
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
|
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
|
||||||
c10 = hn::MulAdd(a1, b0, c10);
|
c10 = hn::MulAdd(a1, b0, c10);
|
||||||
c11 = hn::MulAdd(a1, b1, c11);
|
c11 = hn::MulAdd(a1, b1, c11);
|
||||||
c12 = hn::MulAdd(a1, b2, c12);
|
c12 = hn::MulAdd(a1, b2, c12);
|
||||||
c13 = hn::MulAdd(a1, b3, c13);
|
c13 = hn::MulAdd(a1, b3, c13);
|
||||||
if (kNumRows == 2) continue;
|
if constexpr (kNumRows == 2) continue;
|
||||||
|
|
||||||
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
|
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
|
||||||
c20 = hn::MulAdd(a2, b0, c20);
|
c20 = hn::MulAdd(a2, b0, c20);
|
||||||
c21 = hn::MulAdd(a2, b1, c21);
|
c21 = hn::MulAdd(a2, b1, c21);
|
||||||
c22 = hn::MulAdd(a2, b2, c22);
|
c22 = hn::MulAdd(a2, b2, c22);
|
||||||
c23 = hn::MulAdd(a2, b3, c23);
|
c23 = hn::MulAdd(a2, b3, c23);
|
||||||
if (kNumRows == 3) continue;
|
if constexpr (kNumRows == 3) continue;
|
||||||
|
|
||||||
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
|
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
|
||||||
c30 = hn::MulAdd(a3, b0, c30);
|
c30 = hn::MulAdd(a3, b0, c30);
|
||||||
|
|
@ -227,9 +270,16 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
}
|
}
|
||||||
|
|
||||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||||
StoreHorizontalSums<kNumRows>(
|
if constexpr (kAdd) {
|
||||||
d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
const AddT* tile_add = add + row_b_col_c;
|
||||||
c30, c31, c32, c33, tile_c, stride_c);
|
StoreHorizontalSumsAdd<kNumRows>(
|
||||||
|
d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||||
|
c30, c31, c32, c33, tile_add, tile_c, stride_c);
|
||||||
|
} else {
|
||||||
|
StoreHorizontalSums<kNumRows>(
|
||||||
|
d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||||
|
c30, c31, c32, c33, tile_c, stride_c);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#undef GEMMA_NATIVE_BF16
|
#undef GEMMA_NATIVE_BF16
|
||||||
|
|
@ -241,10 +291,11 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// As above, for MatT=bf16
|
// As above, for MatT=bf16
|
||||||
template <size_t kNumRows, size_t kColsA_RowsB, typename MatT,
|
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatT,
|
||||||
HWY_IF_BF16(MatT)>
|
HWY_IF_BF16(MatT), typename AddT>
|
||||||
HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
const MatT* HWY_RESTRICT B, float* HWY_RESTRICT C,
|
const MatT* HWY_RESTRICT B, float* HWY_RESTRICT C,
|
||||||
|
const AddT* add,
|
||||||
const size_t idx_tile, const size_t xtiles,
|
const size_t idx_tile, const size_t xtiles,
|
||||||
const size_t stride_a, const size_t stride_b,
|
const size_t stride_a, const size_t stride_b,
|
||||||
const size_t stride_c) {
|
const size_t stride_c) {
|
||||||
|
|
@ -311,21 +362,21 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
c01 = hn::ReorderWidenMulAccumulate(df, a0, b1, c01, unused_sum1);
|
c01 = hn::ReorderWidenMulAccumulate(df, a0, b1, c01, unused_sum1);
|
||||||
c02 = hn::ReorderWidenMulAccumulate(df, a0, b2, c02, unused_sum1);
|
c02 = hn::ReorderWidenMulAccumulate(df, a0, b2, c02, unused_sum1);
|
||||||
c03 = hn::ReorderWidenMulAccumulate(df, a0, b3, c03, unused_sum1);
|
c03 = hn::ReorderWidenMulAccumulate(df, a0, b3, c03, unused_sum1);
|
||||||
if (kNumRows == 1) continue;
|
if constexpr (kNumRows == 1) continue;
|
||||||
|
|
||||||
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
|
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
|
||||||
c10 = hn::ReorderWidenMulAccumulate(df, a1, b0, c10, unused_sum1);
|
c10 = hn::ReorderWidenMulAccumulate(df, a1, b0, c10, unused_sum1);
|
||||||
c11 = hn::ReorderWidenMulAccumulate(df, a1, b1, c11, unused_sum1);
|
c11 = hn::ReorderWidenMulAccumulate(df, a1, b1, c11, unused_sum1);
|
||||||
c12 = hn::ReorderWidenMulAccumulate(df, a1, b2, c12, unused_sum1);
|
c12 = hn::ReorderWidenMulAccumulate(df, a1, b2, c12, unused_sum1);
|
||||||
c13 = hn::ReorderWidenMulAccumulate(df, a1, b3, c13, unused_sum1);
|
c13 = hn::ReorderWidenMulAccumulate(df, a1, b3, c13, unused_sum1);
|
||||||
if (kNumRows == 2) continue;
|
if constexpr (kNumRows == 2) continue;
|
||||||
|
|
||||||
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
|
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
|
||||||
c20 = hn::ReorderWidenMulAccumulate(df, a2, b0, c20, unused_sum1);
|
c20 = hn::ReorderWidenMulAccumulate(df, a2, b0, c20, unused_sum1);
|
||||||
c21 = hn::ReorderWidenMulAccumulate(df, a2, b1, c21, unused_sum1);
|
c21 = hn::ReorderWidenMulAccumulate(df, a2, b1, c21, unused_sum1);
|
||||||
c22 = hn::ReorderWidenMulAccumulate(df, a2, b2, c22, unused_sum1);
|
c22 = hn::ReorderWidenMulAccumulate(df, a2, b2, c22, unused_sum1);
|
||||||
c23 = hn::ReorderWidenMulAccumulate(df, a2, b3, c23, unused_sum1);
|
c23 = hn::ReorderWidenMulAccumulate(df, a2, b3, c23, unused_sum1);
|
||||||
if (kNumRows == 3) continue;
|
if constexpr (kNumRows == 3) continue;
|
||||||
|
|
||||||
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
|
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
|
||||||
c30 = hn::ReorderWidenMulAccumulate(df, a3, b0, c30, unused_sum1);
|
c30 = hn::ReorderWidenMulAccumulate(df, a3, b0, c30, unused_sum1);
|
||||||
|
|
@ -348,7 +399,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
c01 = hn::MulAdd(a0, b1, c01);
|
c01 = hn::MulAdd(a0, b1, c01);
|
||||||
c02 = hn::MulAdd(a0, b2, c02);
|
c02 = hn::MulAdd(a0, b2, c02);
|
||||||
c03 = hn::MulAdd(a0, b3, c03);
|
c03 = hn::MulAdd(a0, b3, c03);
|
||||||
if (kNumRows == 1) continue;
|
if constexpr (kNumRows == 1) continue;
|
||||||
|
|
||||||
const VF a1 =
|
const VF a1 =
|
||||||
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 1 + col_ab));
|
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 1 + col_ab));
|
||||||
|
|
@ -356,7 +407,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
c11 = hn::MulAdd(a1, b1, c11);
|
c11 = hn::MulAdd(a1, b1, c11);
|
||||||
c12 = hn::MulAdd(a1, b2, c12);
|
c12 = hn::MulAdd(a1, b2, c12);
|
||||||
c13 = hn::MulAdd(a1, b3, c13);
|
c13 = hn::MulAdd(a1, b3, c13);
|
||||||
if (kNumRows == 2) continue;
|
if constexpr (kNumRows == 2) continue;
|
||||||
|
|
||||||
const VF a2 =
|
const VF a2 =
|
||||||
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 2 + col_ab));
|
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 2 + col_ab));
|
||||||
|
|
@ -364,7 +415,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
c21 = hn::MulAdd(a2, b1, c21);
|
c21 = hn::MulAdd(a2, b1, c21);
|
||||||
c22 = hn::MulAdd(a2, b2, c22);
|
c22 = hn::MulAdd(a2, b2, c22);
|
||||||
c23 = hn::MulAdd(a2, b3, c23);
|
c23 = hn::MulAdd(a2, b3, c23);
|
||||||
if (kNumRows == 3) continue;
|
if constexpr (kNumRows == 3) continue;
|
||||||
|
|
||||||
const VF a3 =
|
const VF a3 =
|
||||||
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 3 + col_ab));
|
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 3 + col_ab));
|
||||||
|
|
@ -381,17 +432,27 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||||
StoreHorizontalSums<kNumRows>(
|
if constexpr (kAdd) {
|
||||||
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22,
|
const AddT* tile_add = add + row_b_col_c;
|
||||||
c23, c30, c31, c32, c33, tile_c, stride_c);
|
StoreHorizontalSumsAdd<kNumRows>(
|
||||||
|
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||||
|
c30, c31, c32, c33, tile_add, tile_c, stride_c);
|
||||||
|
} else {
|
||||||
|
StoreHorizontalSums<kNumRows>(
|
||||||
|
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||||
|
c30, c31, c32, c33, tile_c, stride_c);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as above, but with mixed Mat types: (f32, compressed).
|
// Same as above, but with mixed Mat types: (f32, compressed).
|
||||||
template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
|
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatTA,
|
||||||
HWY_IF_F32(MatTA), typename MatTB, HWY_IF_T_SIZE(MatTB, 1)>
|
HWY_IF_F32(MatTA), typename MatTB, HWY_IF_T_SIZE(MatTB, 1),
|
||||||
|
typename AddT>
|
||||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
const MatTB* HWY_RESTRICT B,
|
const MatTB* HWY_RESTRICT B,
|
||||||
float* HWY_RESTRICT C, const size_t idx_tile,
|
float* HWY_RESTRICT C,
|
||||||
|
const AddT* add,
|
||||||
|
const size_t idx_tile,
|
||||||
const size_t xtiles, const size_t stride_a,
|
const size_t xtiles, const size_t stride_a,
|
||||||
const size_t stride_b, const size_t stride_c) {
|
const size_t stride_b, const size_t stride_c) {
|
||||||
constexpr size_t kRegRows = 4;
|
constexpr size_t kRegRows = 4;
|
||||||
|
|
@ -450,21 +511,21 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
c01 = hn::MulAdd(a0, b1, c01);
|
c01 = hn::MulAdd(a0, b1, c01);
|
||||||
c02 = hn::MulAdd(a0, b2, c02);
|
c02 = hn::MulAdd(a0, b2, c02);
|
||||||
c03 = hn::MulAdd(a0, b3, c03);
|
c03 = hn::MulAdd(a0, b3, c03);
|
||||||
if (kNumRows == 1) continue;
|
if constexpr (kNumRows == 1) continue;
|
||||||
|
|
||||||
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
|
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
|
||||||
c10 = hn::MulAdd(a1, b0, c10);
|
c10 = hn::MulAdd(a1, b0, c10);
|
||||||
c11 = hn::MulAdd(a1, b1, c11);
|
c11 = hn::MulAdd(a1, b1, c11);
|
||||||
c12 = hn::MulAdd(a1, b2, c12);
|
c12 = hn::MulAdd(a1, b2, c12);
|
||||||
c13 = hn::MulAdd(a1, b3, c13);
|
c13 = hn::MulAdd(a1, b3, c13);
|
||||||
if (kNumRows == 2) continue;
|
if constexpr (kNumRows == 2) continue;
|
||||||
|
|
||||||
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
|
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
|
||||||
c20 = hn::MulAdd(a2, b0, c20);
|
c20 = hn::MulAdd(a2, b0, c20);
|
||||||
c21 = hn::MulAdd(a2, b1, c21);
|
c21 = hn::MulAdd(a2, b1, c21);
|
||||||
c22 = hn::MulAdd(a2, b2, c22);
|
c22 = hn::MulAdd(a2, b2, c22);
|
||||||
c23 = hn::MulAdd(a2, b3, c23);
|
c23 = hn::MulAdd(a2, b3, c23);
|
||||||
if (kNumRows == 3) continue;
|
if constexpr (kNumRows == 3) continue;
|
||||||
|
|
||||||
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
|
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
|
||||||
c30 = hn::MulAdd(a3, b0, c30);
|
c30 = hn::MulAdd(a3, b0, c30);
|
||||||
|
|
@ -474,17 +535,27 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
}
|
}
|
||||||
|
|
||||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||||
StoreHorizontalSums<kNumRows>(
|
if constexpr (kAdd) {
|
||||||
d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22,
|
const AddT* tile_add = add + row_b_col_c;
|
||||||
c23, c30, c31, c32, c33, tile_c, stride_c);
|
StoreHorizontalSumsAdd<kNumRows>(
|
||||||
|
d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||||
|
c30, c31, c32, c33, tile_add, tile_c, stride_c);
|
||||||
|
} else {
|
||||||
|
StoreHorizontalSums<kNumRows>(
|
||||||
|
d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||||
|
c30, c31, c32, c33, tile_c, stride_c);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as above, but with mixed Mat types: (bf16, compressed)).
|
// Same as above, but with mixed Mat types: (bf16, compressed).
|
||||||
template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
|
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatTA,
|
||||||
HWY_IF_BF16(MatTA), typename MatTB, HWY_IF_T_SIZE(MatTB, 1)>
|
HWY_IF_BF16(MatTA), typename MatTB, HWY_IF_T_SIZE(MatTB, 1),
|
||||||
|
typename AddT>
|
||||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
const MatTB* HWY_RESTRICT B,
|
const MatTB* HWY_RESTRICT B,
|
||||||
float* HWY_RESTRICT C, const size_t idx_tile,
|
float* HWY_RESTRICT C,
|
||||||
|
const AddT* add,
|
||||||
|
const size_t idx_tile,
|
||||||
const size_t xtiles, const size_t stride_a,
|
const size_t xtiles, const size_t stride_a,
|
||||||
const size_t stride_b, const size_t stride_c) {
|
const size_t stride_b, const size_t stride_c) {
|
||||||
constexpr size_t kRegRows = 4;
|
constexpr size_t kRegRows = 4;
|
||||||
|
|
@ -549,7 +620,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
c01 = hn::MulAdd(a0, b1, c01);
|
c01 = hn::MulAdd(a0, b1, c01);
|
||||||
c02 = hn::MulAdd(a0, b2, c02);
|
c02 = hn::MulAdd(a0, b2, c02);
|
||||||
c03 = hn::MulAdd(a0, b3, c03);
|
c03 = hn::MulAdd(a0, b3, c03);
|
||||||
if (kNumRows == 1) continue;
|
if constexpr (kNumRows == 1) continue;
|
||||||
|
|
||||||
const V a1 =
|
const V a1 =
|
||||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab));
|
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab));
|
||||||
|
|
@ -557,7 +628,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
c11 = hn::MulAdd(a1, b1, c11);
|
c11 = hn::MulAdd(a1, b1, c11);
|
||||||
c12 = hn::MulAdd(a1, b2, c12);
|
c12 = hn::MulAdd(a1, b2, c12);
|
||||||
c13 = hn::MulAdd(a1, b3, c13);
|
c13 = hn::MulAdd(a1, b3, c13);
|
||||||
if (kNumRows == 2) continue;
|
if constexpr (kNumRows == 2) continue;
|
||||||
|
|
||||||
const V a2 =
|
const V a2 =
|
||||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab));
|
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab));
|
||||||
|
|
@ -565,7 +636,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
c21 = hn::MulAdd(a2, b1, c21);
|
c21 = hn::MulAdd(a2, b1, c21);
|
||||||
c22 = hn::MulAdd(a2, b2, c22);
|
c22 = hn::MulAdd(a2, b2, c22);
|
||||||
c23 = hn::MulAdd(a2, b3, c23);
|
c23 = hn::MulAdd(a2, b3, c23);
|
||||||
if (kNumRows == 3) continue;
|
if constexpr (kNumRows == 3) continue;
|
||||||
|
|
||||||
const V a3 =
|
const V a3 =
|
||||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab));
|
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab));
|
||||||
|
|
@ -576,18 +647,26 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
}
|
}
|
||||||
|
|
||||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||||
StoreHorizontalSums<kNumRows>(
|
if constexpr (kAdd) {
|
||||||
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21,
|
const AddT* tile_add = add + row_b_col_c;
|
||||||
c22, c23, c30, c31, c32, c33, tile_c, stride_c);
|
StoreHorizontalSumsAdd<kNumRows>(
|
||||||
|
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||||
|
c30, c31, c32, c33, tile_add, tile_c, stride_c);
|
||||||
|
} else {
|
||||||
|
StoreHorizontalSums<kNumRows>(
|
||||||
|
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||||
|
c30, c31, c32, c33, tile_c, stride_c);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as above, but with mixed Mat types: (f32, bf16).
|
// Same as above, but with mixed Mat types: (f32, bf16).
|
||||||
template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
|
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatTA,
|
||||||
HWY_IF_F32(MatTA),
|
HWY_IF_F32(MatTA),
|
||||||
typename MatTB, HWY_IF_BF16(MatTB)>
|
typename MatTB, HWY_IF_BF16(MatTB), typename AddT>
|
||||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
const MatTB* HWY_RESTRICT B,
|
const MatTB* HWY_RESTRICT B,
|
||||||
float* HWY_RESTRICT C, const size_t idx_tile,
|
float* HWY_RESTRICT C, const AddT* add,
|
||||||
|
const size_t idx_tile,
|
||||||
const size_t xtiles, const size_t stride_a,
|
const size_t xtiles, const size_t stride_a,
|
||||||
const size_t stride_b, const size_t stride_c) {
|
const size_t stride_b, const size_t stride_c) {
|
||||||
constexpr size_t kRegRows = 4;
|
constexpr size_t kRegRows = 4;
|
||||||
|
|
@ -651,21 +730,21 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
c01 = hn::MulAdd(a0, b1, c01);
|
c01 = hn::MulAdd(a0, b1, c01);
|
||||||
c02 = hn::MulAdd(a0, b2, c02);
|
c02 = hn::MulAdd(a0, b2, c02);
|
||||||
c03 = hn::MulAdd(a0, b3, c03);
|
c03 = hn::MulAdd(a0, b3, c03);
|
||||||
if (kNumRows == 1) continue;
|
if constexpr (kNumRows == 1) continue;
|
||||||
|
|
||||||
const VF a1 = hn::LoadU(d32, tile_a + stride_a * 1 + col_ab);
|
const VF a1 = hn::LoadU(d32, tile_a + stride_a * 1 + col_ab);
|
||||||
c10 = hn::MulAdd(a1, b0, c10);
|
c10 = hn::MulAdd(a1, b0, c10);
|
||||||
c11 = hn::MulAdd(a1, b1, c11);
|
c11 = hn::MulAdd(a1, b1, c11);
|
||||||
c12 = hn::MulAdd(a1, b2, c12);
|
c12 = hn::MulAdd(a1, b2, c12);
|
||||||
c13 = hn::MulAdd(a1, b3, c13);
|
c13 = hn::MulAdd(a1, b3, c13);
|
||||||
if (kNumRows == 2) continue;
|
if constexpr (kNumRows == 2) continue;
|
||||||
|
|
||||||
const VF a2 = hn::LoadU(d32, tile_a + stride_a * 2 + col_ab);
|
const VF a2 = hn::LoadU(d32, tile_a + stride_a * 2 + col_ab);
|
||||||
c20 = hn::MulAdd(a2, b0, c20);
|
c20 = hn::MulAdd(a2, b0, c20);
|
||||||
c21 = hn::MulAdd(a2, b1, c21);
|
c21 = hn::MulAdd(a2, b1, c21);
|
||||||
c22 = hn::MulAdd(a2, b2, c22);
|
c22 = hn::MulAdd(a2, b2, c22);
|
||||||
c23 = hn::MulAdd(a2, b3, c23);
|
c23 = hn::MulAdd(a2, b3, c23);
|
||||||
if (kNumRows == 3) continue;
|
if constexpr (kNumRows == 3) continue;
|
||||||
|
|
||||||
const VF a3 = hn::LoadU(d32, tile_a + stride_a * 3 + col_ab);
|
const VF a3 = hn::LoadU(d32, tile_a + stride_a * 3 + col_ab);
|
||||||
c30 = hn::MulAdd(a3, b0, c30);
|
c30 = hn::MulAdd(a3, b0, c30);
|
||||||
|
|
@ -675,17 +754,25 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
}
|
}
|
||||||
|
|
||||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||||
StoreHorizontalSums<kNumRows>(
|
if constexpr (kAdd) {
|
||||||
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21,
|
const AddT* tile_add = add + row_b_col_c;
|
||||||
c22, c23, c30, c31, c32, c33, tile_c, stride_c);
|
StoreHorizontalSumsAdd<kNumRows>(
|
||||||
|
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||||
|
c30, c31, c32, c33, tile_add, tile_c, stride_c);
|
||||||
|
} else {
|
||||||
|
StoreHorizontalSums<kNumRows>(
|
||||||
|
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||||
|
c30, c31, c32, c33, tile_c, stride_c);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as above, but with mixed Mat types: (bf16, f32).
|
// Same as above, but with mixed Mat types: (bf16, f32).
|
||||||
template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
|
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatTA,
|
||||||
HWY_IF_BF16(MatTA), typename MatTB, HWY_IF_F32(MatTB)>
|
HWY_IF_BF16(MatTA), typename MatTB, HWY_IF_F32(MatTB), typename AddT>
|
||||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
const MatTB* HWY_RESTRICT B,
|
const MatTB* HWY_RESTRICT B,
|
||||||
float* HWY_RESTRICT C, const size_t idx_tile,
|
float* HWY_RESTRICT C, const AddT* add,
|
||||||
|
const size_t idx_tile,
|
||||||
const size_t xtiles, const size_t stride_a,
|
const size_t xtiles, const size_t stride_a,
|
||||||
const size_t stride_b, const size_t stride_c) {
|
const size_t stride_b, const size_t stride_c) {
|
||||||
constexpr size_t kRegRows = 4;
|
constexpr size_t kRegRows = 4;
|
||||||
|
|
@ -746,7 +833,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
c01 = hn::MulAdd(a0, b1, c01);
|
c01 = hn::MulAdd(a0, b1, c01);
|
||||||
c02 = hn::MulAdd(a0, b2, c02);
|
c02 = hn::MulAdd(a0, b2, c02);
|
||||||
c03 = hn::MulAdd(a0, b3, c03);
|
c03 = hn::MulAdd(a0, b3, c03);
|
||||||
if (kNumRows == 1) continue;
|
if constexpr (kNumRows == 1) continue;
|
||||||
|
|
||||||
const VF a1 =
|
const VF a1 =
|
||||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab));
|
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab));
|
||||||
|
|
@ -754,7 +841,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
c11 = hn::MulAdd(a1, b1, c11);
|
c11 = hn::MulAdd(a1, b1, c11);
|
||||||
c12 = hn::MulAdd(a1, b2, c12);
|
c12 = hn::MulAdd(a1, b2, c12);
|
||||||
c13 = hn::MulAdd(a1, b3, c13);
|
c13 = hn::MulAdd(a1, b3, c13);
|
||||||
if (kNumRows == 2) continue;
|
if constexpr (kNumRows == 2) continue;
|
||||||
|
|
||||||
const VF a2 =
|
const VF a2 =
|
||||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab));
|
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab));
|
||||||
|
|
@ -762,7 +849,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
c21 = hn::MulAdd(a2, b1, c21);
|
c21 = hn::MulAdd(a2, b1, c21);
|
||||||
c22 = hn::MulAdd(a2, b2, c22);
|
c22 = hn::MulAdd(a2, b2, c22);
|
||||||
c23 = hn::MulAdd(a2, b3, c23);
|
c23 = hn::MulAdd(a2, b3, c23);
|
||||||
if (kNumRows == 3) continue;
|
if constexpr (kNumRows == 3) continue;
|
||||||
|
|
||||||
const VF a3 =
|
const VF a3 =
|
||||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab));
|
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab));
|
||||||
|
|
@ -773,19 +860,27 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
}
|
}
|
||||||
|
|
||||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||||
StoreHorizontalSums<kNumRows>(d32, c00, c01, c02, c03, c10, c11, c12, c13,
|
if constexpr (kAdd) {
|
||||||
c20, c21, c22, c23, c30, c31, c32, c33, tile_c,
|
const AddT* tile_add = add + row_b_col_c;
|
||||||
stride_c);
|
StoreHorizontalSumsAdd<kNumRows>(
|
||||||
|
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||||
|
c30, c31, c32, c33, tile_add, tile_c, stride_c);
|
||||||
|
} else {
|
||||||
|
StoreHorizontalSums<kNumRows>(
|
||||||
|
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||||
|
c30, c31, c32, c33, tile_c, stride_c);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k,
|
// Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k,
|
||||||
// and kColsBC is 24k or 3k. Note: B is transposed (column-major).
|
// and kColsBC is 24k or 3k. Note: B is transposed (column-major).
|
||||||
// NOTE that batch_size is the number of rows of A and C.
|
// NOTE that batch_size is the number of rows of A and C.
|
||||||
// This function processes tiles in parallel with a work-stealing thread pool.
|
// This function processes tiles in parallel with a work-stealing thread pool.
|
||||||
template <size_t kColsA_RowsB, size_t kColsBC, typename MatTA,
|
template <size_t kColsA_RowsB, size_t kColsBC, bool kAdd, typename MatTA,
|
||||||
typename MatTB, typename OutT>
|
typename MatTB, typename OutT, typename AddT>
|
||||||
HWY_NOINLINE void MatMul_4x4_Batch(
|
HWY_NOINLINE void MatMul_4x4_Batch_Add(
|
||||||
size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B,
|
size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B,
|
||||||
OutT* HWY_RESTRICT C, hwy::ThreadPool& pool) {
|
OutT* HWY_RESTRICT C, const AddT* add, hwy::ThreadPool& pool) {
|
||||||
// Process reg-sized tiles of C in parallel. We currently write C directly,
|
// Process reg-sized tiles of C in parallel. We currently write C directly,
|
||||||
// which touches more memory than fits in L3. TODO: add another level of loops
|
// which touches more memory than fits in L3. TODO: add another level of loops
|
||||||
// so that we finish one L3-sized piece of C at a time.
|
// so that we finish one L3-sized piece of C at a time.
|
||||||
|
|
@ -810,51 +905,31 @@ HWY_NOINLINE void MatMul_4x4_Batch(
|
||||||
HWY_ASSERT(num_rows > 0);
|
HWY_ASSERT(num_rows > 0);
|
||||||
switch (num_rows) {
|
switch (num_rows) {
|
||||||
case 1:
|
case 1:
|
||||||
GEMM_4x4_Tile<1, kColsA_RowsB>(
|
GEMM_4x4_Tile<1, kColsA_RowsB, kAdd>(
|
||||||
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
A, B, C, add, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
GEMM_4x4_Tile<2, kColsA_RowsB>(
|
GEMM_4x4_Tile<2, kColsA_RowsB, kAdd>(
|
||||||
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
A, B, C, add, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
GEMM_4x4_Tile<3, kColsA_RowsB>(
|
GEMM_4x4_Tile<3, kColsA_RowsB, kAdd>(
|
||||||
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
A, B, C, add, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GEMM_4x4_Tile<4, kColsA_RowsB>(
|
GEMM_4x4_Tile<4, kColsA_RowsB, kAdd>(
|
||||||
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
A, B, C, add, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
|
template <size_t kColsA_RowsB, size_t kColsBC, typename MatTA,
|
||||||
// ops_test across instruction sets.
|
typename MatTB, typename OutT>
|
||||||
template <size_t kM, size_t kN, size_t kK, typename MatTA, typename MatTB>
|
HWY_NOINLINE void MatMul_4x4_Batch(
|
||||||
HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a,
|
size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B,
|
||||||
const MatTB* HWY_RESTRICT b,
|
OutT* HWY_RESTRICT C, hwy::ThreadPool& pool) {
|
||||||
float* HWY_RESTRICT out) {
|
MatMul_4x4_Batch_Add<kColsA_RowsB, kColsBC, /*kAdd=*/false>(
|
||||||
for (size_t i = 0; i < kM; ++i) {
|
batch_size, A, B, C, /*add=*/static_cast<OutT*>(nullptr), pool);
|
||||||
for (size_t k = 0; k < kN; ++k) {
|
|
||||||
for (size_t j = 0; j < kK; ++j) {
|
|
||||||
const float a1 = hwy::ConvertScalarTo<float>(a[i * kN + k]);
|
|
||||||
const float b1 = hwy::ConvertScalarTo<float>(b[k * kK + j]);
|
|
||||||
out[i * kK + j] += a1 * b1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <size_t kM, size_t kN, size_t kK, typename MatTA>
|
|
||||||
HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a,
|
|
||||||
const SfpStream* HWY_RESTRICT b_sfp_stream,
|
|
||||||
float* HWY_RESTRICT out) {
|
|
||||||
const hn::ScalableTag<float> d;
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
|
|
||||||
CompressTraits<SfpStream>::Decompress(d,
|
|
||||||
/*in_capacity=*/0, b_sfp_stream, 0,
|
|
||||||
b.get(), kK * kN);
|
|
||||||
MatMulSlow<kM, kN, kK>(a, b.get(), out);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
|
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
|
||||||
|
|
@ -863,6 +938,7 @@ template <size_t kN, size_t kK, typename MatTA, typename MatTB,
|
||||||
HWY_IF_T_SIZE_GT(MatTB, 1)>
|
HWY_IF_T_SIZE_GT(MatTB, 1)>
|
||||||
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
|
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
|
||||||
const MatTB* HWY_RESTRICT b,
|
const MatTB* HWY_RESTRICT b,
|
||||||
|
const float* add,
|
||||||
float* HWY_RESTRICT out) {
|
float* HWY_RESTRICT out) {
|
||||||
for (size_t i = 0; i < batch_size; ++i) {
|
for (size_t i = 0; i < batch_size; ++i) {
|
||||||
for (size_t k = 0; k < kN; ++k) {
|
for (size_t k = 0; k < kN; ++k) {
|
||||||
|
|
@ -872,6 +948,11 @@ HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
|
||||||
out[i * kK + j] += a1 * b1;
|
out[i * kK + j] += a1 * b1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (add != nullptr) {
|
||||||
|
for (size_t j = 0; j < kK; ++j) {
|
||||||
|
out[i * kK + j] += add[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -881,13 +962,13 @@ template <size_t kN, size_t kK, typename MatTA, typename MatTB,
|
||||||
HWY_IF_T_SIZE(MatTB, 1)>
|
HWY_IF_T_SIZE(MatTB, 1)>
|
||||||
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
|
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
|
||||||
const MatTB* HWY_RESTRICT b_compr,
|
const MatTB* HWY_RESTRICT b_compr,
|
||||||
|
const float* add,
|
||||||
float* HWY_RESTRICT out) {
|
float* HWY_RESTRICT out) {
|
||||||
const hn::ScalableTag<float> d;
|
const hn::ScalableTag<float> d;
|
||||||
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
|
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
|
||||||
CompressTraits<MatTB>::Decompress(d,
|
CompressTraits<MatTB>::Decompress(d, /*in_capacity=*/0, b_compr, 0, b.get(),
|
||||||
/*in_capacity=*/0, b_compr, 0, b.get(),
|
|
||||||
kK * kN);
|
kK * kN);
|
||||||
MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), out);
|
MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), add, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
|
HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <cstdio>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#ifndef HWY_DISABLED_TARGETS
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||||
|
|
@ -506,68 +507,80 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t kM, size_t kN, size_t kK, typename MatTA,
|
template <size_t kM, size_t kN, size_t kK, bool kAdd, typename MatTA,
|
||||||
typename MatTB = MatTA>
|
typename MatTB = MatTA>
|
||||||
void TestTiledBatchMatMul() {
|
void TestTiledBatchMatMul() {
|
||||||
fprintf(stderr, "TestTiledBatchMatMul %lu, %lu, %lu", kM, kN, kK);
|
fprintf(stderr, "TestTiledBatchMatMul %lu, %lu, %lu, add=%d ", kM, kN, kK,
|
||||||
|
kAdd);
|
||||||
hwy::ThreadPool pool(3);
|
hwy::ThreadPool pool(3);
|
||||||
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
|
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
|
||||||
GenerateMatHeap<MatTA, kM, kN>(0, pool);
|
GenerateMatHeap<MatTA, kM, kN>(0, pool);
|
||||||
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b =
|
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b =
|
||||||
GenerateMatHeap<MatTB, kN, kK>(0, pool);
|
GenerateMatHeap<MatTB, kN, kK>(0, pool);
|
||||||
|
std::unique_ptr<CompressedArray<float, kK>> add =
|
||||||
|
GenerateMatHeap<float, 1, kK>(0, pool);
|
||||||
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow =
|
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow =
|
||||||
GenerateZeroMatHeap<float, kM, kK>(pool);
|
GenerateZeroMatHeap<float, kM, kK>(pool);
|
||||||
|
|
||||||
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(), c_slow->data());
|
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(),
|
||||||
|
kAdd ? add->data() : nullptr, c_slow->data());
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
|
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
|
||||||
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b_trans =
|
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b_trans =
|
||||||
GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool);
|
GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool);
|
||||||
MatMul_4x4_Batch<kN, kK>(kM, a->data(), b_trans->data(), c.get(), pool);
|
if (kAdd) {
|
||||||
|
MatMul_4x4_Batch_Add<kN, kK, kAdd>(kM, a->data(), b_trans->data(), c.get(),
|
||||||
|
add->data(), pool);
|
||||||
|
} else {
|
||||||
|
MatMul_4x4_Batch<kN, kK>(kM, a->data(), b_trans->data(), c.get(), pool);
|
||||||
|
}
|
||||||
|
|
||||||
AssertClose(c_slow->data(), c.get(), kM * kK);
|
AssertClose(c_slow->data(), c.get(), kM * kK);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestAllTiledBatchMatMul() {
|
void TestAllTiledBatchMatMul() {
|
||||||
// medium-sized square test
|
// medium-sized square test
|
||||||
TestTiledBatchMatMul<512, 512, 512, float>();
|
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float>();
|
||||||
TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t>();
|
||||||
TestTiledBatchMatMul<512, 512, 512, float, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float, hwy::bfloat16_t>();
|
||||||
TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t, float>();
|
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t, float>();
|
||||||
TestTiledBatchMatMul<512, 512, 512, float, SfpStream>();
|
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float, SfpStream>();
|
||||||
TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t, SfpStream>();
|
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t,
|
||||||
|
SfpStream>();
|
||||||
|
|
||||||
// minimal non-square test
|
// minimal non-square test
|
||||||
TestTiledBatchMatMul<35, 128, 32, float>();
|
TestTiledBatchMatMul<35, 128, 32, /*kAdd=*/false, float>();
|
||||||
TestTiledBatchMatMul<34, 128, 32, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<34, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>();
|
||||||
TestTiledBatchMatMul<33, 128, 32, float, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>();
|
||||||
TestTiledBatchMatMul<33, 128, 32, hwy::bfloat16_t, float>();
|
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>();
|
||||||
TestTiledBatchMatMul<31, 128, 32, float, SfpStream>();
|
TestTiledBatchMatMul<31, 128, 32, /*kAdd=*/false, float, SfpStream>();
|
||||||
TestTiledBatchMatMul<29, 128, 32, hwy::bfloat16_t, SfpStream>();
|
TestTiledBatchMatMul<29, 128, 32, /*kAdd=*/true, hwy::bfloat16_t,
|
||||||
TestTiledBatchMatMul<4, 128, 8, float>();
|
SfpStream>();
|
||||||
TestTiledBatchMatMul<4, 128, 8, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float>();
|
||||||
TestTiledBatchMatMul<4, 128, 8, float, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t>();
|
||||||
TestTiledBatchMatMul<4, 128, 8, hwy::bfloat16_t, float>();
|
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float, hwy::bfloat16_t>();
|
||||||
TestTiledBatchMatMul<4, 128, 8, float, SfpStream>();
|
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t, float>();
|
||||||
TestTiledBatchMatMul<4, 128, 8, hwy::bfloat16_t, SfpStream>();
|
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float, SfpStream>();
|
||||||
TestTiledBatchMatMul<3, 128, 32, float>();
|
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t, SfpStream>();
|
||||||
TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float>();
|
||||||
TestTiledBatchMatMul<3, 128, 32, float, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>();
|
||||||
TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t, float>();
|
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>();
|
||||||
TestTiledBatchMatMul<3, 128, 32, float, SfpStream>();
|
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>();
|
||||||
TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t, SfpStream>();
|
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float, SfpStream>();
|
||||||
TestTiledBatchMatMul<2, 128, 16, float>();
|
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, SfpStream>();
|
||||||
TestTiledBatchMatMul<2, 128, 16, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float>();
|
||||||
TestTiledBatchMatMul<2, 128, 16, float, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t>();
|
||||||
TestTiledBatchMatMul<2, 128, 16, hwy::bfloat16_t, float>();
|
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float, hwy::bfloat16_t>();
|
||||||
TestTiledBatchMatMul<2, 128, 16, float, SfpStream>();
|
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t, float>();
|
||||||
TestTiledBatchMatMul<2, 128, 16, hwy::bfloat16_t, SfpStream>();
|
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float, SfpStream>();
|
||||||
TestTiledBatchMatMul<1, 128, 32, float>();
|
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t,
|
||||||
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t>();
|
SfpStream>();
|
||||||
TestTiledBatchMatMul<1, 128, 32, float, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float>();
|
||||||
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, float>();
|
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>();
|
||||||
TestTiledBatchMatMul<1, 128, 32, float, SfpStream>();
|
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>();
|
||||||
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, SfpStream>();
|
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>();
|
||||||
|
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float, SfpStream>();
|
||||||
|
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, SfpStream>();
|
||||||
|
|
||||||
// large-scale test
|
// large-scale test
|
||||||
// TODO(philculliton): investigate rounding issues with large matrices.
|
// TODO(philculliton): investigate rounding issues with large matrices.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue