diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index b41585e..eab1a4d 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -734,7 +734,8 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved, // Hidden layer -> output layer. auto activations_mat = MakeConstMat( - hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim)); + hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim), + hidden_activations.Stride()); MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out); } @@ -773,8 +774,9 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved, multiplier.Row(0), ff_hidden_dim * num_interleaved); // Hidden layer -> output layer. - auto activations_mat = MakeConstMat( - hidden_activations.Row(0), Extents2D(num_interleaved, ff_hidden_dim)); + auto activations_mat = MakeConstMat(hidden_activations.Row(0), + Extents2D(num_interleaved, ff_hidden_dim), + hidden_activations.Stride()); MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out); } diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index 984e564..b32f3dd 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -133,21 +133,22 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents, // Generates inputs and prints observed throughput of MatMul. // M = A rows, K = A cols, N = C cols. -template +template void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { hwy::ThreadPool& pool = env.parallel.Pools().Pool(0); if (env.print_config || env.print_measurement) { fprintf(stderr, "\n"); } - fprintf(stderr, "BenchMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s\n", M, K, N, - add, TypeName(), TypeName()); + fprintf(stderr, + "BenchMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n", // + M, K, N, add, TypeName(), TypeName(), TypeName()); const Extents2D A_extents(M, K); const Extents2D B_extents(N, K); // already transposed const Extents2D C_extents(M, N); - RowVectorBatch c_slow_batch = AllocateAlignedRows(C_extents); - RowVectorBatch c_batch = AllocateAlignedRows(C_extents); + RowVectorBatch c_slow_batch = AllocateAlignedRows(C_extents); + RowVectorBatch c_batch = AllocateAlignedRows(C_extents); std::unique_ptr> add_storage; if (add) { @@ -156,14 +157,14 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { add_storage->set_scale(1.0f); } - MatStoragePtr a = GenerateMat(A_extents, pool); - MatStoragePtr b_trans = GenerateTransposedMat(B_extents, pool); + MatStoragePtr a = GenerateMat(A_extents, pool); + MatStoragePtr b_trans = GenerateTransposedMat(B_extents, pool); HWY_ASSERT(a && b_trans); const auto A = ConstMatFromWeights(*a); const auto B = ConstMatFromWeights(*b_trans); const float* add_row = add ? add_storage->data_scale1() : nullptr; - const RowPtrF C = RowPtrFromBatch(c_batch); + const RowPtr C = RowPtrFromBatch(c_batch); // Fewer reps for large batch sizes, which take longer. const size_t num_samples = M < 32 ? 20 : 12; @@ -173,7 +174,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { // Ensure usage conditions are set before autotuning. Both binding and // spinning may materially affect the choice of config. No harm in calling // BindB/C if there is a single package: they will be a no-op. - BindB(B_extents.rows, B, env.parallel); + BindB(B_extents.rows, sizeof(TC), B, env.parallel); BindC(A_extents.rows, C, env.parallel); Tristate use_spinning = Tristate::kDefault; @@ -191,7 +192,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { per_key = MatMul(A, B, add_row, env, C); const double t1 = hwy::platform::Now(); double elapsed = t1 - t0; - keep += C.Row(0)[hwy::Unpredictable1()]; + keep += hwy::ConvertScalarTo(C.Row(0)[hwy::Unpredictable1()]); // Only record times after autotuning finished. if (per_key->autotune.Best()) times.push_back(elapsed); @@ -229,8 +230,8 @@ void BenchAllMatMul() { for (size_t batch_size : {1, 4, 128, 512}) { constexpr bool kAdd = false; - BenchMatMul(batch_size, 24576, 3072, kAdd, env); - BenchMatMul(batch_size, 3072, 24576, kAdd, env); + BenchMatMul(batch_size, 24576, 3072, kAdd, env); + BenchMatMul(batch_size, 3072, 24576, kAdd, env); } PROFILER_PRINT_RESULTS(); diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 401dcd1..782c3e7 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -46,7 +46,7 @@ namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; // Like hn::PromoteOddTo, but uses assembly to avoid an extra vector register. -template > +template > static hn::VFromD FastPromoteOddTo(DF df, hn::VFromD vbf) { // Promoting odd means clearing the lower 16 bits. Doing this via AND // requires a second input vector, which we prefer to avoid due to high @@ -70,6 +70,16 @@ static hn::VFromD FastPromoteOddTo(DF df, hn::VFromD vbf) { #endif } +// Converts from float intermediate to MatMul output type `TC`. +template , HWY_IF_F32_D(DC)> +hn::Vec TCFromF32(DC /*dc*/, hn::Vec vf) { + return vf; +} +template , HWY_IF_BF16_D(DC)> +hn::Vec TCFromF32(DC dc, hn::Vec vf) { + return hn::DemoteTo(dc, vf); +} + // Tag classes, passed to `MMKernel::A2C0` to choose between writing one // (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or writing the // first kc result to partial, or accumulating the next kc result into partial @@ -93,14 +103,14 @@ class MMStoreHorizontalSumsIntoC { // Thus we compute the horizontal sums of each `Crc`. The elements may be // permuted because we multiply bf16 via `ReorderWidenMulAccumulate`, but // this does not change their horizontal sum. - template > + template , typename TC> HWY_INLINE void operator()(DF df, // VF C00, VF C01, VF C02, VF C03, // VF C10, VF C11, VF C12, VF C13, // VF C20, VF C21, VF C22, VF C23, // VF C30, VF C31, VF C32, VF C33, // const size_t row_c, const size_t col_c, - const MMArgs& args) const { + const MMArgs& args, const RowPtr& C) const { float buf[16 * hn::MaxLanes(df)]; const size_t N = hn::Lanes(df); // Horizontal reductions (`ReduceSum`) are rather expensive, entailing @@ -136,10 +146,10 @@ class MMStoreHorizontalSumsIntoC { if constexpr (kAdd) { vadd = hn::Load(d4, args.add + col_c); } - MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, args.C, row_c, col_c); - MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, args.C, row_c, col_c); - MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, args.C, row_c, col_c); - MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, args.C, row_c, col_c); + MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, C, row_c, col_c); + MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, C, row_c, col_c); + MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, C, row_c, col_c); + MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, C, row_c, col_c); } private: @@ -155,34 +165,36 @@ class MMStoreHorizontalSumsIntoC { } // Note: N is the number of lanes in the StoreInterleaved4 vectors, not V4. - template > - static HWY_INLINE V4 MaybeLoad(D4 d4, size_t N, - const float* HWY_RESTRICT buf) { + template > + static HWY_INLINE VF4 MaybeLoad(DF4 df4, size_t N, + const float* HWY_RESTRICT buf) { if constexpr (kRow < kRowsAC) { - return hn::Load(d4, buf + 4 * kRow * N); + return hn::Load(df4, buf + 4 * kRow * N); } else { - return hn::Zero(d4); + return hn::Zero(df4); } } - template > - static HWY_INLINE V4 MaybeAdd(D4 d4, size_t N, V4 sum, - const float* HWY_RESTRICT buf) { + template > + static HWY_INLINE VF4 MaybeAdd(DF4 df4, size_t N, VF4 sum, + const float* HWY_RESTRICT buf) { if constexpr (kRow < kRowsAC) { - return hn::Add(sum, hn::Load(d4, buf + 4 * kRow * N)); + return hn::Add(sum, hn::Load(df4, buf + 4 * kRow * N)); } else { return sum; } } - template > - static HWY_INLINE void MaybeScaleAndStore(D4 d4, V4 sum, V4 vscale, V4 vadd, - const RowPtrF& C, + template > + static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale, + VF4 vadd, const RowPtr& C, const size_t row_c, const size_t col_c) { if constexpr (kRow < kRowsAC) { - float* HWY_RESTRICT pos = C.Row(row_c + kRow) + col_c; - hn::Store(hn::MulAdd(sum, vscale, vadd), d4, pos); + TC* HWY_RESTRICT pos = C.Row(row_c + kRow) + col_c; + const hn::Rebind dc4; + const VF4 out = hn::MulAdd(sum, vscale, vadd); + hn::Store(TCFromF32(dc4, out), dc4, pos); } } }; // MMStoreHorizontalSumsIntoC @@ -340,14 +352,14 @@ class MMKernel { // or less on ISAs with fewer registers, or for the last few rows of A. static constexpr size_t kMaxMR = 4; - // Calls `LoopOverKC` for each of `mc` rows of A in steps of `mr`. `A_view` + // Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` // is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0. // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. - template + template static HWY_INLINE void A2C0(const RowPtrBF& A_view, const RowPtrBF& B_view, size_t mr, const IndexRange& range_mc, const size_t row_b, size_t kc, Tag tag, - const MMArgs& args) { + const MMArgs& args, const RowPtr& C) { HWY_DASSERT(1 <= mr && mr <= kMaxMR); const size_t row0 = range_mc.begin(); const size_t mc = range_mc.Num(); @@ -356,7 +368,7 @@ class MMKernel { // M == 1, or x86 with 8 SIMD registers: if (HWY_UNLIKELY(mr == 1)) { for (; imc < mc; ++imc) { - LoopOverKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args); + LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C); } return; } @@ -365,11 +377,11 @@ class MMKernel { if (HWY_UNLIKELY(mr == 2)) { if (HWY_LIKELY(mc >= 2)) { for (; imc <= mc - 2; imc += 2) { - LoopOverKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args); + LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C); } } if (HWY_UNLIKELY(imc != mc)) { - LoopOverKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args); + LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C); } return; } @@ -377,17 +389,17 @@ class MMKernel { HWY_DASSERT(mr == 4); if (HWY_LIKELY(mc >= 4)) { for (; imc <= mc - 4; imc += 4) { - LoopOverKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args); + LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C); } } const size_t remainder_mc = mc - imc; HWY_DASSERT(remainder_mc < 4); if (HWY_UNLIKELY(remainder_mc & 2)) { - LoopOverKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args); + LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C); imc += 2; } if (HWY_UNLIKELY(remainder_mc & 1)) { - LoopOverKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args); + LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C); imc += 1; } HWY_DASSERT(imc == mc); @@ -484,11 +496,11 @@ class MMKernel { // with top-left corner `partial.Row(row_ac) + col_c`. Both A and B must be // BF16 so we can load directly without `Decompress2`, which is expensive for // NUQ and requires 2x unrolling, which requires more loads. - template - static HWY_INLINE void LoopOverKC(const RowPtrBF& A_view, - const RowPtrBF& B_view, size_t row_ac, - size_t imc, size_t col_c, size_t kc, - Tag tag, const MMArgs& args) { + template + static HWY_INLINE void LoopKC(const RowPtrBF& A_view, const RowPtrBF& B_view, + size_t row_ac, size_t imc, size_t col_c, + size_t kc, Tag tag, const MMArgs& args, + const RowPtr& C) { const hn::ScalableTag dbf; using VBF = hn::Vec; const size_t NBF = hn::Lanes(dbf); @@ -602,11 +614,11 @@ class MMKernel { if (args.add) { MMStoreHorizontalSumsIntoC()( df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, - C31, C32, C33, row_ac, col_c, args); + C31, C32, C33, row_ac, col_c, args, C); } else { MMStoreHorizontalSumsIntoC()( df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, - C31, C32, C33, row_ac, col_c, args); + C31, C32, C33, row_ac, col_c, args, C); } } else { MMAddHorizontalSumsIntoPartial()( @@ -627,40 +639,44 @@ class MMScaleDemoteAdd { // TODO: fuse with subsequent operations - function pointer? // Although this region in `outputs.C` is not touched again, streaming stores // do not help on SKX and Zen4. TODO: re-check this. + template static HWY_INLINE void FillC(const IndexRange& range_mc, - const IndexRange& range_nc, const MMArgs& args) { + const IndexRange& range_nc, const MMArgs& args, + const RowPtr& C) { size_t row_c = range_mc.begin(); if (args.add) { constexpr bool kAdd = true; if (range_mc.Num() >= 4) { for (; row_c <= range_mc.end() - 4; row_c += 4) { - Do4Rows(row_c, range_nc, args); + Do4Rows(row_c, range_nc, args, C); } } for (; row_c < range_mc.end(); ++row_c) { - Do1Row(row_c, range_nc, args); + Do1Row(row_c, range_nc, args, C); } } else { constexpr bool kAdd = false; if (range_mc.Num() >= 4) { for (; row_c <= range_mc.end() - 4; row_c += 4) { - Do4Rows(row_c, range_nc, args); + Do4Rows(row_c, range_nc, args, C); } } for (; row_c < range_mc.end(); ++row_c) { - Do1Row(row_c, range_nc, args); + Do1Row(row_c, range_nc, args, C); } } } private: // Unrolled for 4 rows to reduce the number of loads from `add`. - template + template static HWY_INLINE void Do4Rows(size_t row_c, const IndexRange& range_nc, - const MMArgs& args) { + const MMArgs& args, const RowPtr& C) { const hn::ScalableTag dd; const hn::Rebind df; // result of DemoteTo + const hn::Rebind dc; using VD = hn::Vec; + using VF = hn::Vec; const size_t ND = hn::Lanes(dd); const VD vscale = hn::Set(dd, args.scale); @@ -669,10 +685,10 @@ class MMScaleDemoteAdd { const double* HWY_RESTRICT pr2 = args.partial.Row(row_c + 2); const double* HWY_RESTRICT pr3 = args.partial.Row(row_c + 3); - float* HWY_RESTRICT cr0 = args.C.Row(row_c + 0); - float* HWY_RESTRICT cr1 = args.C.Row(row_c + 1); - float* HWY_RESTRICT cr2 = args.C.Row(row_c + 2); - float* HWY_RESTRICT cr3 = args.C.Row(row_c + 3); + TC* HWY_RESTRICT cr0 = C.Row(row_c + 0); + TC* HWY_RESTRICT cr1 = C.Row(row_c + 1); + TC* HWY_RESTRICT cr2 = C.Row(row_c + 2); + TC* HWY_RESTRICT cr3 = C.Row(row_c + 3); // We manually unroll 2x for higher IPC in batch=1. size_t col_c = range_nc.begin(); @@ -713,15 +729,24 @@ class MMScaleDemoteAdd { m30 = hn::Mul(d30, vscale); m31 = hn::Mul(d31, vscale); } + // First convert f64 to f32. + const VF f00 = hn::DemoteTo(df, m00); + const VF f01 = hn::DemoteTo(df, m01); + const VF f10 = hn::DemoteTo(df, m10); + const VF f11 = hn::DemoteTo(df, m11); + const VF f20 = hn::DemoteTo(df, m20); + const VF f21 = hn::DemoteTo(df, m21); + const VF f30 = hn::DemoteTo(df, m30); + const VF f31 = hn::DemoteTo(df, m31); // Note that Stream is neutral on SKX and harmful on Zen4. - hn::Store(hn::DemoteTo(df, m00), df, cr0 + col_c); - hn::Store(hn::DemoteTo(df, m01), df, cr0 + col_c + ND); - hn::Store(hn::DemoteTo(df, m10), df, cr1 + col_c); - hn::Store(hn::DemoteTo(df, m11), df, cr1 + col_c + ND); - hn::Store(hn::DemoteTo(df, m20), df, cr2 + col_c); - hn::Store(hn::DemoteTo(df, m21), df, cr2 + col_c + ND); - hn::Store(hn::DemoteTo(df, m30), df, cr3 + col_c); - hn::Store(hn::DemoteTo(df, m31), df, cr3 + col_c + ND); + hn::Store(TCFromF32(dc, f00), dc, cr0 + col_c); + hn::Store(TCFromF32(dc, f01), dc, cr0 + col_c + ND); + hn::Store(TCFromF32(dc, f10), dc, cr1 + col_c); + hn::Store(TCFromF32(dc, f11), dc, cr1 + col_c + ND); + hn::Store(TCFromF32(dc, f20), dc, cr2 + col_c); + hn::Store(TCFromF32(dc, f21), dc, cr2 + col_c + ND); + hn::Store(TCFromF32(dc, f30), dc, cr3 + col_c); + hn::Store(TCFromF32(dc, f31), dc, cr3 + col_c + ND); } } @@ -750,25 +775,31 @@ class MMScaleDemoteAdd { m20 = hn::Mul(d20, vscale); m30 = hn::Mul(d30, vscale); } - hn::StoreN(hn::DemoteTo(df, m00), df, cr0 + col_c, remaining); - hn::StoreN(hn::DemoteTo(df, m10), df, cr1 + col_c, remaining); - hn::StoreN(hn::DemoteTo(df, m20), df, cr2 + col_c, remaining); - hn::StoreN(hn::DemoteTo(df, m30), df, cr3 + col_c, remaining); + // First convert f64 to f32. + const VF f00 = hn::DemoteTo(df, m00); + const VF f10 = hn::DemoteTo(df, m10); + const VF f20 = hn::DemoteTo(df, m20); + const VF f30 = hn::DemoteTo(df, m30); + hn::StoreN(TCFromF32(dc, f00), dc, cr0 + col_c, remaining); + hn::StoreN(TCFromF32(dc, f10), dc, cr1 + col_c, remaining); + hn::StoreN(TCFromF32(dc, f20), dc, cr2 + col_c, remaining); + hn::StoreN(TCFromF32(dc, f30), dc, cr3 + col_c, remaining); } } // Same as above but handles a single row (for remainder rows). - template + template static HWY_INLINE void Do1Row(size_t row_c, const IndexRange& range_nc, - const MMArgs& args) { + const MMArgs& args, const RowPtr& C) { const hn::ScalableTag dd; const hn::Rebind df; // result of DemoteTo + const hn::Rebind dc; using VD = hn::Vec; + using VF = hn::Vec; const size_t ND = hn::Lanes(dd); const VD vscale = hn::Set(dd, args.scale); - const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0); - float* HWY_RESTRICT cr0 = args.C.Row(row_c + 0); + TC* HWY_RESTRICT cr0 = C.Row(row_c + 0); // We manually unroll 2x for higher IPC in batch=1. size_t col_c = range_nc.begin(); @@ -791,9 +822,12 @@ class MMScaleDemoteAdd { m00 = hn::Mul(d00, vscale); m01 = hn::Mul(d01, vscale); } + // First convert f64 to f32. + const VF f00 = hn::DemoteTo(df, m00); + const VF f01 = hn::DemoteTo(df, m01); // Note that Stream is neutral on SKX and harmful on Zen4. - hn::Store(hn::DemoteTo(df, m00), df, cr0 + col_c); - hn::Store(hn::DemoteTo(df, m01), df, cr0 + col_c + ND); + hn::Store(TCFromF32(dc, f00), dc, cr0 + col_c); + hn::Store(TCFromF32(dc, f01), dc, cr0 + col_c + ND); } } @@ -813,7 +847,9 @@ class MMScaleDemoteAdd { } else { m00 = hn::Mul(d00, vscale); } - hn::StoreN(hn::DemoteTo(df, m00), df, cr0 + col_c, remaining); + // First convert f64 to f32. + const VF f00 = hn::DemoteTo(df, m00); + hn::StoreN(TCFromF32(dc, f00), dc, cr0 + col_c, remaining); } } }; // MMScaleDemoteAdd @@ -849,20 +885,21 @@ class MMPerPackage { // B is decompressed several call layers lower, but not all member functions // depend on TB, so pass it as an argument instead of templating the class. - template - HWY_NOINLINE void operator()(const ConstMat& B) const { + template + HWY_NOINLINE void operator()(const ConstMat& B, + const RowPtr& C) const { // TODO: include NUQ tables? NumPacked in ConstMat? const size_t num_packed_B = B.ofs + B.Stride() * B.Extents().rows; switch (order_) { case MMOrder::kNT: - return DoNT(B, num_packed_B); + return DoNT(B, num_packed_B, C); case MMOrder::kNT_K: - return DoNT_K(B, num_packed_B); + return DoNT_K(B, num_packed_B, C); case MMOrder::kNT_MT: - return DoNT_MT(B, num_packed_B); + return DoNT_MT(B, num_packed_B, C); case MMOrder::kNT_MT_K: - return DoNT_MT_K(B, num_packed_B); + return DoNT_MT_K(B, num_packed_B, C); default: HWY_UNREACHABLE; } @@ -878,13 +915,14 @@ class MMPerPackage { // Granularity of `ForNP`. B rows produce C columns, so we // want a multiple of the line size to prevent false sharing. - static size_t MultipleNP() { - return HWY_MAX(kNR, Allocator::LineBytes() / sizeof(float)); + static size_t MultipleNP(size_t sizeof_TC) { + return HWY_MAX(kNR, Allocator::LineBytes() / sizeof_TC); } // Single M and K, parallel N. Fills all of C directly. - template - HWY_INLINE void DoNT(const ConstMat& B, size_t num_packed_B) const { + template + HWY_INLINE void DoNT(const ConstMat& B, size_t num_packed_B, + const RowPtr& C) const { MMZone zone; zone.MaybeEnter("MM.NT", args_); HWY_DASSERT(ranges_mc_.NumTasks() == 1); @@ -897,7 +935,7 @@ class MMPerPackage { // Similar to `loop_nc` below, but here we hoisted `A_view`. args_.env->parallel.ForNP( - range_np_, MultipleNP(), inner_tasks_, pkg_idx_, + range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_, [&](const IndexRange& range_nc) HWY_ATTR { HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS const RowPtrBF B_view(B_storage, K, B_stride); @@ -910,7 +948,7 @@ class MMPerPackage { DecompressB(B, num_packed_B, row_b, range_K, B_view); } MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(), - args_); + args_, C); } }); @@ -918,8 +956,9 @@ class MMPerPackage { } // Single M, parallel N, sequential K. Fills all of partial. - template - HWY_INLINE void DoNT_K(const ConstMat& B, size_t num_packed_B) const { + template + HWY_INLINE void DoNT_K(const ConstMat& B, size_t num_packed_B, + const RowPtr& C) const { MMZone zone; zone.MaybeEnter("MM.NT_K", args_); HWY_DASSERT(ranges_mc_.NumTasks() == 1); @@ -942,13 +981,13 @@ class MMPerPackage { zone.MaybeEnter("MM.NT_K.DecB", args_); DecompressB(B, num_packed_B, row_b, range_kc, B_view); } - MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, - args_); + MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, + C); } }; args_.env->parallel.ForNP( - range_np_, MultipleNP(), inner_tasks_, pkg_idx_, + range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_, [&](const IndexRange& range_nc) HWY_ATTR { HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS @@ -965,13 +1004,13 @@ class MMPerPackage { MMZone fill_zone; if (out_ == MMOut::kCopy) { fill_zone.MaybeEnter("MM.NT_K.FillC", args_); - MMScaleDemoteAdd::FillC(range_mc, range_np_, args_); + MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C); } else if (out_ == MMOut::kParM) { fill_zone.MaybeEnter("MM.NT_K.FillC.ParM", args_); args_.env->parallel.ForRangeMC( range_mc, pkg_idx_, [&](size_t row_a) HWY_ATTR { MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_, - args_); + args_, C); }); } else { HWY_UNREACHABLE; // kDirect is only used with kNT. @@ -980,8 +1019,9 @@ class MMPerPackage { // Parallel loops over mc/nc blocks of M/range_np, single K. // Fills `mc x nc` sections of C directly, in parallel. - template - HWY_INLINE void DoNT_MT(const ConstMat& B, size_t num_packed_B) const { + template + HWY_INLINE void DoNT_MT(const ConstMat& B, size_t num_packed_B, + const RowPtr& C) const { MMZone zone; zone.MaybeEnter("MM.NT_MT", args_); HWY_DASSERT(ranges_kc_.NumTasks() == 1); @@ -1006,7 +1046,7 @@ class MMPerPackage { DecompressB(B, num_packed_B, row_b, range_K, B_view); } MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(), - args_); + args_, C); } }); @@ -1015,8 +1055,9 @@ class MMPerPackage { // Parallel loops over mc/nc blocks of M/range_np, sequential K. // Fills `mc x nc` sections of `partial`, then `C`, in parallel. - template - HWY_INLINE void DoNT_MT_K(const ConstMat& B, size_t num_packed_B) const { + template + HWY_INLINE void DoNT_MT_K(const ConstMat& B, size_t num_packed_B, + const RowPtr& C) const { MMZone zone; zone.MaybeEnter("MM.NT_MT_K", args_); const size_t kc_max = ranges_kc_.TaskSize(); @@ -1039,8 +1080,8 @@ class MMPerPackage { zone.MaybeEnter("MM.NT_MT_K.DecB", args_); DecompressB(B, num_packed_B, row_b, range_kc, B_view); } - MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, - args_); + MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, + C); } }; // loop_nc args_.env->parallel.ForRangesMC_NC( @@ -1063,7 +1104,7 @@ class MMPerPackage { HWY_DASSERT(out_ == MMOut::kCopy); MMZone fill_zone; fill_zone.MaybeEnter("MM.NT_MT_K.FillC", args_); - MMScaleDemoteAdd::FillC(range_mc, range_nc, args_); + MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C); }); } @@ -1083,6 +1124,8 @@ class MMPerPackage { const IndexRange& range_K) HWY_ATTR { const size_t col0 = range_K.begin(); const size_t cols = range_K.Num(); + // otherwise, padding overwrites neighbors + HWY_DASSERT(cols % NBF == 0 || cols == A.extents.cols); for (size_t row_a : range_M) { const PackedSpan from = MakeSpan(A.ptr + A.Row(row_a) + col0, cols); @@ -1224,9 +1267,10 @@ struct MMImpl { // Called from `MatMul` from two places: either with the next autotune config, // or with the best config. - template + template static HWY_NOINLINE void DoMatMul(const ConstMat& A, - const ConstMat& B, const MMArgs& args, + const ConstMat& B, const RowPtr& C, + const MMArgs& args, const MMConfig& config) { MMZone matmul_zone; matmul_zone.MaybeEnter("MM.DoMatMul", args); @@ -1235,7 +1279,7 @@ struct MMImpl { args.env->parallel.ForPkg( args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) { const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - MMPerPackage(A, args, config, pkg_idx, range_np)(B); + MMPerPackage(A, args, config, pkg_idx, range_np)(B, C); }); } }; @@ -1260,10 +1304,10 @@ struct MMImpl { // invalidated by the next call to `MatMul`. // // Uses considerable stack space: at least 40 KiB per thread. -template +template HWY_NOINLINE MMPerKey* MatMul(const ConstMat& A, const ConstMat& B, const float* HWY_RESTRICT add, MatMulEnv& env, - const RowPtrF& C) { + const RowPtr& C) { const size_t M = A.Extents().rows; const size_t K = A.Extents().cols; const size_t N = B.Extents().rows; @@ -1281,15 +1325,16 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat& A, const ConstMat& B, // invalidates `MMAutoTune::Best()` index = env.per_key.size(); - env.per_key.push_back(MMPerKey(max_packages, N, kNR, env.parallel)); + env.per_key.push_back( + MMPerKey(max_packages, N, sizeof(TC), kNR, env.parallel)); } MMPerKey& per_key = env.per_key[index]; MMAutoTune& tuner = per_key.autotune; const MMArgs args(env, per_key, static_cast(A.scale) * B.scale, add, - env.storage.Partial(), C); + env.storage.Partial()); if (HWY_LIKELY(tuner.Best())) { - MMImpl::DoMatMul(A, B, args, *tuner.Best()); + MMImpl::DoMatMul(A, B, C, args, *tuner.Best()); return &per_key; } @@ -1306,13 +1351,13 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat& A, const ConstMat& B, HWY_ASSERT(N % kNR == 0); // Negligible CPU time. - tuner.SetCandidates(MMCandidates(M, K, N, MMKernel::kMaxMR, kNR, + tuner.SetCandidates(MMCandidates(M, K, N, sizeof(TC), MMKernel::kMaxMR, kNR, per_key.ranges_np, env.print_config)); } const MMConfig& cfg = tuner.NextConfig(); const uint64_t t0 = hwy::timer::Start(); - MMImpl::DoMatMul(A, B, args, cfg); + MMImpl::DoMatMul(A, B, C, args, cfg); const uint64_t t1 = env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); const double min_elapsed = static_cast(tuner.NotifyTicks(t1 - t0)) / diff --git a/ops/matmul.cc b/ops/matmul.cc index 80f1d8d..678da56 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -60,10 +60,13 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim, // and holds most of their arguments in member variables. class GenerateCandidates { public: - GenerateCandidates(size_t M, size_t K, size_t N, size_t max_mr, size_t nr, + GenerateCandidates(size_t M, size_t K, size_t N, size_t sizeof_TC, + size_t max_mr, size_t nr, const IndexRangePartition& ranges_np, bool print_config) : M_(M), K_(K), + N_(N), + sizeof_TC_(sizeof_TC), max_mr_(max_mr), nr_(nr), // These influence kc/nc, but are also stored in `MMConfig` for @@ -71,7 +74,7 @@ class GenerateCandidates { // is likely still in L1, but we expect K > 1000 and might as well round // up to the line size. kc_multiple_(HWY_MIN(K, Allocator::LineBytes() / sizeof(BF16))), - nc_multiple_(Allocator::StepBytes() / sizeof(float)), + nc_multiple_(Allocator::StepBytes() / sizeof_TC), ranges_np_(ranges_np), print_config_(print_config) {} @@ -88,7 +91,7 @@ class GenerateCandidates { for (size_t nc : NC(mr, mc, kc, order)) { for (int inner_tasks : all_inner_tasks) { for (MMOut out : all_outs) { - const MMConfig config(K_, mr, mc, kc, nc, kc_multiple_, + const MMConfig config(K_, N_, mr, mc, kc, nc, kc_multiple_, nc_multiple_, order, out, inner_tasks); const size_t M_tasks = config.RangesOfMC(M_).NumTasks(); const size_t K_tasks = config.RangesOfKC(K_).NumTasks(); @@ -114,7 +117,7 @@ class GenerateCandidates { private: using SizeVec = std::vector; - // How many rows of A per call to `MMKernel::LoopOverKC`. Lower values may + // How many rows of A per call to `MMKernel::LoopKC`. Lower values may // be better for SIMD targets with fewer registers. SizeVec MR() const { const int64_t target = hwy::DispatchedTarget(); @@ -153,7 +156,7 @@ class GenerateCandidates { // The number of A and B columns to read between updating `partial`. SizeVec KC(size_t mr, MMOrder order) const { - // `LoopOverKC` handles up to `mr` rows of A. + // `LoopKC` handles up to `mr` rows of A. const size_t rows_a = HWY_MIN(M_, mr); // After looping over `kc` columns, we write `mr x 4` outputs and 16 vector @@ -161,9 +164,9 @@ class GenerateCandidates { // is important that B fits in L1, because batch=1 only has a single row of // A and thus no reuse of the packed B. When L1-resident, we can use the // separate `DecompressAndZeroPad` to write `kc` columns, rather than having - // to integrate `Decompress2` into `LoopOverKC`, which is less efficient for + // to integrate `Decompress2` into `LoopKC`, which is less efficient for // TB=NUQ due to less amortization of the table loads. Due to the low L1 - // latency, the packing is still effectively fused into `LoopOverKC`. It may + // latency, the packing is still effectively fused into `LoopKC`. It may // be better to round up and accept a few L2 accesses in exchange for // fewer loops over K, and thus fewer writes to `partial`. Hence we do not // subtract the output and buf, and allow using more than the actual L1 @@ -255,7 +258,7 @@ class GenerateCandidates { SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const { const size_t np_max = ranges_np_.TaskSize(); size_t nc_max = np_max; - const size_t out_bytes = IsOneKC(order) ? sizeof(float) : sizeof(double); + const size_t out_bytes = IsOneKC(order) ? sizeof_TC_ : sizeof(double); // Only if there will be reuse of B: choose the largest `nc_max` (C cols) // such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3. // Otherwise, leave it unbounded. @@ -350,6 +353,8 @@ class GenerateCandidates { const size_t M_; const size_t K_; + const size_t N_; + const size_t sizeof_TC_; const size_t max_mr_; const size_t nr_; @@ -365,23 +370,25 @@ class GenerateCandidates { } // namespace // Facade to avoid exposing `GenerateCandidates` in the header. -std::vector MMCandidates(size_t M, size_t K, size_t N, size_t max_mr, - size_t nr, +std::vector MMCandidates(size_t M, size_t K, size_t N, + size_t sizeof_TC, size_t max_mr, size_t nr, const IndexRangePartition& ranges_np, bool print_config) { - return GenerateCandidates(M, K, N, max_mr, nr, ranges_np, print_config)(); + return GenerateCandidates(M, K, N, sizeof_TC, max_mr, nr, ranges_np, + print_config)(); } // Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote // memory accesses or false sharing, unless there are insufficient per-package // rows for that. -static size_t NPMultiple(size_t N, size_t nr, size_t num_packages) { - size_t np_multiple = Allocator::QuantumBytes() / sizeof(float); +static size_t NPMultiple(size_t N, size_t sizeof_TC, size_t nr, + size_t num_packages) { + size_t np_multiple = Allocator::QuantumBytes() / sizeof_TC; // If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For // `N` < 4096, this can cause significant load imbalance. If split unevenly, // choose a smaller multiple. if (N % (np_multiple * num_packages)) { - const size_t min_multiple = Allocator::LineBytes() / sizeof(float); + const size_t min_multiple = Allocator::LineBytes() / sizeof_TC; np_multiple = PrevDivisor(min_multiple, np_multiple, N / num_packages, min_multiple); if (HWY_UNLIKELY(np_multiple == 0)) { @@ -398,10 +405,10 @@ static size_t NPMultiple(size_t N, size_t nr, size_t num_packages) { } IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N, - size_t nr) const { + size_t sizeof_TC, size_t nr) const { const size_t num_packages = HWY_MIN(max_packages, pools_.NumPackages()); return StaticPartition(IndexRange(0, N), num_packages, - NPMultiple(N, nr, num_packages)); + NPMultiple(N, sizeof_TC, nr, num_packages)); } MatMulEnv::MatMulEnv(NestedPools& pools) : parallel(pools), storage(parallel) { diff --git a/ops/matmul.h b/ops/matmul.h index 707e37b..72ccd4a 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -60,7 +60,7 @@ class MMParallel { // Initial static partitioning of B rows across packages. IndexRangePartition RangesOfNP(size_t max_packages, size_t N, - size_t nr) const; + size_t sizeof_TC, size_t nr) const; // For `BindB` and `BindC`. size_t Node(size_t pkg_idx) const { @@ -170,13 +170,13 @@ class MMParallel { NestedPools& pools_; }; -template // float for C, double for partial -void BindC(size_t M, const RowPtr& C, MMParallel& parallel) { +template // BF16/float for C, double for partial +void BindC(size_t M, const RowPtr& C, MMParallel& parallel) { if (!Allocator::ShouldBind()) return; const IndexRangePartition ranges_np = - parallel.RangesOfNP(MMParallel::kMaxPackages, C.Cols(), kNR); - const size_t quantum = Allocator::QuantumBytes() / sizeof(T); + parallel.RangesOfNP(MMParallel::kMaxPackages, C.Cols(), sizeof(TC), kNR); + const size_t quantum = Allocator::QuantumBytes() / sizeof(TC); bool ok = true; for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { const IndexRange& cols_c = ranges_np.Range(pkg_idx); @@ -185,7 +185,7 @@ void BindC(size_t M, const RowPtr& C, MMParallel& parallel) { // BindRowsToPackageNodes may not be page-aligned. const size_t begin = hwy::RoundUpTo(cols_c.begin(), quantum); const size_t end = hwy::RoundDownTo(cols_c.end(), quantum); - ok &= Allocator::BindMemory(C.Row(im) + begin, (end - begin) * sizeof(T), + ok &= Allocator::BindMemory(C.Row(im) + begin, (end - begin) * sizeof(TC), node); } } @@ -355,11 +355,11 @@ static inline const char* StringFromParA(MMParA par_a) { class MMConfig { public: MMConfig() = default; // for std::vector - // `mr` is the number of A rows per call to `MMKernel::LoopOverKC`. + // `mr` is the number of A rows per call to `MMKernel::LoopKC`. // `MMOrder` is how to parallelize the outer loops. // `MMOut` is how/whether to parallelize filling the C result. // `inner_tasks` chooses the within-cluster task granularity in `ForNP`. - MMConfig(size_t K, size_t mr, size_t mc, size_t kc, size_t nc, + MMConfig(size_t K, size_t N, size_t mr, size_t mc, size_t kc, size_t nc, size_t kc_multiple, size_t nc_multiple, MMOrder order, MMOut out, int inner_tasks) : mr_(static_cast(mr)), @@ -381,7 +381,7 @@ class MMConfig { if (kc != K && (kc % kc_multiple) != 0) { HWY_WARN("kc %zu not a multiple of kc_multiple %zu", kc, kc_multiple); } - if (nc % nc_multiple != 0) { + if (nc != N && (nc % nc_multiple) != 0) { HWY_WARN("nc %zu not a multiple of nc_multiple %zu", nc, nc_multiple); } HWY_DASSERT(StringFromOrder(order_) != nullptr); @@ -428,8 +428,8 @@ class MMConfig { static_assert(sizeof(MMConfig) == 32); // for faster indexing #pragma pack(pop) -std::vector MMCandidates(size_t M, size_t K, size_t N, size_t max_mr, - size_t nr, +std::vector MMCandidates(size_t M, size_t K, size_t N, + size_t sizeof_TC, size_t max_mr, size_t nr, const IndexRangePartition& ranges_np, bool print_config); @@ -588,8 +588,9 @@ class MMKeys { // Per-MatMul-shape state. struct MMPerKey { - MMPerKey(size_t max_packages, size_t N, size_t nr, MMParallel& parallel) - : ranges_np(parallel.RangesOfNP(max_packages, N, nr)) {} + MMPerKey(size_t max_packages, size_t N, size_t sizeof_TC, size_t nr, + MMParallel& parallel) + : ranges_np(parallel.RangesOfNP(max_packages, N, sizeof_TC, nr)) {} // Only profile if enabled and the main autotuner finished (the par_a // autotuner is per-package and we want to avoid synchronization). @@ -623,18 +624,16 @@ struct MatMulEnv { std::vector per_key; }; -// Arguments to MatMul() that are independent of the A/B type. +// Arguments to MatMul() that are independent of the A/B/C types. // Reduces register pressure compared to individual values/references. struct MMArgs { MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale, - const float* HWY_RESTRICT add, const RowPtrD& partial, - const RowPtrF& C) + const float* HWY_RESTRICT add, const RowPtrD& partial) : env(&env), per_key(&per_key), scale(scale), add(add), - partial(partial), - C(C) {} + partial(partial) {} MatMulEnv* env; MMPerKey* per_key; @@ -643,7 +642,6 @@ struct MMArgs { const float* HWY_RESTRICT add; // Same size as C, threads write at false-sharing-free granularity. RowPtrD partial; - RowPtrF C; }; // Wrapper over hwy::Zone that is only enabled when autotuning finished. @@ -683,22 +681,22 @@ struct MMZone { // `ofs` required for compressed T. template struct ConstMat { - ConstMat(const T* ptr, Extents2D extents, size_t ofs = 0) - : ptr(ptr), extents(extents), ofs(ofs) { + ConstMat(const T* ptr, Extents2D extents, size_t stride, size_t ofs = 0) + : ptr(ptr), extents(extents), stride(stride), ofs(ofs) { HWY_DASSERT(ptr != nullptr); + HWY_DASSERT(stride >= extents.cols); } - // TODO: support stride for page alignment. size_t Row(size_t r) const { if constexpr (HWY_IS_DEBUG_BUILD) { if (r >= extents.rows) { HWY_ABORT("ConstMat::Row %zu out of bounds %zu", r, extents.rows); } } - return ofs + extents.cols * r; + return ofs + r * stride; } const Extents2D& Extents() const { return extents; } - size_t Stride() const { return extents.cols; } + size_t Stride() const { return stride; } // Shrinks the row-extent of this matrix view, i.e. reduces the view to a // subrange of the original rows starting at row 0. @@ -709,6 +707,7 @@ struct ConstMat { const T* HWY_RESTRICT ptr; Extents2D extents; + size_t stride; // `scale` allows expanding the smaller range of `SfpStream` to the original // values. MatFromWeights sets this from `MatPtr`. @@ -721,9 +720,9 @@ struct ConstMat { // For deducing T. template -ConstMat MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents, +ConstMat MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents, size_t stride, size_t ofs = 0) { - return ConstMat(ptr, extents, ofs); + return ConstMat(ptr, extents, stride, ofs); } // For A argument to MatMul (activations). @@ -732,22 +731,25 @@ ConstMat ConstMatFromBatch(size_t batch_size, const RowVectorBatch& row_vectors) { HWY_DASSERT(batch_size <= row_vectors.BatchSize()); return MakeConstMat(const_cast(row_vectors.Const()), - Extents2D(batch_size, row_vectors.Cols())); + Extents2D(batch_size, row_vectors.Cols()), + row_vectors.Stride()); } template ConstMat ConstMatFromWeights(const MatPtrT& m, size_t ofs = 0) { - ConstMat mat = MakeConstMat(const_cast(m.data()), m.Extents(), ofs); + ConstMat mat = + MakeConstMat(const_cast(m.data()), m.Extents(), m.Stride(), ofs); mat.scale = m.scale(); return mat; } template -void BindB(size_t N, const ConstMat& B, MMParallel& parallel) { +void BindB(size_t N, size_t sizeof_TC, const ConstMat& B, + MMParallel& parallel) { if (!Allocator::ShouldBind()) return; const IndexRangePartition ranges_np = - parallel.RangesOfNP(MMParallel::kMaxPackages, N, kNR); + parallel.RangesOfNP(MMParallel::kMaxPackages, N, sizeof_TC, kNR); const size_t quantum = Allocator::QuantumBytes() / sizeof(TB); for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { const IndexRange& rows_b = ranges_np.Range(pkg_idx); diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index ad57508..3d34715 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -67,7 +67,7 @@ using MatStoragePtr = std::unique_ptr>; // Generates inputs: deterministic, within max SfpStream range. template -MatStoragePtr GenerateMat(const Extents2D extents, +MatStoragePtr GenerateMat(const Extents2D& extents, hwy::ThreadPool& pool) { gcpp::CompressWorkingSet ws; auto mat = @@ -112,12 +112,12 @@ MatStoragePtr GenerateTransposedMat(const Extents2D extents, } // Returns 1-norm, used for estimating tolerable numerical differences. -double MaxRowAbsSum(const float* HWY_RESTRICT a, const Extents2D& extents) { +double MaxRowAbsSum(const RowVectorBatch& a) { double max_row_abs_sum = 0.0; - for (size_t r = 0; r < extents.rows; r++) { - const float* row = a + r * extents.cols; + for (size_t r = 0; r < a.BatchSize(); r++) { + const float* row = a.Batch(r); double row_abs_sum = 0.0; - for (size_t c = 0; c < extents.cols; c++) { + for (size_t c = 0; c < a.Cols(); c++) { row_abs_sum += hwy::ScalarAbs(row[c]); } max_row_abs_sum = HWY_MAX(max_row_abs_sum, row_abs_sum); @@ -126,41 +126,52 @@ double MaxRowAbsSum(const float* HWY_RESTRICT a, const Extents2D& extents) { } // Returns the maximum absolute value of `a`. -float MaxAbs(const float* HWY_RESTRICT a, const Extents2D& extents) { +float MaxAbs(const RowVectorBatch& a) { float max_abs = 0.0f; - for (size_t c = 0; c < extents.cols; c++) { - for (size_t r = 0; r < extents.rows; r++) { - max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(a[r * extents.cols + c])); + for (size_t c = 0; c < a.Cols(); c++) { + for (size_t r = 0; r < a.BatchSize(); r++) { + const float* row = a.Batch(r); + max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(row[c])); } } return max_abs; } // B is already transposed. -template +template void AssertClose(const ConstMat& A, const ConstMat& B, - const RowPtrF& C_slow, const RowPtrF& C) { + const RowPtr& C_slow, const RowPtr& C, int line) { const hn::ScalableTag df; - const size_t num_a = A.extents.Area(); - const size_t num_b = B.extents.Area(); - const size_t N = hn::Lanes(df); + const size_t cols = A.extents.cols; + const size_t B_rows = B.extents.rows; // Round up for DecompressAndZeroPad. - FloatPtr a = hwy::AllocateAligned(hwy::RoundUpTo(num_a, N)); - FloatPtr b_trans = hwy::AllocateAligned(hwy::RoundUpTo(num_b, N)); - HWY_ASSERT(a && b_trans); + RowVectorBatch a_batch = AllocateAlignedRows(A.extents); + RowVectorBatch b_trans_batch = AllocateAlignedRows(B.extents); + RowVectorBatch c_batch = + AllocateAlignedRows(Extents2D(A.extents.rows, B_rows)); + RowVectorBatch c_slow_batch = + AllocateAlignedRows(Extents2D(A.extents.rows, B_rows)); HWY_ASSERT(A.ofs == 0 && B.ofs == 0); - DecompressAndZeroPad(df, MakeSpan(A.ptr, num_a), 0, a.get(), num_a); - DecompressAndZeroPad(df, MakeSpan(B.ptr, num_b), 0, b_trans.get(), num_b); + for (size_t m = 0; m < A.extents.rows; ++m) { + DecompressAndZeroPad(df, MakeSpan(A.ptr + A.Row(m), cols), 0, + a_batch.Batch(m), cols); + DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Batch(m), + B_rows); + DecompressAndZeroPad(df, MakeSpan(C_slow.Row(m), B_rows), 0, + c_slow_batch.Batch(m), B_rows); + } + for (size_t n = 0; n < B_rows; ++n) { + DecompressAndZeroPad(df, MakeSpan(B.ptr + B.Row(n), cols), 0, + b_trans_batch.Batch(n), cols); + } // MatMul rounds inputs to BF16, so error is proportional to the max input // magnitude, but also to f32 accumulation of rows in A and B. - const double norm = MaxRowAbsSum(a.get(), A.Extents()) * - MaxRowAbsSum(b_trans.get(), B.Extents()); - const float max_abs = - MaxAbs(a.get(), A.Extents()) * MaxAbs(b_trans.get(), B.Extents()); + const double norm = MaxRowAbsSum(a_batch) * MaxRowAbsSum(b_trans_batch); + const float max_abs = MaxAbs(a_batch) * MaxAbs(b_trans_batch); const double eps_bf16 = hwy::ConvertScalarTo(hwy::Epsilon()); const double eps_f32 = hwy::ConvertScalarTo(hwy::Epsilon()); - double tolerance = 10 * norm * eps_f32; + double tolerance = 12 * norm * eps_f32; // Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the // tolerance there. if (IsF32() && IsF32()) { @@ -169,30 +180,38 @@ void AssertClose(const ConstMat& A, const ConstMat& B, if (tolerance > 500.0) { HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs); } + const double max_rel = 1.0 + hwy::ConvertScalarTo(hwy::Epsilon()); for (size_t r = 0; r < A.extents.rows; r++) { - const float* expected_row = C_slow.Row(r); - const float* actual_row = C.Row(r); + const float* expected_row = c_slow_batch.Batch(r); + const float* actual_row = c_batch.Batch(r); for (size_t c = 0; c < B.extents.rows; c++) { const double expected_value = static_cast(expected_row[c]); const double actual_value = static_cast(actual_row[c]); + const bool in_range = expected_value - tolerance <= actual_value && + actual_value <= expected_value + tolerance; - if (!(expected_value - tolerance <= actual_value && - actual_value <= expected_value + tolerance)) { - HWY_ABORT( - "(%zu,%zu): expected %f, actual %f, norm %f maxabs %f " - "tolerance %f\n", - r, c, expected_value, actual_value, norm, max_abs, tolerance); + if (!in_range) { + const double max = HWY_MAX(expected_value, actual_value); + const double min = HWY_MIN(expected_value, actual_value); + const double rel = max / HWY_MAX(min, 1E-6); + if (rel > max_rel) { + hwy::Abort(__FILE__, line, + "(%zu,%zu): expected %f, actual %f, norm %f maxabs %f " + "tolerance %f rel %E max_rel %E\n", + r, c, expected_value, actual_value, norm, max_abs, + tolerance, rel, max_rel); + } } } } } // B is already transposed. -template +template HWY_INLINE void MatMulSlow(const ConstMat A, const ConstMat B, const float* HWY_RESTRICT add_row, MatMulEnv& env, - const RowPtrF& C) { + const RowPtr& C) { // TA can be any Packed except NuqStream because it uses pointer // arithmetic, because it is the second argument to Dot, which does not // support a v_ofs. @@ -200,7 +219,8 @@ HWY_INLINE void MatMulSlow(const ConstMat A, const ConstMat B, const float scale = A.scale * B.scale; const hn::ScalableTag df; // lane type is ignored - const PackedSpan b_span = MakeSpan(B.ptr, B.ofs + B.extents.Area()); + const PackedSpan b_span = + MakeSpan(B.ptr, B.ofs + B.Stride() * B.Extents().rows); const IndexRange all_rows_c(0, A.Extents().rows); const IndexRange all_cols_c(0, C.Cols()); @@ -219,12 +239,12 @@ HWY_INLINE void MatMulSlow(const ConstMat A, const ConstMat B, get_col_c, all_clusters, [&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR { for (size_t r : rows_c) { - float* HWY_RESTRICT C_row = C.Row(r); + TC* HWY_RESTRICT C_row = C.Row(r); for (size_t c : cols_c) { const float add = add_row ? add_row[c] : 0.0f; - C_row[c] = - add + scale * Dot(df, b_span, c * B.extents.cols, - A.ptr + A.Row(r), A.extents.cols); + C_row[c] = hwy::ConvertScalarTo( + add + scale * Dot(df, b_span, c * B.Stride(), + A.ptr + A.Row(r), A.extents.cols)); } } }); @@ -239,14 +259,15 @@ void PrintSpeed(const char* algo, const Extents2D& A_extents, elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed); } -template +template void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, - MatMulEnv& env) { + MatMulEnv& env, int line) { hwy::ThreadPool& pool = env.parallel.Pools().Pool(); - fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s\n", rows_ac, - cols_a_rows_b, cols_bc, add, TypeName(), TypeName()); + fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n", + rows_ac, cols_a_rows_b, cols_bc, add, TypeName(), TypeName(), + TypeName()); - env.print_config = true; + env.print_config = false; // Too verbose. env.print_best = true; const Extents2D A_extents(rows_ac, cols_a_rows_b); @@ -255,8 +276,8 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, MatStoragePtr a = GenerateMat(A_extents, pool); MatStoragePtr b_trans = GenerateTransposedMat(B_extents, pool); - RowVectorBatch c_slow_batch = AllocateAlignedRows(C_extents); - RowVectorBatch c_batch = AllocateAlignedRows(C_extents); + RowVectorBatch c_slow_batch = AllocateAlignedRows(C_extents); + RowVectorBatch c_batch = AllocateAlignedRows(C_extents); HWY_ASSERT(a && b_trans); std::unique_ptr> add_storage; @@ -269,14 +290,14 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, const auto A = ConstMatFromWeights(*a); const auto B = ConstMatFromWeights(*b_trans); const float* add_row = add ? add_storage->data_scale1() : nullptr; - const RowPtrF C_slow = RowPtrFromBatch(c_slow_batch); - const RowPtrF C = RowPtrFromBatch(c_batch); + const RowPtr C_slow = RowPtrFromBatch(c_slow_batch); + const RowPtr C = RowPtrFromBatch(c_batch); MatMulSlow(A, B, add_row, env, C_slow); // A few reps to get coverage of the various autotuned code paths. for (size_t rep = 0; rep < 16; ++rep) { MMPerKey* per_key = MatMul(A, B, add_row, env, C); - AssertClose(A, B, C_slow, C); + AssertClose(A, B, C_slow, C, line); if (per_key->autotune.Best()) break; } } @@ -311,7 +332,7 @@ void TestTiny() { for (size_t M = 1; M <= 12; ++M) { for (size_t K = 1; K <= 64; K *= 2) { for (size_t N = 4; N <= 64; N += max_packages * 4) { - TestMatMul(M, K, N, /*add=*/false, env); + TestMatMul(M, K, N, /*add=*/false, env, __LINE__); } } } @@ -332,56 +353,69 @@ void TestAllMatMul() { Allocator::Init(pools.Topology(), /*enable_bind=*/true); MatMulEnv env(pools); - // Sizes seen in gemma_test 2B. - TestMatMul(1, 2048, 512, /*add=*/false, env); - TestMatMul(1, 2048, 2048, /*add=*/false, env); - TestMatMul(1, 2048, 16384, /*add=*/false, env); - TestMatMul(1, 16384, 2048, /*add=*/false, env); - TestMatMul(1, 2048, 256000, /*add=*/false, env); - TestMatMul(5, 2048, 512, /*add=*/false, env); - TestMatMul(5, 2048, 2048, /*add=*/false, env); - TestMatMul(5, 2048, 16384, /*add=*/false, env); - TestMatMul(5, 16384, 2048, /*add=*/false, env); + // Sizes seen in gemma_test 2B. Too slow for CI, enable on-demand. + TestMatMul(1, 2048, 512, /*add=*/false, env, __LINE__); + // TestMatMul(1, 2048, 2048, /*add=*/false, env, __LINE__); + // TestMatMul(1, 2048, 16384, /*add=*/false, env, __LINE__); + // TestMatMul(1, 16384, 2048, /*add=*/false, env, __LINE__); + // TestMatMul(1, 2048, 256000, /*add=*/false, env, __LINE__); + // TestMatMul(5, 2048, 512, /*add=*/false, env, __LINE__); + // TestMatMul(5, 2048, 2048, /*add=*/false, env, __LINE__); + // TestMatMul(5, 2048, 16384, /*add=*/false, env, __LINE__); + // TestMatMul(5, 16384, 2048, /*add=*/false, env, __LINE__); - // medium-sized square - TestMatMul(512, 512, 512, /*add=*/false, env); - TestMatMul(512, 512, 512, /*add=*/true, env); - TestMatMul(512, 512, 512, /*add=*/false, env); - TestMatMul(512, 512, 512, /*add=*/true, env); - TestMatMul(512, 512, 512, /*add=*/false, env); - TestMatMul(512, 512, 512, /*add=*/true, env); + // medium-sized square, f32 vs bf16 for A, B, C; plus add. + TestMatMul(256, 256, 256, /*add=*/false, env, __LINE__); + TestMatMul(256, 256, 256, /*add=*/false, env, __LINE__); + TestMatMul(256, 256, 256, /*add=*/false, env, __LINE__); + TestMatMul(256, 256, 256, /*add=*/false, env, __LINE__); + TestMatMul(256, 256, 256, /*add=*/false, env, __LINE__); + TestMatMul(256, 256, 256, /*add=*/false, env, __LINE__); + TestMatMul(256, 256, 256, /*add=*/false, env, __LINE__); + TestMatMul(256, 256, 256, /*add=*/false, env, __LINE__); + TestMatMul(256, 256, 256, /*add=*/true, env, __LINE__); + TestMatMul(256, 256, 256, /*add=*/true, env, __LINE__); + TestMatMul(256, 256, 256, /*add=*/true, env, __LINE__); + TestMatMul(256, 256, 256, /*add=*/true, env, __LINE__); + TestMatMul(256, 256, 256, /*add=*/true, env, __LINE__); + TestMatMul(256, 256, 256, /*add=*/true, env, __LINE__); + TestMatMul(256, 256, 256, /*add=*/true, env, __LINE__); + TestMatMul(256, 256, 256, /*add=*/true, env, __LINE__); + + TestMatMul(256, 256, 256, /*add=*/false, env, __LINE__); + TestMatMul(256, 256, 256, /*add=*/true, env, __LINE__); // minimal non-square test. kColsARowsB must be at least 2 vectors. - TestMatMul(35, 128, 32, /*add=*/false, env); - TestMatMul(34, 128, 32, /*add=*/true, env); - TestMatMul(33, 128, 32, /*add=*/false, env); - TestMatMul(33, 128, 32, /*add=*/true, env); - TestMatMul(31, 128, 32, /*add=*/false, env); - TestMatMul(29, 128, 32, /*add=*/true, env); - TestMatMul(4, 128, 32, /*add=*/true, env); - TestMatMul(4, 128, 32, /*add=*/false, env); - TestMatMul(4, 128, 32, /*add=*/true, env); - TestMatMul(4, 128, 32, /*add=*/false, env); - TestMatMul(4, 128, 32, /*add=*/true, env); - TestMatMul(4, 128, 32, /*add=*/false, env); - TestMatMul(3, 128, 32, /*add=*/false, env); - TestMatMul(3, 128, 32, /*add=*/true, env); - TestMatMul(3, 128, 32, /*add=*/false, env); - TestMatMul(3, 128, 32, /*add=*/true, env); - TestMatMul(3, 128, 32, /*add=*/false, env); - TestMatMul(3, 128, 32, /*add=*/true, env); - TestMatMul(2, 128, 64, /*add=*/true, env); - TestMatMul(2, 128, 64, /*add=*/false, env); - TestMatMul(2, 128, 64, /*add=*/true, env); - TestMatMul(2, 128, 64, /*add=*/false, env); - TestMatMul(2, 128, 64, /*add=*/true, env); - TestMatMul(2, 128, 64, /*add=*/false, env); - TestMatMul(1, 128, 32, /*add=*/false, env); - TestMatMul(1, 128, 32, /*add=*/true, env); - TestMatMul(1, 128, 32, /*add=*/false, env); - TestMatMul(1, 128, 32, /*add=*/true, env); - TestMatMul(1, 128, 32, /*add=*/false, env); - TestMatMul(1, 128, 32, /*add=*/true, env); + TestMatMul(35, 128, 32, /*add=*/false, env, __LINE__); + TestMatMul(34, 128, 32, /*add=*/true, env, __LINE__); + TestMatMul(33, 128, 32, /*add=*/false, env, __LINE__); + TestMatMul(33, 128, 32, /*add=*/true, env, __LINE__); + TestMatMul(31, 128, 32, /*add=*/false, env, __LINE__); + TestMatMul(29, 128, 32, /*add=*/true, env, __LINE__); + TestMatMul(4, 128, 32, /*add=*/true, env, __LINE__); + TestMatMul(4, 128, 32, /*add=*/false, env, __LINE__); + TestMatMul(4, 128, 32, /*add=*/true, env, __LINE__); + TestMatMul(4, 128, 32, /*add=*/false, env, __LINE__); + TestMatMul(4, 128, 32, /*add=*/true, env, __LINE__); + TestMatMul(4, 128, 32, /*add=*/false, env, __LINE__); + TestMatMul(3, 128, 32, /*add=*/false, env, __LINE__); + TestMatMul(3, 128, 32, /*add=*/true, env, __LINE__); + TestMatMul(3, 128, 32, /*add=*/false, env, __LINE__); + TestMatMul(3, 128, 32, /*add=*/true, env, __LINE__); + TestMatMul(3, 128, 32, /*add=*/false, env, __LINE__); + TestMatMul(3, 128, 32, /*add=*/true, env, __LINE__); + TestMatMul(2, 128, 64, /*add=*/true, env, __LINE__); + TestMatMul(2, 128, 64, /*add=*/false, env, __LINE__); + TestMatMul(2, 128, 64, /*add=*/true, env, __LINE__); + TestMatMul(2, 128, 64, /*add=*/false, env, __LINE__); + TestMatMul(2, 128, 64, /*add=*/true, env, __LINE__); + TestMatMul(2, 128, 64, /*add=*/false, env, __LINE__); + TestMatMul(1, 128, 32, /*add=*/false, env, __LINE__); + TestMatMul(1, 128, 32, /*add=*/true, env, __LINE__); + TestMatMul(1, 128, 32, /*add=*/false, env, __LINE__); + TestMatMul(1, 128, 32, /*add=*/true, env, __LINE__); + TestMatMul(1, 128, 32, /*add=*/false, env, __LINE__); + TestMatMul(1, 128, 32, /*add=*/true, env, __LINE__); } // NOLINTNEXTLINE(google-readability-namespace-comments)