mirror of https://github.com/google/gemma.cpp.git
Simplify matmul: only 2 overloads
Also add StoreHorizontalSumsMaybeAdd wrapper function, move MatMulSlowBatch into test. 1.02-1.06x speedup. PiperOrigin-RevId: 651394791
This commit is contained in:
parent
3e92088595
commit
be765afce2
531
gemma/ops.h
531
gemma/ops.h
|
|
@ -28,7 +28,6 @@
|
|||
#include <type_traits> // std::enable_if_t
|
||||
|
||||
#include "compression/sfp.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/detect_targets.h"
|
||||
|
|
@ -142,14 +141,14 @@ HWY_INLINE void StoreHorizontalSums(DF df, VF c00, VF c01, VF c02, VF c03,
|
|||
}
|
||||
|
||||
// Completes the tile by summing across the vectors, and adds the biases.
|
||||
template <size_t kNumRows, class DF, class VF = hn::Vec<DF>, typename AddT>
|
||||
template <size_t kNumRows, class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void StoreHorizontalSumsAdd(DF df, VF c00, VF c01, VF c02, VF c03,
|
||||
VF c10, VF c11, VF c12, VF c13, //
|
||||
VF c20, VF c21, VF c22, VF c23, //
|
||||
VF c30, VF c31, VF c32, VF c33,
|
||||
const AddT* add,
|
||||
float* HWY_RESTRICT tile_c,
|
||||
size_t stride_c) {
|
||||
VF c10, VF c11, VF c12, VF c13, //
|
||||
VF c20, VF c21, VF c22, VF c23, //
|
||||
VF c30, VF c31, VF c32, VF c33,
|
||||
const float* HWY_RESTRICT add,
|
||||
float* HWY_RESTRICT tile_c,
|
||||
size_t stride_c) {
|
||||
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
|
||||
// Each entry of C[r,c] is a dot product of A.row and B.col, which reside in
|
||||
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
|
||||
|
|
@ -182,103 +181,23 @@ HWY_INLINE void StoreHorizontalSumsAdd(DF df, VF c00, VF c01, VF c02, VF c03,
|
|||
tile_c[stride_c * 3 + 3] = hn::ReduceSum(df, c33) + addon3;
|
||||
}
|
||||
|
||||
// Accumulates a single kNumRowsx4 tile of A x B into C. B is transposed, so we
|
||||
// can iterate over both A and B with consecutive vector loads. kNumRows<=4.
|
||||
// Shared between parallelized and sequential (loop) callers.
|
||||
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatT,
|
||||
HWY_IF_F32(MatT), typename AddT>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||
const MatT* HWY_RESTRICT B, MatT* HWY_RESTRICT C,
|
||||
const AddT* add,
|
||||
const size_t idx_tile, const size_t xtiles,
|
||||
const size_t stride_a, const size_t stride_b,
|
||||
const size_t stride_c) {
|
||||
// Tile size. 4x unrolling makes sense on many platforms because we can fit
|
||||
// 4x4 accumulators and 8 temporaries in the 32 vectors; we have more than
|
||||
// #FMA units * FMA latency (up to 2*5) independent computations in flight;
|
||||
// threads write in units of 4*N elements, which is at least one cache line.
|
||||
constexpr size_t kRegRows = 4;
|
||||
constexpr size_t kRegCols = 4;
|
||||
static_assert(kNumRows <= kRegRows);
|
||||
|
||||
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
||||
const size_t row_a = idx_tile / xtiles * kRegRows;
|
||||
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
|
||||
|
||||
const hn::ScalableTag<MatT> d;
|
||||
const size_t N = Lanes(d);
|
||||
using V = hn::Vec<decltype(d)>;
|
||||
|
||||
V c00 = hn::Zero(d);
|
||||
V c01 = hn::Zero(d);
|
||||
V c02 = hn::Zero(d);
|
||||
V c03 = hn::Zero(d);
|
||||
|
||||
V c10 = hn::Zero(d);
|
||||
V c11 = hn::Zero(d);
|
||||
V c12 = hn::Zero(d);
|
||||
V c13 = hn::Zero(d);
|
||||
|
||||
V c20 = hn::Zero(d);
|
||||
V c21 = hn::Zero(d);
|
||||
V c22 = hn::Zero(d);
|
||||
V c23 = hn::Zero(d);
|
||||
|
||||
V c30 = hn::Zero(d);
|
||||
V c31 = hn::Zero(d);
|
||||
V c32 = hn::Zero(d);
|
||||
V c33 = hn::Zero(d);
|
||||
|
||||
const MatT* HWY_RESTRICT tile_a = A + stride_a * row_a;
|
||||
const MatT* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c;
|
||||
|
||||
// Loop over columns of A and columns of the transposed B, in steps of N.
|
||||
// Accumulates into the c## vectors.
|
||||
HWY_UNROLL(1)
|
||||
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
|
||||
const V b0 = hn::LoadU(d, tile_b + stride_b * 0 + col_ab);
|
||||
const V b1 = hn::LoadU(d, tile_b + stride_b * 1 + col_ab);
|
||||
const V b2 = hn::LoadU(d, tile_b + stride_b * 2 + col_ab);
|
||||
const V b3 = hn::LoadU(d, tile_b + stride_b * 3 + col_ab);
|
||||
|
||||
const V a0 = hn::LoadU(d, tile_a + stride_a * 0 + col_ab);
|
||||
c00 = hn::MulAdd(a0, b0, c00);
|
||||
c01 = hn::MulAdd(a0, b1, c01);
|
||||
c02 = hn::MulAdd(a0, b2, c02);
|
||||
c03 = hn::MulAdd(a0, b3, c03);
|
||||
if constexpr (kNumRows == 1) continue;
|
||||
|
||||
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
|
||||
c10 = hn::MulAdd(a1, b0, c10);
|
||||
c11 = hn::MulAdd(a1, b1, c11);
|
||||
c12 = hn::MulAdd(a1, b2, c12);
|
||||
c13 = hn::MulAdd(a1, b3, c13);
|
||||
if constexpr (kNumRows == 2) continue;
|
||||
|
||||
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
|
||||
c20 = hn::MulAdd(a2, b0, c20);
|
||||
c21 = hn::MulAdd(a2, b1, c21);
|
||||
c22 = hn::MulAdd(a2, b2, c22);
|
||||
c23 = hn::MulAdd(a2, b3, c23);
|
||||
if constexpr (kNumRows == 3) continue;
|
||||
|
||||
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
|
||||
c30 = hn::MulAdd(a3, b0, c30);
|
||||
c31 = hn::MulAdd(a3, b1, c31);
|
||||
c32 = hn::MulAdd(a3, b2, c32);
|
||||
c33 = hn::MulAdd(a3, b3, c33);
|
||||
}
|
||||
|
||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||
// Wrapper around StoreHorizontalSums and StoreHorizontalSumsAdd to shorten call
|
||||
// sites. If `!kAdd`, `add` is nullptr, so adding `add_offset` to it would be
|
||||
// UB, hence we pass it as a separate argument.
|
||||
template <bool kAdd, size_t kNumRows, class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void StoreHorizontalSumsMaybeAdd(
|
||||
DF df, VF c00, VF c01, VF c02, VF c03, VF c10, VF c11, VF c12, VF c13,
|
||||
VF c20, VF c21, VF c22, VF c23, VF c30, VF c31, VF c32, VF c33,
|
||||
const float* HWY_RESTRICT add, size_t add_offset,
|
||||
float* HWY_RESTRICT tile_c, size_t stride_c) {
|
||||
if constexpr (kAdd) {
|
||||
const AddT* tile_add = add + row_b_col_c;
|
||||
StoreHorizontalSumsAdd<kNumRows>(d, c00, c01, c02, c03, c10, c11, c12, c13,
|
||||
StoreHorizontalSumsAdd<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
|
||||
c20, c21, c22, c23, c30, c31, c32, c33,
|
||||
tile_add, tile_c, stride_c);
|
||||
add + add_offset, tile_c, stride_c);
|
||||
} else {
|
||||
StoreHorizontalSums<kNumRows>(
|
||||
d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||
c30, c31, c32, c33, tile_c, stride_c);
|
||||
StoreHorizontalSums<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
|
||||
c20, c21, c22, c23, c30, c31, c32, c33,
|
||||
tile_c, stride_c);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -290,12 +209,13 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
|||
#define GEMMA_NATIVE_BF16 0
|
||||
#endif
|
||||
|
||||
// As above, for MatT=bf16
|
||||
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatT,
|
||||
HWY_IF_BF16(MatT), typename AddT>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||
const MatT* HWY_RESTRICT B, float* HWY_RESTRICT C,
|
||||
const AddT* add,
|
||||
#if GEMMA_NATIVE_BF16
|
||||
|
||||
// Specialization for f32 += bf16 * bf16 that avoids promoting to f32.
|
||||
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A,
|
||||
const hwy::bfloat16_t* HWY_RESTRICT B,
|
||||
float* HWY_RESTRICT C, const float* add,
|
||||
const size_t idx_tile, const size_t xtiles,
|
||||
const size_t stride_a, const size_t stride_b,
|
||||
const size_t stride_c) {
|
||||
|
|
@ -309,17 +229,10 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
|||
|
||||
const hn::ScalableTag<float> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
#if GEMMA_NATIVE_BF16
|
||||
// ReorderWidenMulAccumulate does not use its sum1 arg and we can use full
|
||||
// bf16 vectors.
|
||||
const hn::Repartition<MatT, decltype(df)> d;
|
||||
const hn::Repartition<hwy::bfloat16_t, decltype(df)> d;
|
||||
VF unused_sum1 = hn::Zero(df);
|
||||
#else
|
||||
// Emulated: use half-vectors of bf16 because we cannot afford two sums for
|
||||
// each c##.
|
||||
const hn::Rebind<MatT, decltype(df)> d;
|
||||
HWY_DASSERT(Lanes(d) == Lanes(df));
|
||||
#endif
|
||||
|
||||
const size_t N = Lanes(d);
|
||||
|
||||
|
|
@ -343,14 +256,13 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
|||
VF c32 = hn::Zero(df);
|
||||
VF c33 = hn::Zero(df);
|
||||
|
||||
const MatT* HWY_RESTRICT tile_a = A + stride_a * row_a;
|
||||
const MatT* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c;
|
||||
const hwy::bfloat16_t* HWY_RESTRICT tile_a = A + stride_a * row_a;
|
||||
const hwy::bfloat16_t* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c;
|
||||
|
||||
// Loop over columns of A and columns of the transposed B, in steps of N.
|
||||
// Accumulates into the c## vectors.
|
||||
HWY_UNROLL(1)
|
||||
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
|
||||
#if GEMMA_NATIVE_BF16
|
||||
using V = hn::Vec<decltype(d)>;
|
||||
const V b0 = hn::LoadU(d, tile_b + stride_b * 0 + col_ab);
|
||||
const V b1 = hn::LoadU(d, tile_b + stride_b * 1 + col_ab);
|
||||
|
|
@ -383,89 +295,44 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
|||
c31 = hn::ReorderWidenMulAccumulate(df, a3, b1, c31, unused_sum1);
|
||||
c32 = hn::ReorderWidenMulAccumulate(df, a3, b2, c32, unused_sum1);
|
||||
c33 = hn::ReorderWidenMulAccumulate(df, a3, b3, c33, unused_sum1);
|
||||
#else // Emulated: promote bf16 to f32
|
||||
const VF b0 =
|
||||
hn::PromoteTo(df, hn::LoadU(d, tile_b + stride_b * 0 + col_ab));
|
||||
const VF b1 =
|
||||
hn::PromoteTo(df, hn::LoadU(d, tile_b + stride_b * 1 + col_ab));
|
||||
const VF b2 =
|
||||
hn::PromoteTo(df, hn::LoadU(d, tile_b + stride_b * 2 + col_ab));
|
||||
const VF b3 =
|
||||
hn::PromoteTo(df, hn::LoadU(d, tile_b + stride_b * 3 + col_ab));
|
||||
|
||||
const VF a0 =
|
||||
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 0 + col_ab));
|
||||
c00 = hn::MulAdd(a0, b0, c00);
|
||||
c01 = hn::MulAdd(a0, b1, c01);
|
||||
c02 = hn::MulAdd(a0, b2, c02);
|
||||
c03 = hn::MulAdd(a0, b3, c03);
|
||||
if constexpr (kNumRows == 1) continue;
|
||||
|
||||
const VF a1 =
|
||||
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 1 + col_ab));
|
||||
c10 = hn::MulAdd(a1, b0, c10);
|
||||
c11 = hn::MulAdd(a1, b1, c11);
|
||||
c12 = hn::MulAdd(a1, b2, c12);
|
||||
c13 = hn::MulAdd(a1, b3, c13);
|
||||
if constexpr (kNumRows == 2) continue;
|
||||
|
||||
const VF a2 =
|
||||
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 2 + col_ab));
|
||||
c20 = hn::MulAdd(a2, b0, c20);
|
||||
c21 = hn::MulAdd(a2, b1, c21);
|
||||
c22 = hn::MulAdd(a2, b2, c22);
|
||||
c23 = hn::MulAdd(a2, b3, c23);
|
||||
if constexpr (kNumRows == 3) continue;
|
||||
|
||||
const VF a3 =
|
||||
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 3 + col_ab));
|
||||
c30 = hn::MulAdd(a3, b0, c30);
|
||||
c31 = hn::MulAdd(a3, b1, c31);
|
||||
c32 = hn::MulAdd(a3, b2, c32);
|
||||
c33 = hn::MulAdd(a3, b3, c33);
|
||||
#endif // !GEMMA_NATIVE_BF16
|
||||
}
|
||||
|
||||
#if GEMMA_NATIVE_BF16
|
||||
// Ensure sum1 was indeed unused.
|
||||
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
|
||||
#endif
|
||||
|
||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||
if constexpr (kAdd) {
|
||||
const AddT* dd = add + row_b_col_c;
|
||||
StoreHorizontalSumsAdd<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
|
||||
c20, c21, c22, c23, c30, c31, c32, c33, dd,
|
||||
tile_c, stride_c);
|
||||
} else {
|
||||
StoreHorizontalSums<kNumRows>(
|
||||
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||
c30, c31, c32, c33, tile_c, stride_c);
|
||||
}
|
||||
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
|
||||
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
|
||||
c32, c33, add, row_b_col_c, tile_c, stride_c);
|
||||
}
|
||||
|
||||
#endif // GEMMA_NATIVE_BF16
|
||||
|
||||
// The col_ab loop is unrolled 2x, so we have two consecutive a0/a1 and b00/b01
|
||||
// etc. Multiplies a[c] with b[r,c] and adds to c[r].
|
||||
template <class VF>
|
||||
HWY_INLINE void UpdateTileRow(const VF& a0, const VF& a1, const VF& b00,
|
||||
const VF& b01, const VF& b10, const VF& b11,
|
||||
const VF& b20, const VF& b21, const VF& b30,
|
||||
const VF& b31, VF& c0, VF& c1, VF& c2, VF& c3) {
|
||||
c0 = MulAdd(a0, b00, c0);
|
||||
c0 = MulAdd(a1, b01, c0);
|
||||
c1 = MulAdd(a0, b10, c1);
|
||||
c1 = MulAdd(a1, b11, c1);
|
||||
c2 = MulAdd(a0, b20, c2);
|
||||
c2 = MulAdd(a1, b21, c2);
|
||||
c3 = MulAdd(a0, b30, c3);
|
||||
c3 = MulAdd(a1, b31, c3);
|
||||
c0 = hn::MulAdd(a0, b00, c0);
|
||||
c1 = hn::MulAdd(a0, b10, c1);
|
||||
c2 = hn::MulAdd(a0, b20, c2);
|
||||
c3 = hn::MulAdd(a0, b30, c3);
|
||||
c0 = hn::MulAdd(a1, b01, c0);
|
||||
c1 = hn::MulAdd(a1, b11, c1);
|
||||
c2 = hn::MulAdd(a1, b21, c2);
|
||||
c3 = hn::MulAdd(a1, b31, c3);
|
||||
}
|
||||
|
||||
// Same as above, for when there exists CompressTraits<MatTA>::Decompress2 and
|
||||
// MatTB is compressed.
|
||||
// Accumulates a single kNumRowsx4 tile of A x B into C. B is transposed, so we
|
||||
// can iterate over both A and B with consecutive vector loads. kNumRows<=4.
|
||||
// General case: uses CompressTraits to load from A and B.
|
||||
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatTA,
|
||||
typename MatTB, HWY_IF_T_SIZE(MatTB, 1), typename AddT>
|
||||
typename MatTB>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||
const MatTB* HWY_RESTRICT B,
|
||||
float* HWY_RESTRICT C, const AddT* add,
|
||||
float* HWY_RESTRICT C, const float* add,
|
||||
const size_t idx_tile, const size_t xtiles,
|
||||
const size_t stride_a, const size_t stride_b,
|
||||
const size_t stride_c) {
|
||||
|
|
@ -473,6 +340,9 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
|||
constexpr size_t kRegCols = 4;
|
||||
static_assert(kNumRows <= kRegRows);
|
||||
|
||||
using TraitsA = CompressTraits<MatTA>;
|
||||
using TraitsB = CompressTraits<MatTB>;
|
||||
|
||||
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
||||
const size_t row_a = idx_tile / xtiles * kRegRows;
|
||||
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
|
||||
|
|
@ -512,273 +382,42 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
|||
HWY_UNROLL(1)
|
||||
for (; col_ab <= kColsA_RowsB - 2 * N; col_ab += 2 * N) {
|
||||
V b00, b01;
|
||||
CompressTraits<MatTB>::Decompress2(
|
||||
d32, B, tile_b_ofs + stride_b * 0 + col_ab, b00, b01);
|
||||
TraitsB::Decompress2(d32, B, tile_b_ofs + stride_b * 0 + col_ab, b00, b01);
|
||||
V b10, b11;
|
||||
CompressTraits<MatTB>::Decompress2(
|
||||
d32, B, tile_b_ofs + stride_b * 1 + col_ab, b10, b11);
|
||||
TraitsB::Decompress2(d32, B, tile_b_ofs + stride_b * 1 + col_ab, b10, b11);
|
||||
V b20, b21;
|
||||
CompressTraits<MatTB>::Decompress2(
|
||||
d32, B, tile_b_ofs + stride_b * 2 + col_ab, b20, b21);
|
||||
TraitsB::Decompress2(d32, B, tile_b_ofs + stride_b * 2 + col_ab, b20, b21);
|
||||
V b30, b31;
|
||||
CompressTraits<MatTB>::Decompress2(
|
||||
d32, B, tile_b_ofs + stride_b * 3 + col_ab, b30, b31);
|
||||
TraitsB::Decompress2(d32, B, tile_b_ofs + stride_b * 3 + col_ab, b30, b31);
|
||||
|
||||
V a00, a01;
|
||||
CompressTraits<MatTA>::Decompress2(
|
||||
d32, A, tile_a_ofs + stride_a * 0 + col_ab, a00, a01);
|
||||
TraitsA::Decompress2(d32, A, tile_a_ofs + stride_a * 0 + col_ab, a00, a01);
|
||||
UpdateTileRow(a00, a01, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01,
|
||||
c02, c03);
|
||||
|
||||
if constexpr (kNumRows == 1) continue;
|
||||
|
||||
V a10, a11;
|
||||
CompressTraits<MatTA>::Decompress2(
|
||||
d32, A, tile_a_ofs + stride_a * 1 + col_ab, a10, a11);
|
||||
TraitsA::Decompress2(d32, A, tile_a_ofs + stride_a * 1 + col_ab, a10, a11);
|
||||
UpdateTileRow(a10, a11, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11,
|
||||
c12, c13);
|
||||
|
||||
if constexpr (kNumRows == 2) continue;
|
||||
|
||||
V a20, a21;
|
||||
CompressTraits<MatTA>::Decompress2(
|
||||
d32, A, tile_a_ofs + stride_a * 2 + col_ab, a20, a21);
|
||||
TraitsA::Decompress2(d32, A, tile_a_ofs + stride_a * 2 + col_ab, a20, a21);
|
||||
UpdateTileRow(a20, a21, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21,
|
||||
c22, c23);
|
||||
|
||||
if constexpr (kNumRows == 3) continue;
|
||||
|
||||
V a30, a31;
|
||||
CompressTraits<MatTA>::Decompress2(
|
||||
d32, A, tile_a_ofs + stride_a * 3 + col_ab, a30, a31);
|
||||
TraitsA::Decompress2(d32, A, tile_a_ofs + stride_a * 3 + col_ab, a30, a31);
|
||||
UpdateTileRow(a30, a31, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31,
|
||||
c32, c33);
|
||||
}
|
||||
|
||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||
if constexpr (kAdd) {
|
||||
const AddT* dd = add + row_b_col_c;
|
||||
StoreHorizontalSumsAdd<kNumRows>(d32, c00, c01, c02, c03, c10, c11, c12,
|
||||
c13, c20, c21, c22, c23, c30, c31, c32,
|
||||
c33, dd, tile_c, stride_c);
|
||||
} else {
|
||||
StoreHorizontalSums<kNumRows>(d32, c00, c01, c02, c03, c10, c11, c12, c13,
|
||||
c20, c21, c22, c23, c30, c31, c32, c33,
|
||||
tile_c, stride_c);
|
||||
}
|
||||
}
|
||||
|
||||
// Same as above, but with mixed Mat types: (f32, bf16).
|
||||
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatTA,
|
||||
HWY_IF_F32(MatTA),
|
||||
typename MatTB, HWY_IF_BF16(MatTB), typename AddT>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||
const MatTB* HWY_RESTRICT B,
|
||||
float* HWY_RESTRICT C, const AddT* add,
|
||||
const size_t idx_tile,
|
||||
const size_t xtiles, const size_t stride_a,
|
||||
const size_t stride_b, const size_t stride_c) {
|
||||
constexpr size_t kRegRows = 4;
|
||||
constexpr size_t kRegCols = 4;
|
||||
static_assert(kNumRows <= kRegRows);
|
||||
|
||||
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
||||
const size_t row_a = idx_tile / xtiles * kRegRows;
|
||||
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
|
||||
|
||||
const hn::ScalableTag<float> d32;
|
||||
using VF = hn::Vec<decltype(d32)>;
|
||||
|
||||
// TODO: Using half-vectors for now, it might be faster to
|
||||
// PromoteLower/UpperTo, and more so to PromoteEven/OddTo if we have packed B
|
||||
// accordingly.
|
||||
const hn::Rebind<MatTB, decltype(d32)> d16;
|
||||
HWY_DASSERT(Lanes(d16) == Lanes(d32));
|
||||
|
||||
const size_t N = Lanes(d16);
|
||||
|
||||
VF c00 = hn::Zero(d32);
|
||||
VF c01 = hn::Zero(d32);
|
||||
VF c02 = hn::Zero(d32);
|
||||
VF c03 = hn::Zero(d32);
|
||||
|
||||
VF c10 = hn::Zero(d32);
|
||||
VF c11 = hn::Zero(d32);
|
||||
VF c12 = hn::Zero(d32);
|
||||
VF c13 = hn::Zero(d32);
|
||||
|
||||
VF c20 = hn::Zero(d32);
|
||||
VF c21 = hn::Zero(d32);
|
||||
VF c22 = hn::Zero(d32);
|
||||
VF c23 = hn::Zero(d32);
|
||||
|
||||
VF c30 = hn::Zero(d32);
|
||||
VF c31 = hn::Zero(d32);
|
||||
VF c32 = hn::Zero(d32);
|
||||
VF c33 = hn::Zero(d32);
|
||||
|
||||
const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a;
|
||||
const MatTB* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c;
|
||||
|
||||
// Loop over columns of A and columns of the transposed B, in steps of N.
|
||||
// Accumulates into the c## vectors.
|
||||
HWY_UNROLL(1)
|
||||
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
|
||||
// Promote bf16 to f32
|
||||
const VF b0 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_b + stride_b * 0 + col_ab));
|
||||
const VF b1 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_b + stride_b * 1 + col_ab));
|
||||
const VF b2 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_b + stride_b * 2 + col_ab));
|
||||
const VF b3 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_b + stride_b * 3 + col_ab));
|
||||
|
||||
const VF a0 = hn::LoadU(d32, tile_a + stride_a * 0 + col_ab);
|
||||
c00 = hn::MulAdd(a0, b0, c00);
|
||||
c01 = hn::MulAdd(a0, b1, c01);
|
||||
c02 = hn::MulAdd(a0, b2, c02);
|
||||
c03 = hn::MulAdd(a0, b3, c03);
|
||||
if constexpr (kNumRows == 1) continue;
|
||||
|
||||
const VF a1 = hn::LoadU(d32, tile_a + stride_a * 1 + col_ab);
|
||||
c10 = hn::MulAdd(a1, b0, c10);
|
||||
c11 = hn::MulAdd(a1, b1, c11);
|
||||
c12 = hn::MulAdd(a1, b2, c12);
|
||||
c13 = hn::MulAdd(a1, b3, c13);
|
||||
if constexpr (kNumRows == 2) continue;
|
||||
|
||||
const VF a2 = hn::LoadU(d32, tile_a + stride_a * 2 + col_ab);
|
||||
c20 = hn::MulAdd(a2, b0, c20);
|
||||
c21 = hn::MulAdd(a2, b1, c21);
|
||||
c22 = hn::MulAdd(a2, b2, c22);
|
||||
c23 = hn::MulAdd(a2, b3, c23);
|
||||
if constexpr (kNumRows == 3) continue;
|
||||
|
||||
const VF a3 = hn::LoadU(d32, tile_a + stride_a * 3 + col_ab);
|
||||
c30 = hn::MulAdd(a3, b0, c30);
|
||||
c31 = hn::MulAdd(a3, b1, c31);
|
||||
c32 = hn::MulAdd(a3, b2, c32);
|
||||
c33 = hn::MulAdd(a3, b3, c33);
|
||||
}
|
||||
|
||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||
if constexpr (kAdd) {
|
||||
const AddT* dd = add + row_b_col_c;
|
||||
StoreHorizontalSumsAdd<kNumRows>(d32, c00, c01, c02, c03, c10, c11, c12,
|
||||
c13, c20, c21, c22, c23, c30, c31, c32,
|
||||
c33, dd, tile_c, stride_c);
|
||||
} else {
|
||||
StoreHorizontalSums<kNumRows>(
|
||||
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||
c30, c31, c32, c33, tile_c, stride_c);
|
||||
}
|
||||
}
|
||||
|
||||
// Same as above, but with mixed Mat types: (bf16, f32).
|
||||
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatTA,
|
||||
HWY_IF_BF16(MatTA), typename MatTB, HWY_IF_F32(MatTB), typename AddT>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||
const MatTB* HWY_RESTRICT B,
|
||||
float* HWY_RESTRICT C, const AddT* add,
|
||||
const size_t idx_tile,
|
||||
const size_t xtiles, const size_t stride_a,
|
||||
const size_t stride_b, const size_t stride_c) {
|
||||
constexpr size_t kRegRows = 4;
|
||||
constexpr size_t kRegCols = 4;
|
||||
static_assert(kNumRows <= kRegRows);
|
||||
|
||||
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
||||
const size_t row_a = idx_tile / xtiles * kRegRows;
|
||||
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
|
||||
|
||||
const hn::ScalableTag<float> d32;
|
||||
using VF = hn::Vec<decltype(d32)>;
|
||||
|
||||
// TODO: Using half-vectors for now, it might be faster to
|
||||
// PromoteLower/UpperTo, and more so to PromoteEven/OddTo if we have packed B
|
||||
// accordingly.
|
||||
const hn::Rebind<MatTA, decltype(d32)> d16;
|
||||
HWY_DASSERT(Lanes(d16) == Lanes(d32));
|
||||
|
||||
const size_t N = Lanes(d16);
|
||||
|
||||
VF c00 = hn::Zero(d32);
|
||||
VF c01 = hn::Zero(d32);
|
||||
VF c02 = hn::Zero(d32);
|
||||
VF c03 = hn::Zero(d32);
|
||||
|
||||
VF c10 = hn::Zero(d32);
|
||||
VF c11 = hn::Zero(d32);
|
||||
VF c12 = hn::Zero(d32);
|
||||
VF c13 = hn::Zero(d32);
|
||||
|
||||
VF c20 = hn::Zero(d32);
|
||||
VF c21 = hn::Zero(d32);
|
||||
VF c22 = hn::Zero(d32);
|
||||
VF c23 = hn::Zero(d32);
|
||||
|
||||
VF c30 = hn::Zero(d32);
|
||||
VF c31 = hn::Zero(d32);
|
||||
VF c32 = hn::Zero(d32);
|
||||
VF c33 = hn::Zero(d32);
|
||||
|
||||
const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a;
|
||||
const MatTB* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c;
|
||||
|
||||
// Loop over columns of A and columns of the transposed B, in steps of N.
|
||||
// Accumulates into the c## vectors.
|
||||
HWY_UNROLL(1)
|
||||
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
|
||||
// Promote bf16 to f32
|
||||
const VF b0 = hn::LoadU(d32, tile_b + stride_b * 0 + col_ab);
|
||||
const VF b1 = hn::LoadU(d32, tile_b + stride_b * 1 + col_ab);
|
||||
const VF b2 = hn::LoadU(d32, tile_b + stride_b * 2 + col_ab);
|
||||
const VF b3 = hn::LoadU(d32, tile_b + stride_b * 3 + col_ab);
|
||||
|
||||
const VF a0 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 0 + col_ab));
|
||||
c00 = hn::MulAdd(a0, b0, c00);
|
||||
c01 = hn::MulAdd(a0, b1, c01);
|
||||
c02 = hn::MulAdd(a0, b2, c02);
|
||||
c03 = hn::MulAdd(a0, b3, c03);
|
||||
if constexpr (kNumRows == 1) continue;
|
||||
|
||||
const VF a1 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab));
|
||||
c10 = hn::MulAdd(a1, b0, c10);
|
||||
c11 = hn::MulAdd(a1, b1, c11);
|
||||
c12 = hn::MulAdd(a1, b2, c12);
|
||||
c13 = hn::MulAdd(a1, b3, c13);
|
||||
if constexpr (kNumRows == 2) continue;
|
||||
|
||||
const VF a2 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab));
|
||||
c20 = hn::MulAdd(a2, b0, c20);
|
||||
c21 = hn::MulAdd(a2, b1, c21);
|
||||
c22 = hn::MulAdd(a2, b2, c22);
|
||||
c23 = hn::MulAdd(a2, b3, c23);
|
||||
if constexpr (kNumRows == 3) continue;
|
||||
|
||||
const VF a3 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab));
|
||||
c30 = hn::MulAdd(a3, b0, c30);
|
||||
c31 = hn::MulAdd(a3, b1, c31);
|
||||
c32 = hn::MulAdd(a3, b2, c32);
|
||||
c33 = hn::MulAdd(a3, b3, c33);
|
||||
}
|
||||
|
||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||
if constexpr (kAdd) {
|
||||
const AddT* dd = add + row_b_col_c;
|
||||
StoreHorizontalSumsAdd<kNumRows>(d32, c00, c01, c02, c03, c10, c11, c12,
|
||||
c13, c20, c21, c22, c23, c30, c31, c32,
|
||||
c33, dd, tile_c, stride_c);
|
||||
} else {
|
||||
StoreHorizontalSums<kNumRows>(
|
||||
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||
c30, c31, c32, c33, tile_c, stride_c);
|
||||
}
|
||||
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
|
||||
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
|
||||
c32, c33, add, row_b_col_c, tile_c, stride_c);
|
||||
}
|
||||
|
||||
// Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k,
|
||||
|
|
@ -842,45 +481,7 @@ HWY_NOINLINE void MatMul_4x4_Batch(
|
|||
batch_size, A, B, C, /*add=*/static_cast<OutT*>(nullptr), pool);
|
||||
}
|
||||
|
||||
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
|
||||
// ops_test across instruction sets.
|
||||
template <size_t kN, size_t kK, typename MatTA, typename MatTB,
|
||||
HWY_IF_T_SIZE_GT(MatTB, 1)>
|
||||
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
|
||||
const MatTB* HWY_RESTRICT b,
|
||||
const float* add,
|
||||
float* HWY_RESTRICT out) {
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
for (size_t k = 0; k < kN; ++k) {
|
||||
for (size_t j = 0; j < kK; ++j) {
|
||||
const float a1 = hwy::ConvertScalarTo<float>(a[i * kN + k]);
|
||||
const float b1 = hwy::ConvertScalarTo<float>(b[k * kK + j]);
|
||||
out[i * kK + j] += a1 * b1;
|
||||
}
|
||||
}
|
||||
if (add != nullptr) {
|
||||
for (size_t j = 0; j < kK; ++j) {
|
||||
out[i * kK + j] += add[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The above overload can handle combinations of f32 and bf16, but this one
|
||||
// is required for MatTB = {SFP, NUQ}.
|
||||
template <size_t kN, size_t kK, typename MatTA, typename MatTB,
|
||||
HWY_IF_T_SIZE(MatTB, 1)>
|
||||
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
|
||||
const MatTB* HWY_RESTRICT b_compr,
|
||||
const float* add,
|
||||
float* HWY_RESTRICT out) {
|
||||
const hn::ScalableTag<float> d;
|
||||
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
|
||||
CompressTraits<MatTB>::Decompress(d, /*in_capacity=*/0, b_compr, 0, b.get(),
|
||||
kK * kN);
|
||||
MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), add, out);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
|
||||
const size_t size, float* HWY_RESTRICT out) {
|
||||
const hn::ScalableTag<float> df;
|
||||
|
|
|
|||
|
|
@ -508,6 +508,43 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
|
|||
}
|
||||
}
|
||||
|
||||
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
|
||||
// ops_test across instruction sets.
|
||||
template <size_t kN, size_t kK, typename MatTA, typename MatTB,
|
||||
HWY_IF_T_SIZE_GT(MatTB, 1)>
|
||||
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
|
||||
const MatTB* HWY_RESTRICT b, const float* add,
|
||||
float* HWY_RESTRICT out) {
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
for (size_t k = 0; k < kN; ++k) {
|
||||
for (size_t j = 0; j < kK; ++j) {
|
||||
const float a1 = hwy::ConvertScalarTo<float>(a[i * kN + k]);
|
||||
const float b1 = hwy::ConvertScalarTo<float>(b[k * kK + j]);
|
||||
out[i * kK + j] += a1 * b1;
|
||||
}
|
||||
}
|
||||
if (add != nullptr) {
|
||||
for (size_t j = 0; j < kK; ++j) {
|
||||
out[i * kK + j] += add[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The above overload can handle combinations of f32 and bf16, but this one
|
||||
// is required for MatTB = {SFP, NUQ}.
|
||||
template <size_t kN, size_t kK, typename MatTA, typename MatTB,
|
||||
HWY_IF_T_SIZE(MatTB, 1)>
|
||||
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
|
||||
const MatTB* HWY_RESTRICT b_compr,
|
||||
const float* add, float* HWY_RESTRICT out) {
|
||||
const hn::ScalableTag<float> d;
|
||||
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
|
||||
CompressTraits<MatTB>::Decompress(d, /*in_capacity=*/0, b_compr, 0, b.get(),
|
||||
kK * kN);
|
||||
MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), add, out);
|
||||
}
|
||||
|
||||
template <size_t kM, size_t kN, size_t kK, bool kAdd, typename MatTA,
|
||||
typename MatTB = MatTA>
|
||||
void TestTiledBatchMatMul() {
|
||||
|
|
|
|||
Loading…
Reference in New Issue