mirror of https://github.com/google/gemma.cpp.git
Add bf16 matmul support, update naming+test
Avoid int32, which can easily overflow for large matrices. Also fix IDE warning in sfp-inl. PiperOrigin-RevId: 640149845
This commit is contained in:
parent
25d9c8ff30
commit
4f9155d8c6
|
|
@ -199,8 +199,11 @@ class SfpCodec {
|
|||
const hn::Vec<D> tblL1 = hn::LoadU(d, kTblL1);
|
||||
const hn::Vec<D> tblH0 = hn::LoadU(d, kTblH0);
|
||||
const hn::Vec<D> tblH1 = hn::LoadU(d, kTblH1);
|
||||
// AVX-512 ignores the index MSB, no need to clear.
|
||||
#if HWY_IDE // only let the IDE see portable code.
|
||||
const auto idx = hn::IndicesFromVec(hn::AndNot(k80, encoded));
|
||||
#else // AVX-512-specific: index MSB is ignored, no need to clear.
|
||||
const hn::Indices512<uint8_t> idx{encoded.raw};
|
||||
#endif
|
||||
hi = hn::TwoTablesLookupLanes(d, tblH0, tblH1, idx);
|
||||
lo = hn::TwoTablesLookupLanes(d, tblL0, tblL1, idx);
|
||||
hi = hn::OrAnd(hi, encoded, k80); // Insert sign bit
|
||||
|
|
|
|||
404
gemma/ops.h
404
gemma/ops.h
|
|
@ -94,19 +94,58 @@ HWY_INLINE constexpr size_t RowsPerStrip() {
|
|||
return kRowsPerStrip;
|
||||
}
|
||||
|
||||
// Processes a single 4x4 tile of A x B. Shared between static and dynamic
|
||||
// versions.
|
||||
template <typename MatT, size_t kColsA>
|
||||
// Shared between f32 and bf16, which also accumulates into f32 vectors.
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void StoreHorizontalSums(DF df, VF c00, VF c01, VF c02, VF c03,
|
||||
VF c10, VF c11, VF c12, VF c13, //
|
||||
VF c20, VF c21, VF c22, VF c23, //
|
||||
VF c30, VF c31, VF c32, VF c33,
|
||||
float* HWY_RESTRICT tile_c,
|
||||
size_t stride_c) {
|
||||
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
|
||||
// Each entry of C[r,c] is a dot product of A.row and B.col, which reside in
|
||||
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
|
||||
// expensive, but only a fraction of the kColsA_RowsB/N FMAs.
|
||||
tile_c[stride_c * 0 + 0] = hn::ReduceSum(df, c00);
|
||||
tile_c[stride_c * 0 + 1] = hn::ReduceSum(df, c01);
|
||||
tile_c[stride_c * 0 + 2] = hn::ReduceSum(df, c02);
|
||||
tile_c[stride_c * 0 + 3] = hn::ReduceSum(df, c03);
|
||||
|
||||
tile_c[stride_c * 1 + 0] = hn::ReduceSum(df, c10);
|
||||
tile_c[stride_c * 1 + 1] = hn::ReduceSum(df, c11);
|
||||
tile_c[stride_c * 1 + 2] = hn::ReduceSum(df, c12);
|
||||
tile_c[stride_c * 1 + 3] = hn::ReduceSum(df, c13);
|
||||
|
||||
tile_c[stride_c * 2 + 0] = hn::ReduceSum(df, c20);
|
||||
tile_c[stride_c * 2 + 1] = hn::ReduceSum(df, c21);
|
||||
tile_c[stride_c * 2 + 2] = hn::ReduceSum(df, c22);
|
||||
tile_c[stride_c * 2 + 3] = hn::ReduceSum(df, c23);
|
||||
|
||||
tile_c[stride_c * 3 + 0] = hn::ReduceSum(df, c30);
|
||||
tile_c[stride_c * 3 + 1] = hn::ReduceSum(df, c31);
|
||||
tile_c[stride_c * 3 + 2] = hn::ReduceSum(df, c32);
|
||||
tile_c[stride_c * 3 + 3] = hn::ReduceSum(df, c33);
|
||||
}
|
||||
|
||||
// Accumulates a single 4x4 tile of A x B into C. B is transposed, so we can
|
||||
// iterate over both A and B with consecutive vector loads.
|
||||
// Shared between parallelized and sequential (loop) callers.
|
||||
template <size_t kColsA_RowsB, typename MatT, HWY_IF_F32(MatT)>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||
const MatT* HWY_RESTRICT B, MatT* HWY_RESTRICT C,
|
||||
size_t tile_num, const int xtiles, const int lda,
|
||||
const int ldb, const int ldc) {
|
||||
constexpr int RM = 4;
|
||||
constexpr int RN = 4;
|
||||
const size_t idx_tile, const size_t xtiles,
|
||||
const size_t stride_a, const size_t stride_b,
|
||||
const size_t stride_c) {
|
||||
// Tile size. 4x unrolling makes sense on many platforms because we can fit
|
||||
// 4x4 accumulators and 8 temporaries in the 32 vectors; we have more than
|
||||
// #FMA units * FMA latency (up to 2*5) independent computations in flight;
|
||||
// threads write in units of 4*N elements, which is at least one cache line.
|
||||
constexpr size_t kRegRows = 4;
|
||||
constexpr size_t kRegCols = 4;
|
||||
|
||||
// Calculate chunk start coords.
|
||||
int ii = tile_num / xtiles * RM;
|
||||
int jj = tile_num % xtiles * RN;
|
||||
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
||||
const size_t row_a = idx_tile / xtiles * kRegRows;
|
||||
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
|
||||
|
||||
const hn::ScalableTag<MatT> d;
|
||||
const size_t N = Lanes(d);
|
||||
|
|
@ -132,146 +171,269 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
|||
V c32 = hn::Zero(d);
|
||||
V c33 = hn::Zero(d);
|
||||
|
||||
// Steps down the rows of A and B, and across width (kN) in steps of
|
||||
// N (Lanes()). Accumulates into the cache vectors. hn::ReduceSum() is
|
||||
// called on each of the cache vectors to sum the partial sums into C.
|
||||
for (size_t l = 0; l < kColsA; l += N) {
|
||||
V k0 = hn::LoadU(d, B + ldb * (jj + 0) + l);
|
||||
V k1 = hn::LoadU(d, B + ldb * (jj + 1) + l);
|
||||
V k2 = hn::LoadU(d, B + ldb * (jj + 2) + l);
|
||||
V k3 = hn::LoadU(d, B + ldb * (jj + 3) + l);
|
||||
const MatT* HWY_RESTRICT tile_a = A + stride_a * row_a;
|
||||
const MatT* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c;
|
||||
|
||||
V a0 = hn::LoadU(d, A + lda * (ii + 0) + l);
|
||||
c00 = hn::MulAdd(a0, k0, c00);
|
||||
c01 = hn::MulAdd(a0, k1, c01);
|
||||
c02 = hn::MulAdd(a0, k2, c02);
|
||||
c03 = hn::MulAdd(a0, k3, c03);
|
||||
// Loop over columns of A and columns of the transposed B, in steps of N.
|
||||
// Accumulates into the c## vectors.
|
||||
HWY_UNROLL(1)
|
||||
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
|
||||
const V b0 = hn::LoadU(d, tile_b + stride_b * 0 + col_ab);
|
||||
const V b1 = hn::LoadU(d, tile_b + stride_b * 1 + col_ab);
|
||||
const V b2 = hn::LoadU(d, tile_b + stride_b * 2 + col_ab);
|
||||
const V b3 = hn::LoadU(d, tile_b + stride_b * 3 + col_ab);
|
||||
|
||||
V a1 = hn::LoadU(d, A + lda * (ii + 1) + l);
|
||||
c10 = hn::MulAdd(a1, k0, c10);
|
||||
c11 = hn::MulAdd(a1, k1, c11);
|
||||
c12 = hn::MulAdd(a1, k2, c12);
|
||||
c13 = hn::MulAdd(a1, k3, c13);
|
||||
const V a0 = hn::LoadU(d, tile_a + stride_a * 0 + col_ab);
|
||||
c00 = hn::MulAdd(a0, b0, c00);
|
||||
c01 = hn::MulAdd(a0, b1, c01);
|
||||
c02 = hn::MulAdd(a0, b2, c02);
|
||||
c03 = hn::MulAdd(a0, b3, c03);
|
||||
|
||||
V a2 = hn::LoadU(d, A + lda * (ii + 2) + l);
|
||||
c20 = hn::MulAdd(a2, k0, c20);
|
||||
c21 = hn::MulAdd(a2, k1, c21);
|
||||
c22 = hn::MulAdd(a2, k2, c22);
|
||||
c23 = hn::MulAdd(a2, k3, c23);
|
||||
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
|
||||
c10 = hn::MulAdd(a1, b0, c10);
|
||||
c11 = hn::MulAdd(a1, b1, c11);
|
||||
c12 = hn::MulAdd(a1, b2, c12);
|
||||
c13 = hn::MulAdd(a1, b3, c13);
|
||||
|
||||
V a3 = hn::LoadU(d, A + lda * (ii + 3) + l);
|
||||
c30 = hn::MulAdd(a3, k0, c30);
|
||||
c31 = hn::MulAdd(a3, k1, c31);
|
||||
c32 = hn::MulAdd(a3, k2, c32);
|
||||
c33 = hn::MulAdd(a3, k3, c33);
|
||||
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
|
||||
c20 = hn::MulAdd(a2, b0, c20);
|
||||
c21 = hn::MulAdd(a2, b1, c21);
|
||||
c22 = hn::MulAdd(a2, b2, c22);
|
||||
c23 = hn::MulAdd(a2, b3, c23);
|
||||
|
||||
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
|
||||
c30 = hn::MulAdd(a3, b0, c30);
|
||||
c31 = hn::MulAdd(a3, b1, c31);
|
||||
c32 = hn::MulAdd(a3, b2, c32);
|
||||
c33 = hn::MulAdd(a3, b3, c33);
|
||||
}
|
||||
|
||||
C[ldc * (ii + 0) + (jj + 0)] = hn::ReduceSum(d, c00);
|
||||
C[ldc * (ii + 0) + (jj + 1)] = hn::ReduceSum(d, c01);
|
||||
C[ldc * (ii + 0) + (jj + 2)] = hn::ReduceSum(d, c02);
|
||||
C[ldc * (ii + 0) + (jj + 3)] = hn::ReduceSum(d, c03);
|
||||
|
||||
C[ldc * (ii + 1) + (jj + 0)] = hn::ReduceSum(d, c10);
|
||||
C[ldc * (ii + 1) + (jj + 1)] = hn::ReduceSum(d, c11);
|
||||
C[ldc * (ii + 1) + (jj + 2)] = hn::ReduceSum(d, c12);
|
||||
C[ldc * (ii + 1) + (jj + 3)] = hn::ReduceSum(d, c13);
|
||||
|
||||
C[ldc * (ii + 2) + (jj + 0)] = hn::ReduceSum(d, c20);
|
||||
C[ldc * (ii + 2) + (jj + 1)] = hn::ReduceSum(d, c21);
|
||||
C[ldc * (ii + 2) + (jj + 2)] = hn::ReduceSum(d, c22);
|
||||
C[ldc * (ii + 2) + (jj + 3)] = hn::ReduceSum(d, c23);
|
||||
|
||||
C[ldc * (ii + 3) + (jj + 0)] = hn::ReduceSum(d, c30);
|
||||
C[ldc * (ii + 3) + (jj + 1)] = hn::ReduceSum(d, c31);
|
||||
C[ldc * (ii + 3) + (jj + 2)] = hn::ReduceSum(d, c32);
|
||||
C[ldc * (ii + 3) + (jj + 3)] = hn::ReduceSum(d, c33);
|
||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||
StoreHorizontalSums(d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22,
|
||||
c23, c30, c31, c32, c33, tile_c, stride_c);
|
||||
}
|
||||
|
||||
// Tiled 4x4 GEMM. Covers primary M =4..512, k = 3k/24k, n = 24k/3k use case.
|
||||
// This version uses tiling suitable for static scheduling.
|
||||
// Note: expects transposed / shuffled B.
|
||||
template <size_t kM, size_t kColsA, size_t kK, typename MatT>
|
||||
#undef GEMMA_NATIVE_BF16
|
||||
#if HWY_IDE || (defined(HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16) == \
|
||||
defined(HWY_TARGET_TOGGLE))
|
||||
#define GEMMA_NATIVE_BF16 1
|
||||
#else
|
||||
#define GEMMA_NATIVE_BF16 0
|
||||
#endif
|
||||
|
||||
// As above, for MatT=bf16
|
||||
template <size_t kColsA_RowsB, typename MatT, HWY_IF_BF16(MatT)>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||
const MatT* HWY_RESTRICT B, float* HWY_RESTRICT C,
|
||||
const size_t idx_tile, const size_t xtiles,
|
||||
const size_t stride_a, const size_t stride_b,
|
||||
const size_t stride_c) {
|
||||
constexpr size_t kRegRows = 4;
|
||||
constexpr size_t kRegCols = 4;
|
||||
|
||||
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
||||
const size_t row_a = idx_tile / xtiles * kRegRows;
|
||||
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
|
||||
|
||||
const hn::ScalableTag<float> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
#if GEMMA_NATIVE_BF16
|
||||
// ReorderWidenMulAccumulate does not use its sum1 arg and we can use full
|
||||
// bf16 vectors.
|
||||
const hn::Repartition<MatT, decltype(df)> d;
|
||||
VF unused_sum1 = hn::Zero(df);
|
||||
#else
|
||||
// Emulated: use half-vectors of bf16 because we cannot afford two sums for
|
||||
// each c##.
|
||||
const hn::Rebind<MatT, decltype(df)> d;
|
||||
HWY_DASSERT(Lanes(d) == Lanes(df));
|
||||
#endif
|
||||
|
||||
const size_t N = Lanes(d);
|
||||
|
||||
VF c00 = hn::Zero(df);
|
||||
VF c01 = hn::Zero(df);
|
||||
VF c02 = hn::Zero(df);
|
||||
VF c03 = hn::Zero(df);
|
||||
|
||||
VF c10 = hn::Zero(df);
|
||||
VF c11 = hn::Zero(df);
|
||||
VF c12 = hn::Zero(df);
|
||||
VF c13 = hn::Zero(df);
|
||||
|
||||
VF c20 = hn::Zero(df);
|
||||
VF c21 = hn::Zero(df);
|
||||
VF c22 = hn::Zero(df);
|
||||
VF c23 = hn::Zero(df);
|
||||
|
||||
VF c30 = hn::Zero(df);
|
||||
VF c31 = hn::Zero(df);
|
||||
VF c32 = hn::Zero(df);
|
||||
VF c33 = hn::Zero(df);
|
||||
|
||||
const MatT* HWY_RESTRICT tile_a = A + stride_a * row_a;
|
||||
const MatT* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c;
|
||||
|
||||
// Loop over columns of A and columns of the transposed B, in steps of N.
|
||||
// Accumulates into the c## vectors.
|
||||
HWY_UNROLL(1)
|
||||
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
|
||||
#if GEMMA_NATIVE_BF16
|
||||
using V = hn::Vec<decltype(d)>;
|
||||
const V b0 = hn::LoadU(d, tile_b + stride_b * 0 + col_ab);
|
||||
const V b1 = hn::LoadU(d, tile_b + stride_b * 1 + col_ab);
|
||||
const V b2 = hn::LoadU(d, tile_b + stride_b * 2 + col_ab);
|
||||
const V b3 = hn::LoadU(d, tile_b + stride_b * 3 + col_ab);
|
||||
|
||||
const V a0 = hn::LoadU(d, tile_a + stride_a * 0 + col_ab);
|
||||
c00 = hn::ReorderWidenMulAccumulate(df, a0, b0, c00, unused_sum1);
|
||||
c01 = hn::ReorderWidenMulAccumulate(df, a0, b1, c01, unused_sum1);
|
||||
c02 = hn::ReorderWidenMulAccumulate(df, a0, b2, c02, unused_sum1);
|
||||
c03 = hn::ReorderWidenMulAccumulate(df, a0, b3, c03, unused_sum1);
|
||||
|
||||
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
|
||||
c10 = hn::ReorderWidenMulAccumulate(df, a1, b0, c10, unused_sum1);
|
||||
c11 = hn::ReorderWidenMulAccumulate(df, a1, b1, c11, unused_sum1);
|
||||
c12 = hn::ReorderWidenMulAccumulate(df, a1, b2, c12, unused_sum1);
|
||||
c13 = hn::ReorderWidenMulAccumulate(df, a1, b3, c13, unused_sum1);
|
||||
|
||||
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
|
||||
c20 = hn::ReorderWidenMulAccumulate(df, a2, b0, c20, unused_sum1);
|
||||
c21 = hn::ReorderWidenMulAccumulate(df, a2, b1, c21, unused_sum1);
|
||||
c22 = hn::ReorderWidenMulAccumulate(df, a2, b2, c22, unused_sum1);
|
||||
c23 = hn::ReorderWidenMulAccumulate(df, a2, b3, c23, unused_sum1);
|
||||
|
||||
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
|
||||
c30 = hn::ReorderWidenMulAccumulate(df, a3, b0, c30, unused_sum1);
|
||||
c31 = hn::ReorderWidenMulAccumulate(df, a3, b1, c31, unused_sum1);
|
||||
c32 = hn::ReorderWidenMulAccumulate(df, a3, b2, c32, unused_sum1);
|
||||
c33 = hn::ReorderWidenMulAccumulate(df, a3, b3, c33, unused_sum1);
|
||||
#else // Emulated: promote bf16 to f32
|
||||
const VF b0 =
|
||||
hn::PromoteTo(df, hn::LoadU(d, tile_b + stride_b * 0 + col_ab));
|
||||
const VF b1 =
|
||||
hn::PromoteTo(df, hn::LoadU(d, tile_b + stride_b * 1 + col_ab));
|
||||
const VF b2 =
|
||||
hn::PromoteTo(df, hn::LoadU(d, tile_b + stride_b * 2 + col_ab));
|
||||
const VF b3 =
|
||||
hn::PromoteTo(df, hn::LoadU(d, tile_b + stride_b * 3 + col_ab));
|
||||
|
||||
const VF a0 =
|
||||
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 0 + col_ab));
|
||||
c00 = hn::MulAdd(a0, b0, c00);
|
||||
c01 = hn::MulAdd(a0, b1, c01);
|
||||
c02 = hn::MulAdd(a0, b2, c02);
|
||||
c03 = hn::MulAdd(a0, b3, c03);
|
||||
|
||||
const VF a1 =
|
||||
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 1 + col_ab));
|
||||
c10 = hn::MulAdd(a1, b0, c10);
|
||||
c11 = hn::MulAdd(a1, b1, c11);
|
||||
c12 = hn::MulAdd(a1, b2, c12);
|
||||
c13 = hn::MulAdd(a1, b3, c13);
|
||||
|
||||
const VF a2 =
|
||||
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 2 + col_ab));
|
||||
c20 = hn::MulAdd(a2, b0, c20);
|
||||
c21 = hn::MulAdd(a2, b1, c21);
|
||||
c22 = hn::MulAdd(a2, b2, c22);
|
||||
c23 = hn::MulAdd(a2, b3, c23);
|
||||
|
||||
const VF a3 =
|
||||
hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 3 + col_ab));
|
||||
c30 = hn::MulAdd(a3, b0, c30);
|
||||
c31 = hn::MulAdd(a3, b1, c31);
|
||||
c32 = hn::MulAdd(a3, b2, c32);
|
||||
c33 = hn::MulAdd(a3, b3, c33);
|
||||
#endif // !GEMMA_NATIVE_BF16
|
||||
}
|
||||
|
||||
#if GEMMA_NATIVE_BF16
|
||||
// Ensure sum1 was indeed unused.
|
||||
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
|
||||
#endif
|
||||
|
||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||
StoreHorizontalSums(df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22,
|
||||
c23, c30, c31, c32, c33, tile_c, stride_c);
|
||||
}
|
||||
|
||||
// Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and
|
||||
// kColsBC is 24k or 3k. Note: B is transposed (column-major).
|
||||
// This function loops over all tiles (static scheduling). TODO(janwas): we can
|
||||
// possibly remove this if ThreadPool(0) is as efficient as the loop.
|
||||
template <size_t kRowsAC, size_t kColsA_RowsB, size_t kColsBC, typename MatT>
|
||||
void GEMM_4x4_Static(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B,
|
||||
MatT* HWY_RESTRICT C) {
|
||||
const hn::ScalableTag<MatT> d;
|
||||
const size_t N = hn::Lanes(d); // column step size
|
||||
constexpr int RM = 4; // tile height
|
||||
constexpr int RN = 4; // tile width
|
||||
constexpr size_t kRegRows = 4;
|
||||
constexpr size_t kRegCols = 4; // in vectors
|
||||
|
||||
static_assert(kM % RM == 0);
|
||||
static_assert(kColsA % N == 0);
|
||||
static_assert(kColsA % RN == 0);
|
||||
static_assert(kRowsAC % kRegRows == 0);
|
||||
static_assert(kColsA_RowsB % (N * kRegCols) == 0);
|
||||
static_assert(kColsBC % kRegCols == 0);
|
||||
constexpr size_t kTilesY = kRowsAC / kRegRows;
|
||||
constexpr size_t kTilesX = kColsBC / kRegCols;
|
||||
constexpr size_t kTiles = kTilesX * kTilesY;
|
||||
|
||||
int lda = kColsA;
|
||||
int ldb = kColsA; // n instead of k because we're transposing
|
||||
int ldc = kK;
|
||||
constexpr size_t kStrideA = kColsA_RowsB;
|
||||
constexpr size_t kStrideB = kColsA_RowsB; // B is column-major
|
||||
constexpr size_t kStrideC = kColsBC;
|
||||
|
||||
int ytiles = (kM) / RM;
|
||||
int xtiles = (kK) / RN; // k instead of n because we're transposing
|
||||
int tiles = xtiles * ytiles;
|
||||
|
||||
for (int job = 0; job < tiles; ++job) {
|
||||
GEMM_4x4_Tile<MatT, kColsA>(A, B, C, job, xtiles, lda, ldb, ldc);
|
||||
HWY_UNROLL(1)
|
||||
for (size_t idx_tile = 0; idx_tile < kTiles; ++idx_tile) {
|
||||
GEMM_4x4_Tile<kColsA_RowsB>(A, B, C, idx_tile, kTilesX, kStrideA, kStrideB,
|
||||
kStrideC);
|
||||
}
|
||||
}
|
||||
|
||||
// Tiled 4x4 GEMM. Covers primary M =4..512, k = 3k/24k, n = 24k/3k use case.
|
||||
// This version uses tiling and pooled threads.
|
||||
// Note: expects transposed / shuffled B.
|
||||
template <size_t kM, size_t kColsA, size_t kK, typename MatT>
|
||||
HWY_NOINLINE void MatMul_4x4_Impl(const MatT* HWY_RESTRICT A,
|
||||
const MatT* HWY_RESTRICT B,
|
||||
MatT* HWY_RESTRICT C, hwy::ThreadPool& pool) {
|
||||
// Process 4x4 chunks of C in parallel. Each pool thread handles a single A x
|
||||
// B tile. Note that C is being addressed directly without a buffer, and that
|
||||
// the cache vectors (c00, c01, etc.) are being summed directly into C. There
|
||||
// may be additional stability / speed gains to be made by using a buffer.
|
||||
// Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and
|
||||
// kColsBC is 24k or 3k. Note: B is transposed (column-major).
|
||||
// This function processes tiles in parallel with a work-stealing thread pool.
|
||||
template <size_t kRowsAC, size_t kColsA_RowsB, size_t kColsBC, typename MatT,
|
||||
typename OutT>
|
||||
HWY_NOINLINE void MatMul_4x4(const MatT* HWY_RESTRICT A,
|
||||
const MatT* HWY_RESTRICT B, OutT* HWY_RESTRICT C,
|
||||
hwy::ThreadPool& pool) {
|
||||
// Process reg-sized tiles of C in parallel. We currently write C directly,
|
||||
// which touches more memory than fits in L3. TODO: add another level of loops
|
||||
// so that we finish one L3-sized piece of C at a time.
|
||||
const hn::ScalableTag<MatT> d;
|
||||
const size_t N = Lanes(d);
|
||||
constexpr size_t kRegRows = 4;
|
||||
constexpr size_t kRegCols = 4; // in vectors
|
||||
|
||||
const int lda = kColsA;
|
||||
const int ldb = kColsA; // n instead of k because we're transposing
|
||||
const int ldc = kK;
|
||||
static_assert(kRowsAC % kRegRows == 0);
|
||||
static_assert(kColsA_RowsB % (N * kRegCols) == 0);
|
||||
static_assert(kColsBC % kRegCols == 0);
|
||||
const size_t kTilesY = kRowsAC / kRegRows;
|
||||
const size_t kTilesX = kColsBC / kRegCols;
|
||||
const size_t kTiles = kTilesX * kTilesY;
|
||||
|
||||
// 4x4
|
||||
const int RM = 4;
|
||||
const int RN = 4;
|
||||
constexpr size_t kStrideA = kColsA_RowsB;
|
||||
constexpr size_t kStrideB = kColsA_RowsB;
|
||||
constexpr size_t kStrideC = kColsBC;
|
||||
|
||||
const int ytiles = (kM) / RM;
|
||||
const int xtiles = (kK) / RN; // k instead of n because we're transposing
|
||||
const int tiles = xtiles * ytiles;
|
||||
|
||||
// 4x4 case requires kM % 4 == 0, kN % N == 0, kK % 4 == 0
|
||||
static_assert(kM % RM == 0);
|
||||
static_assert(kColsA % N == 0);
|
||||
static_assert(kColsA % RN == 0);
|
||||
static_assert(kK % RN == 0);
|
||||
static_assert(kColsA >= N);
|
||||
|
||||
// Handles a single 4x4 chunk, which is completed and then written into C.
|
||||
pool.Run(0, tiles, [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR {
|
||||
GEMM_4x4_Tile<MatT, kColsA>(A, B, C, chunk, xtiles, lda, ldb, ldc);
|
||||
pool.Run(0, kTiles, [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
|
||||
// Computes the finished product of one 4x4N tile and writes to C.
|
||||
GEMM_4x4_Tile<kColsA_RowsB>(A, B, C, idx_tile, kTilesX, kStrideA, kStrideB,
|
||||
kStrideC);
|
||||
});
|
||||
}
|
||||
|
||||
// Requires m % 4 == 0, n % Lanes() == 0, k % 4 == 0
|
||||
template <size_t kM, size_t kN, size_t kK, typename MatT>
|
||||
HWY_INLINE void MatMul_4x4(const MatT* HWY_RESTRICT a,
|
||||
const MatT* HWY_RESTRICT b, MatT* HWY_RESTRICT out,
|
||||
hwy::ThreadPool& pool) {
|
||||
return MatMul_4x4_Impl<kM, kN, kK, MatT>(a, b, out, pool);
|
||||
}
|
||||
|
||||
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
|
||||
// ops_test across instruction sets.
|
||||
template <size_t kM, size_t kN, size_t kK, typename MatT>
|
||||
HWY_INLINE void MatMul(const MatT* HWY_RESTRICT a, const MatT* HWY_RESTRICT b,
|
||||
MatT* HWY_RESTRICT out) {
|
||||
int i, j, k;
|
||||
for (i = 0; i < kM; ++i) {
|
||||
for (k = 0; k < kN; ++k) {
|
||||
for (j = 0; j < kK; ++j) {
|
||||
out[i * kK + j] += a[i * kN + k] * b[k * kK + j];
|
||||
HWY_INLINE void MatMulSlow(const MatT* HWY_RESTRICT a,
|
||||
const MatT* HWY_RESTRICT b,
|
||||
float* HWY_RESTRICT out) {
|
||||
for (size_t i = 0; i < kM; ++i) {
|
||||
for (size_t k = 0; k < kN; ++k) {
|
||||
for (size_t j = 0; j < kK; ++j) {
|
||||
const float a1 = hwy::ConvertScalarTo<float>(a[i * kN + k]);
|
||||
const float b1 = hwy::ConvertScalarTo<float>(b[k * kK + j]);
|
||||
out[i * kK + j] += a1 * b1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -341,70 +342,52 @@ void TestAllCreateDistribution() {
|
|||
TestCreateDistribution<5000>();
|
||||
}
|
||||
|
||||
template <size_t kOuter, size_t kInner>
|
||||
CompressedArray<float, kOuter * kInner> GenerateMat(size_t offset) {
|
||||
hwy::ThreadPool pool(0);
|
||||
template <typename MatT, size_t kOuter, size_t kInner>
|
||||
CompressedArray<MatT, kOuter * kInner> GenerateMat(size_t offset,
|
||||
hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
CompressedArray<float, kOuter * kInner> mat;
|
||||
CompressedArray<MatT, kOuter * kInner> mat;
|
||||
std::array<float, kOuter * kInner> content;
|
||||
const float scale = 1.0f / kInner;
|
||||
for (size_t i = 0; i < kOuter; i++) {
|
||||
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
|
||||
for (size_t j = 0; j < kInner; j++) {
|
||||
content[i * kInner + j] =
|
||||
static_cast<float>((i * kInner + j + offset) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
// for (size_t i = 0; i < kOuter; i++) {
|
||||
// for (size_t j = 0; j < kInner; j++) {
|
||||
// fprintf(stderr, "content[%lu] = %f\n", i * kInner + j,
|
||||
// content[i * kInner + j]);
|
||||
// }
|
||||
// }
|
||||
});
|
||||
|
||||
Compress(content, ws, mat, pool);
|
||||
mat.set_scale(1.0f);
|
||||
return mat;
|
||||
}
|
||||
|
||||
template <size_t kOuter, size_t kInner>
|
||||
CompressedArray<float, kOuter * kInner> GenerateTransposeMat(size_t offset) {
|
||||
hwy::ThreadPool pool(0);
|
||||
template <typename MatT, size_t kOuter, size_t kInner>
|
||||
CompressedArray<MatT, kOuter * kInner> GenerateTransposeMat(
|
||||
size_t offset, hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
CompressedArray<float, kOuter * kInner> mat;
|
||||
CompressedArray<MatT, kOuter * kInner> mat;
|
||||
std::array<float, kOuter * kInner> content;
|
||||
const float scale = 1.0f / kInner;
|
||||
for (size_t i = 0; i < kOuter; i++) {
|
||||
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
|
||||
for (size_t j = 0; j < kInner; j++) {
|
||||
content[j * kOuter + i] =
|
||||
static_cast<float>((i * kInner + j + offset) * scale);
|
||||
content[i * kInner + j] =
|
||||
static_cast<float>((j * kInner + i + offset) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
// for (size_t i = 0; i < kOuter; i++) {
|
||||
// for (size_t j = 0; j < kInner; j++) {
|
||||
// fprintf(stderr, "content[%lu] = %f (transpose)\n", i * kInner + j,
|
||||
// content[i * kInner + j]);
|
||||
// }
|
||||
// }
|
||||
});
|
||||
|
||||
Compress(content, ws, mat, pool);
|
||||
mat.set_scale(1.0f);
|
||||
return mat;
|
||||
}
|
||||
|
||||
template <size_t kOuter, size_t kInner>
|
||||
CompressedArray<float, kOuter * kInner> GenerateZeroMat(size_t offset) {
|
||||
hwy::ThreadPool pool(static_cast<size_t>(std::clamp(
|
||||
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 4)));
|
||||
template <typename MatT, size_t kOuter, size_t kInner>
|
||||
CompressedArray<MatT, kOuter * kInner> GenerateZeroMat(hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
CompressedArray<float, kOuter * kInner> mat;
|
||||
std::array<float, kOuter * kInner> content;
|
||||
CompressedArray<MatT, kOuter * kInner> mat;
|
||||
std::array<MatT, kOuter * kInner> content;
|
||||
|
||||
pool.Run(0, kOuter, [&](const size_t i, size_t thread) {
|
||||
for (size_t j = 0; j < kInner; j++) {
|
||||
content[i * kInner + j] = 0.0f;
|
||||
}
|
||||
hwy::ZeroBytes(&content[i * kInner], kInner * sizeof(content[0]));
|
||||
});
|
||||
|
||||
Compress(content, ws, mat, pool);
|
||||
|
|
@ -474,8 +457,8 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
|
|||
}
|
||||
|
||||
template <typename MatT>
|
||||
void AssertClose(const hwy::AlignedFreeUniquePtr<MatT[]>& expected,
|
||||
const hwy::AlignedFreeUniquePtr<MatT[]>& actual, size_t num) {
|
||||
void AssertClose(const MatT* HWY_RESTRICT expected,
|
||||
const MatT* HWY_RESTRICT actual, size_t num) {
|
||||
for (size_t idx = 0; idx < num; idx++) {
|
||||
double expected_value = hwy::ConvertScalarTo<double>(expected[idx]);
|
||||
double actual_value = hwy::ConvertScalarTo<double>(actual[idx]);
|
||||
|
|
@ -492,69 +475,38 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<MatT[]>& expected,
|
|||
}
|
||||
}
|
||||
|
||||
template <typename MatT>
|
||||
void TestTiledMatMul() {
|
||||
hwy::ThreadPool pool(0);
|
||||
hwy::ThreadPool pool(3);
|
||||
constexpr size_t kM = 512; // 384
|
||||
constexpr size_t kN = 512; // * 5; // 6; // 768
|
||||
constexpr size_t kK = 512; // * 5; // 640
|
||||
|
||||
CompressedArray<float, kM * kN> a1 = GenerateMat<kM, kN>(0);
|
||||
CompressedArray<float, kN * kK> b1 = GenerateMat<kN, kK>(0);
|
||||
CompressedArray<MatT, kM * kN> a = GenerateMat<MatT, kM, kN>(0, pool);
|
||||
CompressedArray<MatT, kN * kK> b = GenerateMat<MatT, kN, kK>(0, pool);
|
||||
CompressedArray<float, kN * kK> c_slow = GenerateZeroMat<float, kM, kK>(pool);
|
||||
MatMulSlow<kM, kN, kK>(a.data(), b.data(), c_slow.data());
|
||||
|
||||
hwy::AlignedFreeUniquePtr<float[]> a = hwy::AllocateAligned<float>(kM * kN);
|
||||
Decompress(a1, 0, a.get(), kM * kN);
|
||||
|
||||
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kN * kK);
|
||||
Decompress(b1, 0, b.get(), kN * kK);
|
||||
|
||||
hwy::AlignedFreeUniquePtr<float[]> expected_out1 =
|
||||
SimpleMatMul<kM, kN, kK>(a, b);
|
||||
|
||||
CompressedArray<float, kM * kK> compressed_c = GenerateZeroMat<kM, kK>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
|
||||
Decompress(compressed_c, 0, c.get(), kM * kK);
|
||||
CompressedArray<MatT, kN * kK> b_trans =
|
||||
GenerateTransposeMat<MatT, kN, kK>(0, pool);
|
||||
MatMul_4x4<kM, kN, kK>(a.data(), b_trans.data(), c.get(), pool);
|
||||
|
||||
CompressedArray<float, kN * kK> b1_trans = GenerateTransposeMat<kN, kK>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> b_trans =
|
||||
hwy::AllocateAligned<float>(kN * kK);
|
||||
Decompress(b1_trans, 0, b_trans.get(), kN * kK);
|
||||
MatMul_4x4<kM, kN, kK>(a.get(), b_trans.get(), c.get(), pool);
|
||||
|
||||
AssertClose(expected_out1, c, kM * kK);
|
||||
AssertClose(c_slow.data(), c.get(), kM * kK);
|
||||
}
|
||||
|
||||
void TestMatMul() {
|
||||
constexpr size_t kM = 512; // 384
|
||||
constexpr size_t kN = 512; // * 5; // 6; // 768
|
||||
constexpr size_t kK = 512; // * 5; // 640
|
||||
|
||||
CompressedArray<float, kM * kN> a1 = GenerateMat<kM, kN>(0);
|
||||
CompressedArray<float, kN * kK> b1 = GenerateMat<kN, kK>(0);
|
||||
|
||||
hwy::AlignedFreeUniquePtr<float[]> a = hwy::AllocateAligned<float>(kM * kN);
|
||||
Decompress(a1, 0, a.get(), kM * kN);
|
||||
|
||||
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kN * kK);
|
||||
Decompress(b1, 0, b.get(), kN * kK);
|
||||
|
||||
hwy::AlignedFreeUniquePtr<float[]> expected_out1 =
|
||||
SimpleMatMul<kM, kN, kK>(a, b);
|
||||
|
||||
CompressedArray<float, kM * kK> compressed_c = GenerateZeroMat<kM, kK>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
|
||||
Decompress(compressed_c, 0, c.get(), kM * kK);
|
||||
|
||||
Decompress(b1, 0, b.get(), kN * kK);
|
||||
MatMul<kM, kN, kK>(a.get(), b.get(), c.get());
|
||||
|
||||
AssertClose(expected_out1, c, kM * kK);
|
||||
void TestAllTiledMatMul() {
|
||||
TestTiledMatMul<float>();
|
||||
TestTiledMatMul<hwy::bfloat16_t>();
|
||||
// TODO(janwas): SFP
|
||||
}
|
||||
|
||||
void TestMatVecAdd() {
|
||||
hwy::ThreadPool pool(0);
|
||||
constexpr size_t kOuter = 128 * 3;
|
||||
constexpr size_t kInner = 128 * 5;
|
||||
CompressedArray<float, kOuter * kInner> mat = GenerateMat<kOuter, kInner>(0);
|
||||
CompressedArray<float, kOuter * kInner> mat =
|
||||
GenerateMat<float, kOuter, kInner>(0, pool);
|
||||
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> add = GenerateVec<kOuter>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> even_odd =
|
||||
|
|
@ -573,8 +525,10 @@ void TestTwoMatVecAdd() {
|
|||
hwy::ThreadPool pool(0);
|
||||
constexpr size_t kOuter = 128 * 3;
|
||||
constexpr size_t kInner = 128 * 5;
|
||||
CompressedArray<float, kOuter * kInner> mat0 = GenerateMat<kOuter, kInner>(0);
|
||||
CompressedArray<float, kOuter * kInner> mat1 = GenerateMat<kOuter, kInner>(1);
|
||||
CompressedArray<float, kOuter * kInner> mat0 =
|
||||
GenerateMat<float, kOuter, kInner>(0, pool);
|
||||
CompressedArray<float, kOuter * kInner> mat1 =
|
||||
GenerateMat<float, kOuter, kInner>(1, pool);
|
||||
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> add0 = GenerateVec<kOuter>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> add1 = GenerateVec<kOuter>(1);
|
||||
|
|
@ -595,9 +549,11 @@ void TestTwoMatVecAdd() {
|
|||
}
|
||||
|
||||
void TestTwoOfsMatVecAddLoop() {
|
||||
hwy::ThreadPool pool(0);
|
||||
constexpr size_t kOuter = 128 * 3;
|
||||
constexpr size_t kInner = 128 * 5;
|
||||
CompressedArray<float, kOuter * kInner> mat = GenerateMat<kOuter, kInner>(0);
|
||||
CompressedArray<float, kOuter * kInner> mat =
|
||||
GenerateMat<float, kOuter, kInner>(0, pool);
|
||||
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> add0 = GenerateVec<kOuter>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> add1 = GenerateVec<kOuter>(1);
|
||||
|
|
@ -650,8 +606,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
|
|||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestTiledMatMul);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatMul);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledMatMul);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop);
|
||||
|
|
|
|||
Loading…
Reference in New Issue