Simplify matmul: only 2 overloads

Also add StoreHorizontalSumsMaybeAdd wrapper function,
move MatMulSlowBatch into test.

1.02-1.06x speedup.

PiperOrigin-RevId: 651394791
This commit is contained in:
Jan Wassenberg 2024-07-11 06:58:07 -07:00 committed by Copybara-Service
parent 3e92088595
commit be765afce2
2 changed files with 103 additions and 465 deletions

View File

@ -28,7 +28,6 @@
#include <type_traits> // std::enable_if_t
#include "compression/sfp.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/detect_targets.h"
@ -142,12 +141,12 @@ HWY_INLINE void StoreHorizontalSums(DF df, VF c00, VF c01, VF c02, VF c03,
}
// Completes the tile by summing across the vectors, and adds the biases.
template <size_t kNumRows, class DF, class VF = hn::Vec<DF>, typename AddT>
template <size_t kNumRows, class DF, class VF = hn::Vec<DF>>
HWY_INLINE void StoreHorizontalSumsAdd(DF df, VF c00, VF c01, VF c02, VF c03,
VF c10, VF c11, VF c12, VF c13, //
VF c20, VF c21, VF c22, VF c23, //
VF c30, VF c31, VF c32, VF c33,
const AddT* add,
const float* HWY_RESTRICT add,
float* HWY_RESTRICT tile_c,
size_t stride_c) {
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
@ -182,103 +181,23 @@ HWY_INLINE void StoreHorizontalSumsAdd(DF df, VF c00, VF c01, VF c02, VF c03,
tile_c[stride_c * 3 + 3] = hn::ReduceSum(df, c33) + addon3;
}
// Accumulates a single kNumRowsx4 tile of A x B into C. B is transposed, so we
// can iterate over both A and B with consecutive vector loads. kNumRows<=4.
// Shared between parallelized and sequential (loop) callers.
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatT,
HWY_IF_F32(MatT), typename AddT>
HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
const MatT* HWY_RESTRICT B, MatT* HWY_RESTRICT C,
const AddT* add,
const size_t idx_tile, const size_t xtiles,
const size_t stride_a, const size_t stride_b,
const size_t stride_c) {
// Tile size. 4x unrolling makes sense on many platforms because we can fit
// 4x4 accumulators and 8 temporaries in the 32 vectors; we have more than
// #FMA units * FMA latency (up to 2*5) independent computations in flight;
// threads write in units of 4*N elements, which is at least one cache line.
constexpr size_t kRegRows = 4;
constexpr size_t kRegCols = 4;
static_assert(kNumRows <= kRegRows);
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
const size_t row_a = idx_tile / xtiles * kRegRows;
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
const hn::ScalableTag<MatT> d;
const size_t N = Lanes(d);
using V = hn::Vec<decltype(d)>;
V c00 = hn::Zero(d);
V c01 = hn::Zero(d);
V c02 = hn::Zero(d);
V c03 = hn::Zero(d);
V c10 = hn::Zero(d);
V c11 = hn::Zero(d);
V c12 = hn::Zero(d);
V c13 = hn::Zero(d);
V c20 = hn::Zero(d);
V c21 = hn::Zero(d);
V c22 = hn::Zero(d);
V c23 = hn::Zero(d);
V c30 = hn::Zero(d);
V c31 = hn::Zero(d);
V c32 = hn::Zero(d);
V c33 = hn::Zero(d);
const MatT* HWY_RESTRICT tile_a = A + stride_a * row_a;
const MatT* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c;
// Loop over columns of A and columns of the transposed B, in steps of N.
// Accumulates into the c## vectors.
HWY_UNROLL(1)
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
const V b0 = hn::LoadU(d, tile_b + stride_b * 0 + col_ab);
const V b1 = hn::LoadU(d, tile_b + stride_b * 1 + col_ab);
const V b2 = hn::LoadU(d, tile_b + stride_b * 2 + col_ab);
const V b3 = hn::LoadU(d, tile_b + stride_b * 3 + col_ab);
const V a0 = hn::LoadU(d, tile_a + stride_a * 0 + col_ab);
c00 = hn::MulAdd(a0, b0, c00);
c01 = hn::MulAdd(a0, b1, c01);
c02 = hn::MulAdd(a0, b2, c02);
c03 = hn::MulAdd(a0, b3, c03);
if constexpr (kNumRows == 1) continue;
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
c10 = hn::MulAdd(a1, b0, c10);
c11 = hn::MulAdd(a1, b1, c11);
c12 = hn::MulAdd(a1, b2, c12);
c13 = hn::MulAdd(a1, b3, c13);
if constexpr (kNumRows == 2) continue;
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
c20 = hn::MulAdd(a2, b0, c20);
c21 = hn::MulAdd(a2, b1, c21);
c22 = hn::MulAdd(a2, b2, c22);
c23 = hn::MulAdd(a2, b3, c23);
if constexpr (kNumRows == 3) continue;
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
c30 = hn::MulAdd(a3, b0, c30);
c31 = hn::MulAdd(a3, b1, c31);
c32 = hn::MulAdd(a3, b2, c32);
c33 = hn::MulAdd(a3, b3, c33);
}
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
// Wrapper around StoreHorizontalSums and StoreHorizontalSumsAdd to shorten call
// sites. If `!kAdd`, `add` is nullptr, so adding `add_offset` to it would be
// UB, hence we pass it as a separate argument.
template <bool kAdd, size_t kNumRows, class DF, class VF = hn::Vec<DF>>
HWY_INLINE void StoreHorizontalSumsMaybeAdd(
DF df, VF c00, VF c01, VF c02, VF c03, VF c10, VF c11, VF c12, VF c13,
VF c20, VF c21, VF c22, VF c23, VF c30, VF c31, VF c32, VF c33,
const float* HWY_RESTRICT add, size_t add_offset,
float* HWY_RESTRICT tile_c, size_t stride_c) {
if constexpr (kAdd) {
const AddT* tile_add = add + row_b_col_c;
StoreHorizontalSumsAdd<kNumRows>(d, c00, c01, c02, c03, c10, c11, c12, c13,
StoreHorizontalSumsAdd<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
c20, c21, c22, c23, c30, c31, c32, c33,
tile_add, tile_c, stride_c);
add + add_offset, tile_c, stride_c);
} else {
StoreHorizontalSums<kNumRows>(
d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
c30, c31, c32, c33, tile_c, stride_c);
StoreHorizontalSums<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
c20, c21, c22, c23, c30, c31, c32, c33,
tile_c, stride_c);
}
}
@ -290,12 +209,13 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
#define GEMMA_NATIVE_BF16 0
#endif
// As above, for MatT=bf16
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatT,
HWY_IF_BF16(MatT), typename AddT>
HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
const MatT* HWY_RESTRICT B, float* HWY_RESTRICT C,
const AddT* add,
#if GEMMA_NATIVE_BF16
// Specialization for f32 += bf16 * bf16 that avoids promoting to f32.
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd>
HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A,
const hwy::bfloat16_t* HWY_RESTRICT B,
float* HWY_RESTRICT C, const float* add,
const size_t idx_tile, const size_t xtiles,
const size_t stride_a, const size_t stride_b,
const size_t stride_c) {
@ -309,17 +229,10 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
const hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>;
#if GEMMA_NATIVE_BF16
// ReorderWidenMulAccumulate does not use its sum1 arg and we can use full
// bf16 vectors.
const hn::Repartition<MatT, decltype(df)> d;
const hn::Repartition<hwy::bfloat16_t, decltype(df)> d;
VF unused_sum1 = hn::Zero(df);
#else
// Emulated: use half-vectors of bf16 because we cannot afford two sums for
// each c##.
const hn::Rebind<MatT, decltype(df)> d;
HWY_DASSERT(Lanes(d) == Lanes(df));
#endif
const size_t N = Lanes(d);
@ -343,14 +256,13 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
VF c32 = hn::Zero(df);
VF c33 = hn::Zero(df);
const MatT* HWY_RESTRICT tile_a = A + stride_a * row_a;
const MatT* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c;
const hwy::bfloat16_t* HWY_RESTRICT tile_a = A + stride_a * row_a;
const hwy::bfloat16_t* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c;
// Loop over columns of A and columns of the transposed B, in steps of N.
// Accumulates into the c## vectors.
HWY_UNROLL(1)
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
#if GEMMA_NATIVE_BF16
using V = hn::Vec<decltype(d)>;
const V b0 = hn::LoadU(d, tile_b + stride_b * 0 + col_ab);
const V b1 = hn::LoadU(d, tile_b + stride_b * 1 + col_ab);
@ -383,89 +295,44 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
c31 = hn::ReorderWidenMulAccumulate(df, a3, b1, c31, unused_sum1);
c32 = hn::ReorderWidenMulAccumulate(df, a3, b2, c32, unused_sum1);
c33 = hn::ReorderWidenMulAccumulate(df, a3, b3, c33, unused_sum1);
#else // Emulated: promote bf16 to f32
const VF b0 =
hn::PromoteTo(df, hn::LoadU(d, tile_b + stride_b * 0 + col_ab));
const VF b1 =
hn::PromoteTo(df, hn::LoadU(d, tile_b + stride_b * 1 + col_ab));
const VF b2 =
hn::PromoteTo(df, hn::LoadU(d, tile_b + stride_b * 2 + col_ab));
const VF b3 =
hn::PromoteTo(df, hn::LoadU(d, tile_b + stride_b * 3 + col_ab));
const VF a0 =
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 0 + col_ab));
c00 = hn::MulAdd(a0, b0, c00);
c01 = hn::MulAdd(a0, b1, c01);
c02 = hn::MulAdd(a0, b2, c02);
c03 = hn::MulAdd(a0, b3, c03);
if constexpr (kNumRows == 1) continue;
const VF a1 =
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 1 + col_ab));
c10 = hn::MulAdd(a1, b0, c10);
c11 = hn::MulAdd(a1, b1, c11);
c12 = hn::MulAdd(a1, b2, c12);
c13 = hn::MulAdd(a1, b3, c13);
if constexpr (kNumRows == 2) continue;
const VF a2 =
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 2 + col_ab));
c20 = hn::MulAdd(a2, b0, c20);
c21 = hn::MulAdd(a2, b1, c21);
c22 = hn::MulAdd(a2, b2, c22);
c23 = hn::MulAdd(a2, b3, c23);
if constexpr (kNumRows == 3) continue;
const VF a3 =
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 3 + col_ab));
c30 = hn::MulAdd(a3, b0, c30);
c31 = hn::MulAdd(a3, b1, c31);
c32 = hn::MulAdd(a3, b2, c32);
c33 = hn::MulAdd(a3, b3, c33);
#endif // !GEMMA_NATIVE_BF16
}
#if GEMMA_NATIVE_BF16
// Ensure sum1 was indeed unused.
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
#endif
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
if constexpr (kAdd) {
const AddT* dd = add + row_b_col_c;
StoreHorizontalSumsAdd<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
c20, c21, c22, c23, c30, c31, c32, c33, dd,
tile_c, stride_c);
} else {
StoreHorizontalSums<kNumRows>(
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
c30, c31, c32, c33, tile_c, stride_c);
}
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
c32, c33, add, row_b_col_c, tile_c, stride_c);
}
#endif // GEMMA_NATIVE_BF16
// The col_ab loop is unrolled 2x, so we have two consecutive a0/a1 and b00/b01
// etc. Multiplies a[c] with b[r,c] and adds to c[r].
template <class VF>
HWY_INLINE void UpdateTileRow(const VF& a0, const VF& a1, const VF& b00,
const VF& b01, const VF& b10, const VF& b11,
const VF& b20, const VF& b21, const VF& b30,
const VF& b31, VF& c0, VF& c1, VF& c2, VF& c3) {
c0 = MulAdd(a0, b00, c0);
c0 = MulAdd(a1, b01, c0);
c1 = MulAdd(a0, b10, c1);
c1 = MulAdd(a1, b11, c1);
c2 = MulAdd(a0, b20, c2);
c2 = MulAdd(a1, b21, c2);
c3 = MulAdd(a0, b30, c3);
c3 = MulAdd(a1, b31, c3);
c0 = hn::MulAdd(a0, b00, c0);
c1 = hn::MulAdd(a0, b10, c1);
c2 = hn::MulAdd(a0, b20, c2);
c3 = hn::MulAdd(a0, b30, c3);
c0 = hn::MulAdd(a1, b01, c0);
c1 = hn::MulAdd(a1, b11, c1);
c2 = hn::MulAdd(a1, b21, c2);
c3 = hn::MulAdd(a1, b31, c3);
}
// Same as above, for when there exists CompressTraits<MatTA>::Decompress2 and
// MatTB is compressed.
// Accumulates a single kNumRowsx4 tile of A x B into C. B is transposed, so we
// can iterate over both A and B with consecutive vector loads. kNumRows<=4.
// General case: uses CompressTraits to load from A and B.
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatTA,
typename MatTB, HWY_IF_T_SIZE(MatTB, 1), typename AddT>
typename MatTB>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const MatTB* HWY_RESTRICT B,
float* HWY_RESTRICT C, const AddT* add,
float* HWY_RESTRICT C, const float* add,
const size_t idx_tile, const size_t xtiles,
const size_t stride_a, const size_t stride_b,
const size_t stride_c) {
@ -473,6 +340,9 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
constexpr size_t kRegCols = 4;
static_assert(kNumRows <= kRegRows);
using TraitsA = CompressTraits<MatTA>;
using TraitsB = CompressTraits<MatTB>;
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
const size_t row_a = idx_tile / xtiles * kRegRows;
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
@ -512,273 +382,42 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
HWY_UNROLL(1)
for (; col_ab <= kColsA_RowsB - 2 * N; col_ab += 2 * N) {
V b00, b01;
CompressTraits<MatTB>::Decompress2(
d32, B, tile_b_ofs + stride_b * 0 + col_ab, b00, b01);
TraitsB::Decompress2(d32, B, tile_b_ofs + stride_b * 0 + col_ab, b00, b01);
V b10, b11;
CompressTraits<MatTB>::Decompress2(
d32, B, tile_b_ofs + stride_b * 1 + col_ab, b10, b11);
TraitsB::Decompress2(d32, B, tile_b_ofs + stride_b * 1 + col_ab, b10, b11);
V b20, b21;
CompressTraits<MatTB>::Decompress2(
d32, B, tile_b_ofs + stride_b * 2 + col_ab, b20, b21);
TraitsB::Decompress2(d32, B, tile_b_ofs + stride_b * 2 + col_ab, b20, b21);
V b30, b31;
CompressTraits<MatTB>::Decompress2(
d32, B, tile_b_ofs + stride_b * 3 + col_ab, b30, b31);
TraitsB::Decompress2(d32, B, tile_b_ofs + stride_b * 3 + col_ab, b30, b31);
V a00, a01;
CompressTraits<MatTA>::Decompress2(
d32, A, tile_a_ofs + stride_a * 0 + col_ab, a00, a01);
TraitsA::Decompress2(d32, A, tile_a_ofs + stride_a * 0 + col_ab, a00, a01);
UpdateTileRow(a00, a01, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01,
c02, c03);
if constexpr (kNumRows == 1) continue;
V a10, a11;
CompressTraits<MatTA>::Decompress2(
d32, A, tile_a_ofs + stride_a * 1 + col_ab, a10, a11);
TraitsA::Decompress2(d32, A, tile_a_ofs + stride_a * 1 + col_ab, a10, a11);
UpdateTileRow(a10, a11, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11,
c12, c13);
if constexpr (kNumRows == 2) continue;
V a20, a21;
CompressTraits<MatTA>::Decompress2(
d32, A, tile_a_ofs + stride_a * 2 + col_ab, a20, a21);
TraitsA::Decompress2(d32, A, tile_a_ofs + stride_a * 2 + col_ab, a20, a21);
UpdateTileRow(a20, a21, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21,
c22, c23);
if constexpr (kNumRows == 3) continue;
V a30, a31;
CompressTraits<MatTA>::Decompress2(
d32, A, tile_a_ofs + stride_a * 3 + col_ab, a30, a31);
TraitsA::Decompress2(d32, A, tile_a_ofs + stride_a * 3 + col_ab, a30, a31);
UpdateTileRow(a30, a31, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31,
c32, c33);
}
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
if constexpr (kAdd) {
const AddT* dd = add + row_b_col_c;
StoreHorizontalSumsAdd<kNumRows>(d32, c00, c01, c02, c03, c10, c11, c12,
c13, c20, c21, c22, c23, c30, c31, c32,
c33, dd, tile_c, stride_c);
} else {
StoreHorizontalSums<kNumRows>(d32, c00, c01, c02, c03, c10, c11, c12, c13,
c20, c21, c22, c23, c30, c31, c32, c33,
tile_c, stride_c);
}
}
// Same as above, but with mixed Mat types: (f32, bf16).
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatTA,
HWY_IF_F32(MatTA),
typename MatTB, HWY_IF_BF16(MatTB), typename AddT>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const MatTB* HWY_RESTRICT B,
float* HWY_RESTRICT C, const AddT* add,
const size_t idx_tile,
const size_t xtiles, const size_t stride_a,
const size_t stride_b, const size_t stride_c) {
constexpr size_t kRegRows = 4;
constexpr size_t kRegCols = 4;
static_assert(kNumRows <= kRegRows);
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
const size_t row_a = idx_tile / xtiles * kRegRows;
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
const hn::ScalableTag<float> d32;
using VF = hn::Vec<decltype(d32)>;
// TODO: Using half-vectors for now, it might be faster to
// PromoteLower/UpperTo, and more so to PromoteEven/OddTo if we have packed B
// accordingly.
const hn::Rebind<MatTB, decltype(d32)> d16;
HWY_DASSERT(Lanes(d16) == Lanes(d32));
const size_t N = Lanes(d16);
VF c00 = hn::Zero(d32);
VF c01 = hn::Zero(d32);
VF c02 = hn::Zero(d32);
VF c03 = hn::Zero(d32);
VF c10 = hn::Zero(d32);
VF c11 = hn::Zero(d32);
VF c12 = hn::Zero(d32);
VF c13 = hn::Zero(d32);
VF c20 = hn::Zero(d32);
VF c21 = hn::Zero(d32);
VF c22 = hn::Zero(d32);
VF c23 = hn::Zero(d32);
VF c30 = hn::Zero(d32);
VF c31 = hn::Zero(d32);
VF c32 = hn::Zero(d32);
VF c33 = hn::Zero(d32);
const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a;
const MatTB* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c;
// Loop over columns of A and columns of the transposed B, in steps of N.
// Accumulates into the c## vectors.
HWY_UNROLL(1)
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
// Promote bf16 to f32
const VF b0 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_b + stride_b * 0 + col_ab));
const VF b1 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_b + stride_b * 1 + col_ab));
const VF b2 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_b + stride_b * 2 + col_ab));
const VF b3 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_b + stride_b * 3 + col_ab));
const VF a0 = hn::LoadU(d32, tile_a + stride_a * 0 + col_ab);
c00 = hn::MulAdd(a0, b0, c00);
c01 = hn::MulAdd(a0, b1, c01);
c02 = hn::MulAdd(a0, b2, c02);
c03 = hn::MulAdd(a0, b3, c03);
if constexpr (kNumRows == 1) continue;
const VF a1 = hn::LoadU(d32, tile_a + stride_a * 1 + col_ab);
c10 = hn::MulAdd(a1, b0, c10);
c11 = hn::MulAdd(a1, b1, c11);
c12 = hn::MulAdd(a1, b2, c12);
c13 = hn::MulAdd(a1, b3, c13);
if constexpr (kNumRows == 2) continue;
const VF a2 = hn::LoadU(d32, tile_a + stride_a * 2 + col_ab);
c20 = hn::MulAdd(a2, b0, c20);
c21 = hn::MulAdd(a2, b1, c21);
c22 = hn::MulAdd(a2, b2, c22);
c23 = hn::MulAdd(a2, b3, c23);
if constexpr (kNumRows == 3) continue;
const VF a3 = hn::LoadU(d32, tile_a + stride_a * 3 + col_ab);
c30 = hn::MulAdd(a3, b0, c30);
c31 = hn::MulAdd(a3, b1, c31);
c32 = hn::MulAdd(a3, b2, c32);
c33 = hn::MulAdd(a3, b3, c33);
}
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
if constexpr (kAdd) {
const AddT* dd = add + row_b_col_c;
StoreHorizontalSumsAdd<kNumRows>(d32, c00, c01, c02, c03, c10, c11, c12,
c13, c20, c21, c22, c23, c30, c31, c32,
c33, dd, tile_c, stride_c);
} else {
StoreHorizontalSums<kNumRows>(
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
c30, c31, c32, c33, tile_c, stride_c);
}
}
// Same as above, but with mixed Mat types: (bf16, f32).
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatTA,
HWY_IF_BF16(MatTA), typename MatTB, HWY_IF_F32(MatTB), typename AddT>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const MatTB* HWY_RESTRICT B,
float* HWY_RESTRICT C, const AddT* add,
const size_t idx_tile,
const size_t xtiles, const size_t stride_a,
const size_t stride_b, const size_t stride_c) {
constexpr size_t kRegRows = 4;
constexpr size_t kRegCols = 4;
static_assert(kNumRows <= kRegRows);
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
const size_t row_a = idx_tile / xtiles * kRegRows;
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
const hn::ScalableTag<float> d32;
using VF = hn::Vec<decltype(d32)>;
// TODO: Using half-vectors for now, it might be faster to
// PromoteLower/UpperTo, and more so to PromoteEven/OddTo if we have packed B
// accordingly.
const hn::Rebind<MatTA, decltype(d32)> d16;
HWY_DASSERT(Lanes(d16) == Lanes(d32));
const size_t N = Lanes(d16);
VF c00 = hn::Zero(d32);
VF c01 = hn::Zero(d32);
VF c02 = hn::Zero(d32);
VF c03 = hn::Zero(d32);
VF c10 = hn::Zero(d32);
VF c11 = hn::Zero(d32);
VF c12 = hn::Zero(d32);
VF c13 = hn::Zero(d32);
VF c20 = hn::Zero(d32);
VF c21 = hn::Zero(d32);
VF c22 = hn::Zero(d32);
VF c23 = hn::Zero(d32);
VF c30 = hn::Zero(d32);
VF c31 = hn::Zero(d32);
VF c32 = hn::Zero(d32);
VF c33 = hn::Zero(d32);
const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a;
const MatTB* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c;
// Loop over columns of A and columns of the transposed B, in steps of N.
// Accumulates into the c## vectors.
HWY_UNROLL(1)
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
// Promote bf16 to f32
const VF b0 = hn::LoadU(d32, tile_b + stride_b * 0 + col_ab);
const VF b1 = hn::LoadU(d32, tile_b + stride_b * 1 + col_ab);
const VF b2 = hn::LoadU(d32, tile_b + stride_b * 2 + col_ab);
const VF b3 = hn::LoadU(d32, tile_b + stride_b * 3 + col_ab);
const VF a0 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 0 + col_ab));
c00 = hn::MulAdd(a0, b0, c00);
c01 = hn::MulAdd(a0, b1, c01);
c02 = hn::MulAdd(a0, b2, c02);
c03 = hn::MulAdd(a0, b3, c03);
if constexpr (kNumRows == 1) continue;
const VF a1 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab));
c10 = hn::MulAdd(a1, b0, c10);
c11 = hn::MulAdd(a1, b1, c11);
c12 = hn::MulAdd(a1, b2, c12);
c13 = hn::MulAdd(a1, b3, c13);
if constexpr (kNumRows == 2) continue;
const VF a2 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab));
c20 = hn::MulAdd(a2, b0, c20);
c21 = hn::MulAdd(a2, b1, c21);
c22 = hn::MulAdd(a2, b2, c22);
c23 = hn::MulAdd(a2, b3, c23);
if constexpr (kNumRows == 3) continue;
const VF a3 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab));
c30 = hn::MulAdd(a3, b0, c30);
c31 = hn::MulAdd(a3, b1, c31);
c32 = hn::MulAdd(a3, b2, c32);
c33 = hn::MulAdd(a3, b3, c33);
}
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
if constexpr (kAdd) {
const AddT* dd = add + row_b_col_c;
StoreHorizontalSumsAdd<kNumRows>(d32, c00, c01, c02, c03, c10, c11, c12,
c13, c20, c21, c22, c23, c30, c31, c32,
c33, dd, tile_c, stride_c);
} else {
StoreHorizontalSums<kNumRows>(
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
c30, c31, c32, c33, tile_c, stride_c);
}
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
c32, c33, add, row_b_col_c, tile_c, stride_c);
}
// Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k,
@ -842,45 +481,7 @@ HWY_NOINLINE void MatMul_4x4_Batch(
batch_size, A, B, C, /*add=*/static_cast<OutT*>(nullptr), pool);
}
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
// ops_test across instruction sets.
template <size_t kN, size_t kK, typename MatTA, typename MatTB,
HWY_IF_T_SIZE_GT(MatTB, 1)>
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b,
const float* add,
float* HWY_RESTRICT out) {
for (size_t i = 0; i < batch_size; ++i) {
for (size_t k = 0; k < kN; ++k) {
for (size_t j = 0; j < kK; ++j) {
const float a1 = hwy::ConvertScalarTo<float>(a[i * kN + k]);
const float b1 = hwy::ConvertScalarTo<float>(b[k * kK + j]);
out[i * kK + j] += a1 * b1;
}
}
if (add != nullptr) {
for (size_t j = 0; j < kK; ++j) {
out[i * kK + j] += add[j];
}
}
}
}
// The above overload can handle combinations of f32 and bf16, but this one
// is required for MatTB = {SFP, NUQ}.
template <size_t kN, size_t kK, typename MatTA, typename MatTB,
HWY_IF_T_SIZE(MatTB, 1)>
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b_compr,
const float* add,
float* HWY_RESTRICT out) {
const hn::ScalableTag<float> d;
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
CompressTraits<MatTB>::Decompress(d, /*in_capacity=*/0, b_compr, 0, b.get(),
kK * kN);
MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), add, out);
}
//------------------------------------------------------------------------------
HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
const size_t size, float* HWY_RESTRICT out) {
const hn::ScalableTag<float> df;

View File

@ -508,6 +508,43 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
}
}
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
// ops_test across instruction sets.
template <size_t kN, size_t kK, typename MatTA, typename MatTB,
HWY_IF_T_SIZE_GT(MatTB, 1)>
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b, const float* add,
float* HWY_RESTRICT out) {
for (size_t i = 0; i < batch_size; ++i) {
for (size_t k = 0; k < kN; ++k) {
for (size_t j = 0; j < kK; ++j) {
const float a1 = hwy::ConvertScalarTo<float>(a[i * kN + k]);
const float b1 = hwy::ConvertScalarTo<float>(b[k * kK + j]);
out[i * kK + j] += a1 * b1;
}
}
if (add != nullptr) {
for (size_t j = 0; j < kK; ++j) {
out[i * kK + j] += add[j];
}
}
}
}
// The above overload can handle combinations of f32 and bf16, but this one
// is required for MatTB = {SFP, NUQ}.
template <size_t kN, size_t kK, typename MatTA, typename MatTB,
HWY_IF_T_SIZE(MatTB, 1)>
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b_compr,
const float* add, float* HWY_RESTRICT out) {
const hn::ScalableTag<float> d;
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
CompressTraits<MatTB>::Decompress(d, /*in_capacity=*/0, b_compr, 0, b.get(),
kK * kN);
MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), add, out);
}
template <size_t kM, size_t kN, size_t kK, bool kAdd, typename MatTA,
typename MatTB = MatTA>
void TestTiledBatchMatMul() {