diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 1707bad..2c196cf 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -15,7 +15,6 @@ #include -#include #include #include @@ -112,7 +111,6 @@ TEST(OptimizeTest, GradientDescent) { ReverseSequenceSampler training_task({ 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1}); size_t steps = 0; - float prev_loss = std::numeric_limits::max(); size_t num_ok; for (; steps < 1000000; ++steps) { std::mt19937 sgen(42); @@ -143,7 +141,6 @@ TEST(OptimizeTest, GradientDescent) { if (total_loss < 0.5f) { break; } - prev_loss = total_loss; } printf("Num steps: %zu\n", steps); printf("Final weights:\n"); diff --git a/compression/compress.h b/compression/compress.h index 130e1ad..23340dc 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -41,8 +41,10 @@ namespace gcpp { +using BF16 = hwy::bfloat16_t; + static inline const char* TypeName(float) { return "f32"; } -static inline const char* TypeName(hwy::bfloat16_t) { return "b16"; } +static inline const char* TypeName(BF16) { return "b16"; } namespace detail { // How many MatT are required to store `capacity` weights. For all but @@ -177,11 +179,11 @@ struct CompressWorkingSet { template hwy::uint128_t CacheKey(const char* name) { // Already used/retired: s, S, n, 1 - const char prefix = hwy::IsSame() ? 'F' - : hwy::IsSame() ? 'B' - : hwy::IsSame() ? '$' - : hwy::IsSame() ? '2' - : '?'; + const char prefix = hwy::IsSame() ? 'F' + : hwy::IsSame() ? 'B' + : hwy::IsSame() ? '$' + : hwy::IsSame() ? '2' + : '?'; return MakeKey((std::string(1, prefix) + name).c_str()); } diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index de7a37b..5d1d435 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -23,7 +23,7 @@ #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/profiler.h" +#include "hwy/profiler.h" // temporarily disabled #endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_ @@ -43,107 +43,65 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; +// A square kernel minimizes the ratio of loads to FMA. 4x 128-bit corresponds +// to one cache line. +constexpr size_t kRegRows = 4; +constexpr size_t kRegCols = 4; + +// Initializes a reg-tile of C: if kAdd, `add[add_ofs + c]`; otherwise 0. +// `add` has no scale, and if `kAdd` is a row vector with A.cols entries, +// otherwise nullptr. In the latter case, adding `add_ofs` to it would be UB, +// hence we pass it as a separate argument. +template +HWY_INLINE void InitC(const float* HWY_RESTRICT add, size_t add_ofs, + float* HWY_RESTRICT pos_c, size_t stride_c) { + for (size_t r = 0; r < HWY_MIN(kNumRows, kRegRows); ++r) { + for (size_t c = 0; c < kRegCols; ++c) { + if constexpr (kAdd) { + pos_c[r * stride_c + c] = add[add_ofs + c]; + } else { + pos_c[r * stride_c + c] = 0.0f; + } + } + } +} + // c## are partial sums of the products of A and B; their horizontal sums are // the final matmul result, stored in C, which is always f32. template > -HWY_INLINE void StoreHorizontalSums(DF df, // - VF c00, VF c01, VF c02, VF c03, // - VF c10, VF c11, VF c12, VF c13, // - VF c20, VF c21, VF c22, VF c23, // - VF c30, VF c31, VF c32, VF c33, // - float scale, float* HWY_RESTRICT tile_c, - size_t stride_c) { +HWY_INLINE void AddHorizontalSums(DF df, float scale, // + VF c00, VF c01, VF c02, VF c03, // + VF c10, VF c11, VF c12, VF c13, // + VF c20, VF c21, VF c22, VF c23, // + VF c30, VF c31, VF c32, VF c33, // + float* HWY_RESTRICT tile_c, size_t stride_c) { // We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles. // Each entry of C[r,c] is a dot product of A.row and B.col, which reside in // the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is // expensive, but only a fraction of the A.cols/N FMAs. - tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00); - tile_c[stride_c * 0 + 1] = scale * hn::ReduceSum(df, c01); - tile_c[stride_c * 0 + 2] = scale * hn::ReduceSum(df, c02); - tile_c[stride_c * 0 + 3] = scale * hn::ReduceSum(df, c03); - if (kNumRows == 1) return; - - tile_c[stride_c * 1 + 0] = scale * hn::ReduceSum(df, c10); - tile_c[stride_c * 1 + 1] = scale * hn::ReduceSum(df, c11); - tile_c[stride_c * 1 + 2] = scale * hn::ReduceSum(df, c12); - tile_c[stride_c * 1 + 3] = scale * hn::ReduceSum(df, c13); - if (kNumRows == 2) return; - - tile_c[stride_c * 2 + 0] = scale * hn::ReduceSum(df, c20); - tile_c[stride_c * 2 + 1] = scale * hn::ReduceSum(df, c21); - tile_c[stride_c * 2 + 2] = scale * hn::ReduceSum(df, c22); - tile_c[stride_c * 2 + 3] = scale * hn::ReduceSum(df, c23); - if (kNumRows == 3) return; - - tile_c[stride_c * 3 + 0] = scale * hn::ReduceSum(df, c30); - tile_c[stride_c * 3 + 1] = scale * hn::ReduceSum(df, c31); - tile_c[stride_c * 3 + 2] = scale * hn::ReduceSum(df, c32); - tile_c[stride_c * 3 + 3] = scale * hn::ReduceSum(df, c33); -} - -// As above, but also adds `add[0..3]` to columns 0..3 of `tile_c`. `add` has no -// scale, and points to a 1D slice of the row vector. -template > -HWY_INLINE void StoreHorizontalSumsAdd(DF df, // - VF c00, VF c01, VF c02, VF c03, // - VF c10, VF c11, VF c12, VF c13, // - VF c20, VF c21, VF c22, VF c23, // - VF c30, VF c31, VF c32, VF c33, - const float scale, - const float* HWY_RESTRICT add, - float* HWY_RESTRICT tile_c, - size_t stride_c) { - // We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles. - // Each entry of C[r,c] is a dot product of A.row and B.col, which reside in - // the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is - // expensive, but only a fraction of the A.cols/N FMAs. - const float add0 = add[0]; // TODO: 4x4 transpose, then 128-bit vector FMA? - tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00) + add0; - const float add1 = add[1]; - tile_c[stride_c * 0 + 1] = scale * hn::ReduceSum(df, c01) + add1; - const float add2 = add[2]; - tile_c[stride_c * 0 + 2] = scale * hn::ReduceSum(df, c02) + add2; - const float add3 = add[3]; - tile_c[stride_c * 0 + 3] = scale * hn::ReduceSum(df, c03) + add3; + tile_c[stride_c * 0 + 0] += scale * hn::ReduceSum(df, c00); + tile_c[stride_c * 0 + 1] += scale * hn::ReduceSum(df, c01); + tile_c[stride_c * 0 + 2] += scale * hn::ReduceSum(df, c02); + tile_c[stride_c * 0 + 3] += scale * hn::ReduceSum(df, c03); if (kNumRows == 1) return; - tile_c[stride_c * 1 + 0] = scale * hn::ReduceSum(df, c10) + add0; - tile_c[stride_c * 1 + 1] = scale * hn::ReduceSum(df, c11) + add1; - tile_c[stride_c * 1 + 2] = scale * hn::ReduceSum(df, c12) + add2; - tile_c[stride_c * 1 + 3] = scale * hn::ReduceSum(df, c13) + add3; + tile_c[stride_c * 1 + 0] += scale * hn::ReduceSum(df, c10); + tile_c[stride_c * 1 + 1] += scale * hn::ReduceSum(df, c11); + tile_c[stride_c * 1 + 2] += scale * hn::ReduceSum(df, c12); + tile_c[stride_c * 1 + 3] += scale * hn::ReduceSum(df, c13); if (kNumRows == 2) return; - tile_c[stride_c * 2 + 0] = scale * hn::ReduceSum(df, c20) + add0; - tile_c[stride_c * 2 + 1] = scale * hn::ReduceSum(df, c21) + add1; - tile_c[stride_c * 2 + 2] = scale * hn::ReduceSum(df, c22) + add2; - tile_c[stride_c * 2 + 3] = scale * hn::ReduceSum(df, c23) + add3; + tile_c[stride_c * 2 + 0] += scale * hn::ReduceSum(df, c20); + tile_c[stride_c * 2 + 1] += scale * hn::ReduceSum(df, c21); + tile_c[stride_c * 2 + 2] += scale * hn::ReduceSum(df, c22); + tile_c[stride_c * 2 + 3] += scale * hn::ReduceSum(df, c23); if (kNumRows == 3) return; - tile_c[stride_c * 3 + 0] = scale * hn::ReduceSum(df, c30) + add0; - tile_c[stride_c * 3 + 1] = scale * hn::ReduceSum(df, c31) + add1; - tile_c[stride_c * 3 + 2] = scale * hn::ReduceSum(df, c32) + add2; - tile_c[stride_c * 3 + 3] = scale * hn::ReduceSum(df, c33) + add3; -} - -// Wrapper around StoreHorizontalSums and StoreHorizontalSumsAdd to shorten call -// sites. If `!kAdd`, `add` is nullptr, so adding `add_offset` to it would be -// UB, hence we pass it as a separate argument. -template > -HWY_INLINE void StoreHorizontalSumsMaybeAdd( - DF df, VF c00, VF c01, VF c02, VF c03, VF c10, VF c11, VF c12, VF c13, - VF c20, VF c21, VF c22, VF c23, VF c30, VF c31, VF c32, VF c33, - const float scale, const float* HWY_RESTRICT add, size_t add_offset, - float* HWY_RESTRICT tile_c, size_t stride_c) { - if constexpr (kAdd) { - StoreHorizontalSumsAdd(df, c00, c01, c02, c03, c10, c11, c12, c13, - c20, c21, c22, c23, c30, c31, c32, c33, - scale, add + add_offset, tile_c, stride_c); - } else { - StoreHorizontalSums(df, c00, c01, c02, c03, c10, c11, c12, c13, - c20, c21, c22, c23, c30, c31, c32, c33, - scale, tile_c, stride_c); - } + tile_c[stride_c * 3 + 0] += scale * hn::ReduceSum(df, c30); + tile_c[stride_c * 3 + 1] += scale * hn::ReduceSum(df, c31); + tile_c[stride_c * 3 + 2] += scale * hn::ReduceSum(df, c32); + tile_c[stride_c * 3 + 3] += scale * hn::ReduceSum(df, c33); } // Wrapper to simplify call sites. T can be const or non-const. @@ -176,104 +134,8 @@ Mat MakeMat(T* HWY_RESTRICT ptr, size_t cols) { return MakeMat(ptr, cols, cols); } -#undef GEMMA_NATIVE_BF16 -#if HWY_IDE || (defined(HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16) == \ - defined(HWY_TARGET_TOGGLE)) -#define GEMMA_NATIVE_BF16 1 -#else -#define GEMMA_NATIVE_BF16 0 -#endif - -#if GEMMA_NATIVE_BF16 - -// Specialization for f32 += bf16 * bf16 that avoids promoting to f32. -template -HWY_INLINE void MatMulTile(const Mat& A, - const Mat& B, - const size_t row_a, const size_t row_b_col_c, - const float scale, const float* HWY_RESTRICT add, - const Mat& C) { - const hn::ScalableTag df; - using VF = hn::Vec; - // ReorderWidenMulAccumulate does not use its sum1 arg and we can use full - // bf16 vectors. - const hn::Repartition d; - const size_t N = Lanes(d); - VF unused_sum1 = hn::Zero(df); - VF c00 = hn::Zero(df); - VF c01 = hn::Zero(df); - VF c02 = hn::Zero(df); - VF c03 = hn::Zero(df); - - VF c10 = hn::Zero(df); - VF c11 = hn::Zero(df); - VF c12 = hn::Zero(df); - VF c13 = hn::Zero(df); - - VF c20 = hn::Zero(df); - VF c21 = hn::Zero(df); - VF c22 = hn::Zero(df); - VF c23 = hn::Zero(df); - - VF c30 = hn::Zero(df); - VF c31 = hn::Zero(df); - VF c32 = hn::Zero(df); - VF c33 = hn::Zero(df); - - const hwy::bfloat16_t* HWY_RESTRICT A_tile = A.ptr + A.Row(row_a); - const hwy::bfloat16_t* HWY_RESTRICT B_tile = B.ptr + B.Row(row_b_col_c); - - // Loop over columns of A and columns of the transposed B, in steps of N. - // Accumulates into the c## vectors. - HWY_UNROLL(1) - for (size_t col_ab = 0; col_ab < A.cols; col_ab += N) { - using V = hn::Vec; - const V b0 = hn::LoadU(d, B_tile + B.stride * 0 + col_ab); - const V b1 = hn::LoadU(d, B_tile + B.stride * 1 + col_ab); - const V b2 = hn::LoadU(d, B_tile + B.stride * 2 + col_ab); - const V b3 = hn::LoadU(d, B_tile + B.stride * 3 + col_ab); - - const V a0 = hn::LoadU(d, A_tile + A.stride * 0 + col_ab); - c00 = hn::ReorderWidenMulAccumulate(df, a0, b0, c00, unused_sum1); - c01 = hn::ReorderWidenMulAccumulate(df, a0, b1, c01, unused_sum1); - c02 = hn::ReorderWidenMulAccumulate(df, a0, b2, c02, unused_sum1); - c03 = hn::ReorderWidenMulAccumulate(df, a0, b3, c03, unused_sum1); - if constexpr (kNumRows == 1) continue; - - const V a1 = hn::LoadU(d, A_tile + A.stride * 1 + col_ab); - c10 = hn::ReorderWidenMulAccumulate(df, a1, b0, c10, unused_sum1); - c11 = hn::ReorderWidenMulAccumulate(df, a1, b1, c11, unused_sum1); - c12 = hn::ReorderWidenMulAccumulate(df, a1, b2, c12, unused_sum1); - c13 = hn::ReorderWidenMulAccumulate(df, a1, b3, c13, unused_sum1); - if constexpr (kNumRows == 2) continue; - - const V a2 = hn::LoadU(d, A_tile + A.stride * 2 + col_ab); - c20 = hn::ReorderWidenMulAccumulate(df, a2, b0, c20, unused_sum1); - c21 = hn::ReorderWidenMulAccumulate(df, a2, b1, c21, unused_sum1); - c22 = hn::ReorderWidenMulAccumulate(df, a2, b2, c22, unused_sum1); - c23 = hn::ReorderWidenMulAccumulate(df, a2, b3, c23, unused_sum1); - if constexpr (kNumRows == 3) continue; - - const V a3 = hn::LoadU(d, A_tile + A.stride * 3 + col_ab); - c30 = hn::ReorderWidenMulAccumulate(df, a3, b0, c30, unused_sum1); - c31 = hn::ReorderWidenMulAccumulate(df, a3, b1, c31, unused_sum1); - c32 = hn::ReorderWidenMulAccumulate(df, a3, b2, c32, unused_sum1); - c33 = hn::ReorderWidenMulAccumulate(df, a3, b3, c33, unused_sum1); - } - - // Ensure sum1 was indeed unused. - HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); - - float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_a) + row_b_col_c; - StoreHorizontalSumsMaybeAdd( - df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31, - c32, c33, scale, add, row_b_col_c, C_tile, C.stride); -} - -#endif // GEMMA_NATIVE_BF16 - -// The col_ab loop is unrolled 2x, so we have two consecutive a0/a1 and b00/b01 -// etc. Multiplies a[c] with b[r,c] and adds to c[r]. +// Inner loop of the kernel, called once per kRegRows. c[r] += a[c] * b[r,c]. +// The col_ab loop is unrolled 2x, so we have a0/a1 and b00/b01 etc. template HWY_INLINE void UpdateTileRow(const VF& a0, const VF& a1, const VF& b00, const VF& b01, const VF& b10, const VF& b11, @@ -289,12 +151,153 @@ HWY_INLINE void UpdateTileRow(const VF& a0, const VF& a1, const VF& b00, c3 = hn::MulAdd(a1, b31, c3); } +// Special case for the first iteration: c## are zero, so skip the first add. +template +HWY_INLINE void FirstTileRow(const VF& a0, const VF& a1, const VF& b00, + const VF& b01, const VF& b10, const VF& b11, + const VF& b20, const VF& b21, const VF& b30, + const VF& b31, VF& c0, VF& c1, VF& c2, VF& c3) { + c0 = hn::Mul(a0, b00); + c1 = hn::Mul(a0, b10); + c2 = hn::Mul(a0, b20); + c3 = hn::Mul(a0, b30); + c0 = hn::MulAdd(a1, b01, c0); + c1 = hn::MulAdd(a1, b11, c1); + c2 = hn::MulAdd(a1, b21, c2); + c3 = hn::MulAdd(a1, b31, c3); +} + +#undef GEMMA_NATIVE_BF16 +#if HWY_IDE || (defined(HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16) == \ + defined(HWY_TARGET_TOGGLE)) +#define GEMMA_NATIVE_BF16 1 +#else +#define GEMMA_NATIVE_BF16 0 +#endif + +#if GEMMA_NATIVE_BF16 + +// Specializations for f32 += bf16 * bf16 that avoid promoting to f32. + +// Inner loop as above, but not unrolled. c[r] += a * b[r]. +template , + class VBF16 = hn::Vec>> +HWY_INLINE void UpdateTileRow(DF df, const VBF16& a, const VBF16& b0, + const VBF16& b1, const VBF16& b2, const VBF16& b3, + VF& c0, VF& c1, VF& c2, VF& c3) { + DF df; + VF unused_sum1 = hn::Zero(df); + c0 = hn::ReorderWidenMulAccumulate(df, a, b0, c0, unused_sum1); + c1 = hn::ReorderWidenMulAccumulate(df, a, b1, c1, unused_sum1); + c2 = hn::ReorderWidenMulAccumulate(df, a, b2, c2, unused_sum1); + c3 = hn::ReorderWidenMulAccumulate(df, a, b3, c3, unused_sum1); + + // Ensure sum1 was indeed unused. + HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); +} + +// Special case for the first iteration: c## are zero, so skip the first add. +template , + class VBF16 = hn::Vec>> +HWY_INLINE void FirstTileRow(DF df, const VBF16& a, const VBF16& b0, + const VBF16& b1, const VBF16& b2, const VBF16& b3, + VF& c0, VF& c1, VF& c2, VF& c3) { + c0 = hn::WidenMulPairwiseAdd(df, a, b0); + c1 = hn::WidenMulPairwiseAdd(df, a, b1); + c2 = hn::WidenMulPairwiseAdd(df, a, b2); + c3 = hn::WidenMulPairwiseAdd(df, a, b3); +} + +template +HWY_INLINE void MatMulTile(const Mat& A, const Mat& B, + const size_t row_ac, const size_t row_b_col_c, + const float scale, const float* HWY_RESTRICT add, + const Mat& C) { + const hn::ScalableTag df; + using VF = hn::Vec; + // ReorderWidenMulAccumulate does not use its sum1 arg and we can use full + // bf16 vectors. + const hn::Repartition d; + const size_t N = Lanes(d); + using V = hn::Vec; + V b0, b1, b2, b3; // one from each row + VF c00, c01, c02, c03; + VF c10, c11, c12, c13; + VF c20, c21, c22, c23; + VF c30, c31, c32, c33; + + const BF16* HWY_RESTRICT A_tile = A.ptr + A.Row(row_ac); + const BF16* HWY_RESTRICT B_tile = B.ptr + B.Row(row_b_col_c); + float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_ac) + row_b_col_c; + InitC(add, row_b_col_c, C_tile, C.stride); + + size_t col_ab = 0; + + // First iteration initializes the c## vectors. + { + b0 = hn::LoadU(d, B_tile + B.stride * 0 + col_ab); + b1 = hn::LoadU(d, B_tile + B.stride * 1 + col_ab); + b2 = hn::LoadU(d, B_tile + B.stride * 2 + col_ab); + b3 = hn::LoadU(d, B_tile + B.stride * 3 + col_ab); + + { + const V a = hn::LoadU(d, A_tile + A.stride * 0 + col_ab); + FirstTileRow(df, a, b0, b1, b2, b3, c00, c01, c02, c03); + } + if constexpr (kNumRows > 1) { + const V a = hn::LoadU(d, A_tile + A.stride * 1 + col_ab); + FirstTileRow(df, a, b0, b1, b2, b3, c10, c11, c12, c13); + } + if constexpr (kNumRows > 2) { + const V a = hn::LoadU(d, A_tile + A.stride * 2 + col_ab); + FirstTileRow(df, a, b0, b1, b2, b3, c20, c21, c22, c23); + } + if constexpr (kNumRows == 3) { + const V a = hn::LoadU(d, A_tile + A.stride * 3 + col_ab); + FirstTileRow(df, a, b0, b1, b2, b3, c30, c31, c32, c33); + } + } + + // Loop over columns of A and columns of the transposed B, in steps of N. + // Accumulates into the c## vectors. + HWY_UNROLL(1) + for (col_ab += N; col_ab < A.cols; col_ab += N) { + b0 = hn::LoadU(d, B_tile + B.stride * 0 + col_ab); + b1 = hn::LoadU(d, B_tile + B.stride * 1 + col_ab); + b2 = hn::LoadU(d, B_tile + B.stride * 2 + col_ab); + b3 = hn::LoadU(d, B_tile + B.stride * 3 + col_ab); + + { + const V a = hn::LoadU(d, A_tile + A.stride * 0 + col_ab); + UpdateTileRow(df, a, b0, b1, b2, b3, c00, c01, c02, c03); + } + if constexpr (kNumRows > 1) { + const V a = hn::LoadU(d, A_tile + A.stride * 1 + col_ab); + UpdateTileRow(df, a, b0, b1, b2, b3, c10, c11, c12, c13); + } + if constexpr (kNumRows > 2) { + const V a = hn::LoadU(d, A_tile + A.stride * 2 + col_ab); + UpdateTileRow(df, a, b0, b1, b2, b3, c20, c21, c22, c23); + } + if constexpr (kNumRows == 3) { + const V a = hn::LoadU(d, A_tile + A.stride * 3 + col_ab); + UpdateTileRow(df, a, b0, b1, b2, b3, c30, c31, c32, c33); + } + } + + AddHorizontalSums(df, scale, c00, c01, c02, c03, c10, c11, c12, c13, + c20, c21, c22, c23, c30, c31, c32, c33, C_tile, + C.stride); +} + +#endif // GEMMA_NATIVE_BF16 + // Streams a `(kNumRows, 4)` strip of `A` and the transposed `B`, then writes a // finished tile of `C`. // General case: uses CompressTraits to load from A and B. template HWY_INLINE void MatMulTile(const Mat& A, const Mat& B, - const size_t row_a, const size_t row_b_col_c, + const size_t row_ac, const size_t row_b_col_c, const float scale, const float* HWY_RESTRICT add, const Mat& C) { using TraitsA = CompressTraits>; @@ -303,74 +306,92 @@ HWY_INLINE void MatMulTile(const Mat& A, const Mat& B, const hn::ScalableTag d32; const size_t N = hn::Lanes(d32); using V = hn::Vec; - V c00 = hn::Zero(d32); - V c01 = hn::Zero(d32); - V c02 = hn::Zero(d32); - V c03 = hn::Zero(d32); + V b00, b01, b10, b11, b20, b21, b30, b31; // two from each row + V c00, c01, c02, c03; + V c10, c11, c12, c13; + V c20, c21, c22, c23; + V c30, c31, c32, c33; - V c10 = hn::Zero(d32); - V c11 = hn::Zero(d32); - V c12 = hn::Zero(d32); - V c13 = hn::Zero(d32); - - V c20 = hn::Zero(d32); - V c21 = hn::Zero(d32); - V c22 = hn::Zero(d32); - V c23 = hn::Zero(d32); - - V c30 = hn::Zero(d32); - V c31 = hn::Zero(d32); - V c32 = hn::Zero(d32); - V c33 = hn::Zero(d32); - - const size_t A_ofs = A.Row(row_a); + const size_t A_ofs = A.Row(row_ac); const size_t B_ofs = B.Row(row_b_col_c); + float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_ac) + row_b_col_c; + InitC(add, row_b_col_c, C_tile, C.stride); // Loop over columns of A and columns of the transposed B, in steps of 2*N // (since we are decoding consecutive bytes at each iteration). - // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, - // col_ab) for B. Accumulates into the c## vectors. + // Top-left of tile is (row_ac, col_ab) for A, and (row_b_col_c, + // col_ab) for B. First iteration initializes the c## vectors. size_t col_ab = 0; - HWY_UNROLL(1) - for (; col_ab <= A.cols - 2 * N; col_ab += 2 * N) { - V b00, b01; + { TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 0 + col_ab, b00, b01); - V b10, b11; TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 1 + col_ab, b10, b11); - V b20, b21; TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 2 + col_ab, b20, b21); - V b30, b31; TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 3 + col_ab, b30, b31); - V a00, a01; - TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 0 + col_ab, a00, a01); - UpdateTileRow(a00, a01, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01, - c02, c03); - if constexpr (kNumRows == 1) continue; - - V a10, a11; - TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 1 + col_ab, a10, a11); - UpdateTileRow(a10, a11, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11, - c12, c13); - if constexpr (kNumRows == 2) continue; - - V a20, a21; - TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 2 + col_ab, a20, a21); - UpdateTileRow(a20, a21, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21, - c22, c23); - if constexpr (kNumRows == 3) continue; - - V a30, a31; - TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 3 + col_ab, a30, a31); - UpdateTileRow(a30, a31, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31, - c32, c33); + { + V a0, a1; + TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 0 + col_ab, a0, a1); + FirstTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01, + c02, c03); + } + if constexpr (kNumRows > 1) { + V a0, a1; + TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 1 + col_ab, a0, a1); + FirstTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11, + c12, c13); + } + if constexpr (kNumRows > 2) { + V a0, a1; + TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 2 + col_ab, a0, a1); + FirstTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21, + c22, c23); + } + if constexpr (kNumRows > 3) { + V a0, a1; + TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 3 + col_ab, a0, a1); + FirstTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31, + c32, c33); + } } - float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_a) + row_b_col_c; - StoreHorizontalSumsMaybeAdd( - d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31, - c32, c33, scale, add, row_b_col_c, C_tile, C.stride); + // Main loop: accumulates into the c## vectors. + HWY_UNROLL(1) + for (col_ab += 2 * N; col_ab <= A.cols - 2 * N; col_ab += 2 * N) { + TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 0 + col_ab, b00, b01); + TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 1 + col_ab, b10, b11); + TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 2 + col_ab, b20, b21); + TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 3 + col_ab, b30, b31); + + { + V a0, a1; + TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 0 + col_ab, a0, a1); + UpdateTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01, + c02, c03); + } + if constexpr (kNumRows > 1) { + V a0, a1; + TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 1 + col_ab, a0, a1); + UpdateTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11, + c12, c13); + } + if constexpr (kNumRows > 2) { + V a0, a1; + TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 2 + col_ab, a0, a1); + UpdateTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21, + c22, c23); + } + if constexpr (kNumRows > 3) { + V a0, a1; + TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 3 + col_ab, a0, a1); + UpdateTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31, + c32, c33); + } + } + + AddHorizontalSums(d32, scale, c00, c01, c02, c03, c10, c11, c12, + c13, c20, c21, c22, c23, c30, c31, c32, c33, + C_tile, C.stride); } // Computes the matrix product `A * B * scale [+ add]` and stores it in `C`. @@ -395,10 +416,7 @@ HWY_NOINLINE void MatMul_4x4(const size_t batch_size, const Mat& A, const Mat& B, const float scale, const float* HWY_RESTRICT add, const Mat& C, hwy::ThreadPool& pool) { - PROFILER_ZONE("Matmul"); - constexpr size_t kRegRows = 4; // if changing, also update the switch below. - constexpr size_t kRegCols = 4; - + // PROFILER_ZONE("Matmul"); HWY_DASSERT(A.NotEmpty() && B.NotEmpty() && C.NotEmpty()); HWY_DASSERT(A.cols == B.cols); @@ -417,24 +435,24 @@ HWY_NOINLINE void MatMul_4x4(const size_t batch_size, const Mat& A, [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR { const size_t tx = idx_tile % tilesX; const size_t ty = idx_tile / tilesX; - const size_t row_a = ty * kRegRows; + const size_t row_ac = ty * kRegRows; const size_t row_b_col_c = tx * kRegCols; // How many rows of C are left to compute. If more than 4, this // tile still only computes 4 rows. - const size_t num_rows = batch_size - row_a; + const size_t num_rows = batch_size - row_ac; HWY_DASSERT(num_rows != 0); switch (num_rows) { case 1: - MatMulTile<1, kAdd>(A, B, row_a, row_b_col_c, scale, add, C); + MatMulTile<1, kAdd>(A, B, row_ac, row_b_col_c, scale, add, C); break; case 2: - MatMulTile<2, kAdd>(A, B, row_a, row_b_col_c, scale, add, C); + MatMulTile<2, kAdd>(A, B, row_ac, row_b_col_c, scale, add, C); break; case 3: - MatMulTile<3, kAdd>(A, B, row_a, row_b_col_c, scale, add, C); + MatMulTile<3, kAdd>(A, B, row_ac, row_b_col_c, scale, add, C); break; default: - MatMulTile<4, kAdd>(A, B, row_a, row_b_col_c, scale, add, C); + MatMulTile<4, kAdd>(A, B, row_ac, row_b_col_c, scale, add, C); } }); } diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index ca025bc..b34321e 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -269,7 +269,6 @@ void TestAllMatMul() { } hwy::ThreadPool pool(4); - using BF16 = hwy::bfloat16_t; using F32 = float; using SFP = SfpStream;