From 72888914391c95d8453807d0d35b8a34f955686b Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 29 Aug 2025 00:11:31 -0700 Subject: [PATCH] Remove F64 partial storage in matmul. Also remove no longer used kMaxN; row_ptrs only used for C PiperOrigin-RevId: 800774757 --- gemma/attention.cc | 4 +- ops/matmul-inl.h | 488 ++++----------------------------------------- ops/matmul.cc | 50 ++--- ops/matmul.h | 75 ++----- 4 files changed, 69 insertions(+), 548 deletions(-) diff --git a/gemma/attention.cc b/gemma/attention.cc index a04b868..c73abcb 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -275,10 +275,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, const size_t batch_idx = div_qbatch.Divide(interleaved_idx); const size_t cache_pos = activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx); - env.row_ptrs[2][interleaved_idx] = reinterpret_cast( + env.row_ptrs[0][interleaved_idx] = reinterpret_cast( qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size); } - kv_rows.AttachRowPtrs(env.row_ptrs[2].get()); + kv_rows.AttachRowPtrs(env.row_ptrs[0].get()); CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2, /*add=*/nullptr, env, kv_rows); diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 9f279cb..56cb06f 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -71,23 +71,29 @@ static hn::VFromD FastPromoteOddTo(DF df, hn::VFromD vbf) { #endif } -// Converts from float intermediate to MatMul output type `TC`. -template , HWY_IF_F32_D(DC)> -hn::Vec TCFromF32(DC /*dc*/, hn::Vec vf) { +// Converts from float intermediate to/from MatMul output type `TC`. +template +hn::Vec TCFromF32(DC /*dc*/, hn::Vec vf) { return vf; } template , HWY_IF_BF16_D(DC)> hn::Vec TCFromF32(DC dc, hn::Vec vf) { return hn::DemoteTo(dc, vf); } +template +hn::Vec F32FromTC(DC /*dc*/, hn::Vec vc) { + return vc; +} +template , HWY_IF_BF16_D(DC)> +hn::Vec F32FromTC(DC dc, hn::Vec vc) { + return hn::PromoteTo(DF(), vc); +} // Tag classes, passed to `MMKernel::A2C0` to choose between writing one -// (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or writing the -// first kc result to partial, or accumulating the next kc result into partial -// via `MMAddHorizontalSumsIntoPartial`. +// (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or accumulating the +// next kc result into `C`. struct MMSetC {}; -struct MMSetPartial {}; -struct MMAddPartial {}; +struct MMAddC {}; // Stores horizontal sums of up to 16 vectors via transpose. template @@ -143,10 +149,8 @@ class MMStoreHorizontalSumsIntoC { sum3 = MaybeAdd<3>(d4, N, sum3, buf + kNR * lane); } const V4 vscale = hn::Set(d4, args.scale); - V4 vadd = hn::Zero(d4); - if constexpr (kAdd) { - vadd = hn::Load(d4, args.add + col_c); - } + HWY_ALIGN static constexpr float kZero[4] = {}; + const V4 vadd = hn::Load(d4, args.add ? args.add + col_c : kZero); MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, C_rows, row_c, col_c); MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, C_rows, row_c, col_c); MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, C_rows, row_c, col_c); @@ -195,156 +199,16 @@ class MMStoreHorizontalSumsIntoC { if constexpr (kRow < kRowsAC) { TC* HWY_RESTRICT pos = C_rows[row_c + kRow] + col_c; const hn::Rebind dc4; + if constexpr (kAdd) { + vadd = F32FromTC(dc4, hn::Load(dc4, pos)); // load prior value + } // else: add bias (only once, the first time we store to C) + const VF4 out = hn::MulAdd(sum, vscale, vadd); hn::Store(TCFromF32(dc4, out), dc4, pos); } } }; // MMStoreHorizontalSumsIntoC -// Accumulates horizontal sums of up to 16 vectors via transpose. -template -class MMAddHorizontalSumsIntoPartial { - public: - static_assert(kNR == 4); // for `StoreInterleaved4` - - // Computes horizontal sums of `kRowsAC x kNR` vectors and accumulates - // into `partial` starting at `(row_c, col_c)`. - // - // `Crc` are the 16 combinations of an A row vector indexed by `r`, times a - // transposed B row vector indexed by `c`. Their elements are thus a subset - // of the terms of the dot product constituting the final `C[r, c]` result. - // Thus we compute the horizontal sums of each `Crc`. The elements may be - // permuted because we multiply bf16 via `ReorderWidenMulAccumulate`, but - // this does not change their horizontal sum. - template > - HWY_INLINE void operator()(DF df, // - VF F00, VF F01, VF F02, VF F03, // - VF F10, VF F11, VF F12, VF F13, // - VF F20, VF F21, VF F22, VF F23, // - VF F30, VF F31, VF F32, VF F33, // - const size_t row_c, const size_t col_c, - const StridedViewD& partial) const { - // We accumulate in 64-bit to avoid loss of precision. - static_assert(HWY_HAVE_FLOAT64, "Disable Armv7 NEON: we require fp64"); - - const hn::Repartition dd; - HWY_ALIGN double buf[16 * hn::MaxLanes(dd)]; - using VD = hn::Vec; - HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd); - VD C00 = SumOfPromotedPairs(dd, F00); - VD C01 = SumOfPromotedPairs(dd, F01); - VD C02 = SumOfPromotedPairs(dd, F02); - VD C03 = SumOfPromotedPairs(dd, F03); - VD C10 = SumOfPromotedPairs(dd, F10); - VD C11 = SumOfPromotedPairs(dd, F11); - VD C12 = SumOfPromotedPairs(dd, F12); - VD C13 = SumOfPromotedPairs(dd, F13); - VD C20 = SumOfPromotedPairs(dd, F20); - VD C21 = SumOfPromotedPairs(dd, F21); - VD C22 = SumOfPromotedPairs(dd, F22); - VD C23 = SumOfPromotedPairs(dd, F23); - VD C30 = SumOfPromotedPairs(dd, F30); - VD C31 = SumOfPromotedPairs(dd, F31); - VD C32 = SumOfPromotedPairs(dd, F32); - VD C33 = SumOfPromotedPairs(dd, F33); - - // Horizontal reductions (`ReduceSum`) are rather expensive, entailing - // log(N) operations for vectors of length N. Because `kNR` == 4, we - // instead use `StoreInterleaved4` for a vector length-agnostic - // 'transpose': `buf[0, 4 * N)` holds `C00[0], C01[0], C02[0], C03[0], - // C00[1], C01[1], C02[1], C03[1] .. C00[N-1], C01[N-1], C02[N-1], - // C03[N-1]`. - MaybeStoreInterleaved4<0>(dd, ND, C00, C01, C02, C03, buf); - MaybeStoreInterleaved4<1>(dd, ND, C10, C11, C12, C13, buf); - MaybeStoreInterleaved4<2>(dd, ND, C20, C21, C22, C23, buf); - MaybeStoreInterleaved4<3>(dd, ND, C30, C31, C32, C33, buf); - // Adding N consecutive V4 yields horizontal sums of Cr0, Cr1, Cr2, Cr3 in - // the elements of one V4. We have four independent rows `r`, hence the - // code is effectively unrolled, which increases throughput. - const hn::CappedTag d4; - using V4 = hn::Vec; - // Store to four elements per row of `partial`. - // Loop is required because vectors may be smaller than 4*64 bits. - for (size_t c = 0; c < kNR; c += hn::Lanes(d4)) { - V4 sum0 = MaybeLoad<0>(d4, ND, buf + c); - V4 sum1 = MaybeLoad<1>(d4, ND, buf + c); - V4 sum2 = MaybeLoad<2>(d4, ND, buf + c); - V4 sum3 = MaybeLoad<3>(d4, ND, buf + c); - - for (size_t lane = 1; lane < ND; ++lane) { - sum0 = MaybeAdd<0>(d4, ND, sum0, buf + c + kNR * lane); - sum1 = MaybeAdd<1>(d4, ND, sum1, buf + c + kNR * lane); - sum2 = MaybeAdd<2>(d4, ND, sum2, buf + c + kNR * lane); - sum3 = MaybeAdd<3>(d4, ND, sum3, buf + c + kNR * lane); - } - MaybeAddStore<0>(d4, sum0, partial, row_c, col_c + c); - MaybeAddStore<1>(d4, sum1, partial, row_c, col_c + c); - MaybeAddStore<2>(d4, sum2, partial, row_c, col_c + c); - MaybeAddStore<3>(d4, sum3, partial, row_c, col_c + c); - } - } - - private: - // Converts lanes to double and adds pairs of them to obtain a vector with the - // same horizontal sum, but element type double. - template , - class DF = hn::Repartition, class VF = hn::Vec> - static HWY_INLINE VD SumOfPromotedPairs(DD dd, VF f) { - // TODO: SVE could PromoteEvenTo. - const VD d0 = hn::PromoteLowerTo(dd, f); - const VD d1 = hn::PromoteUpperTo(dd, f); - return hn::Add(d0, d1); - } - - // These helper functions hoist if() out of the main code below. They have - // no effect if kRow >= kRowsAC. - template > - static HWY_INLINE void MaybeStoreInterleaved4(DD dd, size_t N, VD Cr0, VD Cr1, - VD Cr2, VD Cr3, - double* HWY_RESTRICT buf) { - if constexpr (kRow < kRowsAC) { - hn::StoreInterleaved4(Cr0, Cr1, Cr2, Cr3, dd, buf + 4 * kRow * N); - } - } - - // Note: N is the number of lanes in the StoreInterleaved4 vectors, not V4. - template > - static HWY_INLINE V4 MaybeLoad(D4 d4, size_t N, - const double* HWY_RESTRICT buf) { - if constexpr (kRow < kRowsAC) { - return hn::Load(d4, buf + 4 * kRow * N); - } else { - return hn::Zero(d4); - } - } - - template > - static HWY_INLINE V4 MaybeAdd(D4 d4, size_t N, V4 sum, - const double* HWY_RESTRICT buf) { - if constexpr (kRow < kRowsAC) { - return hn::Add(sum, hn::Load(d4, buf + 4 * kRow * N)); - } else { - return sum; - } - } - - template > - static HWY_INLINE void MaybeAddStore(D4 d4, V4 sum, - const StridedViewD& partial, - const size_t row_c, const size_t col_c) { - if constexpr (kRow < kRowsAC) { - double* HWY_RESTRICT pos = partial.Row(row_c + kRow) + col_c; - if constexpr (hwy::IsSame()) { - hn::Store(sum, d4, pos); - } else { - static_assert(hwy::IsSame()); - const V4 prev = hn::Load(d4, pos); - hn::Store(hn::Add(sum, prev), d4, pos); - } - } - } -}; // MMAddHorizontalSumsIntoPartial - // Stateless, wraps member functions. class MMKernel { public: @@ -865,247 +729,18 @@ class MMKernel { // called frequently, so do not add a profiler zone. if constexpr (hwy::IsSame()) { - if (args.add) { - MMStoreHorizontalSumsIntoC()( - df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, - C31, C32, C33, row_ac, col_c, args, C_rows); - } else { - MMStoreHorizontalSumsIntoC()( - df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, - C31, C32, C33, row_ac, col_c, args, C_rows); - } - } else { - MMAddHorizontalSumsIntoPartial()( + MMStoreHorizontalSumsIntoC()( df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, - C31, C32, C33, row_ac, col_c, args.partial); + C31, C32, C33, row_ac, col_c, args, C_rows); + } else { + static_assert(hwy::IsSame()); + MMStoreHorizontalSumsIntoC()( + df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, + C31, C32, C33, row_ac, col_c, args, C_rows); } } }; -// Multiply partial by scale, add bias if present, demote and store to f32 `C`. -// Stateless, wraps member functions. -class MMScaleDemoteAdd { - public: - // Fills the `range_mc/range_nc` region of `outputs.C` by multiplying the - // same region of `outputs.partial` by `outputs.scale`, which is the product - // of the scales of A and B, demoting from f64 to f32, then if `outputs.add` - // is nonzero, adding it to each row. - // TODO: fuse with subsequent operations - function pointer? - // Although this region in `outputs.C` is not touched again, streaming stores - // do not help on SKX and Zen4. TODO: re-check this. - template - static HWY_INLINE void FillC(const IndexRange& range_mc, - const IndexRange& range_nc, const MMArgs& args, - RowPtrs C_rows) { - size_t row_c = range_mc.begin(); - if (args.add) { - constexpr bool kAdd = true; - if (range_mc.Num() >= 4) { - for (; row_c <= range_mc.end() - 4; row_c += 4) { - Do4Rows(row_c, range_nc, args, C_rows); - } - } - for (; row_c < range_mc.end(); ++row_c) { - Do1Row(row_c, range_nc, args, C_rows); - } - } else { - constexpr bool kAdd = false; - if (range_mc.Num() >= 4) { - for (; row_c <= range_mc.end() - 4; row_c += 4) { - Do4Rows(row_c, range_nc, args, C_rows); - } - } - for (; row_c < range_mc.end(); ++row_c) { - Do1Row(row_c, range_nc, args, C_rows); - } - } - } - - private: - // Unrolled for 4 rows to reduce the number of loads from `add`. - template - static HWY_INLINE void Do4Rows(size_t row_c, const IndexRange& range_nc, - const MMArgs& args, RowPtrs C_rows) { - const hn::ScalableTag dd; - const hn::Rebind df; // result of DemoteTo - const hn::Rebind dc; - using VD = hn::Vec; - using VF = hn::Vec; - HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd); - const VD vscale = hn::Set(dd, args.scale); - - const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0); - const double* HWY_RESTRICT pr1 = args.partial.Row(row_c + 1); - const double* HWY_RESTRICT pr2 = args.partial.Row(row_c + 2); - const double* HWY_RESTRICT pr3 = args.partial.Row(row_c + 3); - - TC* HWY_RESTRICT cr0 = C_rows[row_c + 0]; - TC* HWY_RESTRICT cr1 = C_rows[row_c + 1]; - TC* HWY_RESTRICT cr2 = C_rows[row_c + 2]; - TC* HWY_RESTRICT cr3 = C_rows[row_c + 3]; - - // We manually unroll 2x for higher IPC in batch=1. - size_t col_c = range_nc.begin(); - if (HWY_LIKELY(range_nc.Num() >= 2 * ND)) { - for (; col_c <= range_nc.end() - 2 * ND; col_c += 2 * ND) { - VD a0, a1; // unused if !kAdd - if constexpr (kAdd) { - // Promoting to double lets us fuse the Add into MulAdd. - a0 = hn::PromoteTo(dd, hn::Load(df, args.add + col_c)); - a1 = hn::PromoteTo(dd, hn::Load(df, args.add + col_c + ND)); - } - const VD d00 = hn::Load(dd, pr0 + col_c); - const VD d01 = hn::Load(dd, pr0 + col_c + ND); - const VD d10 = hn::Load(dd, pr1 + col_c); - const VD d11 = hn::Load(dd, pr1 + col_c + ND); - const VD d20 = hn::Load(dd, pr2 + col_c); - const VD d21 = hn::Load(dd, pr2 + col_c + ND); - const VD d30 = hn::Load(dd, pr3 + col_c); - const VD d31 = hn::Load(dd, pr3 + col_c + ND); - VD m00, m01, m10, m11, m20, m21, m30, m31; - if constexpr (kAdd) { - m00 = hn::MulAdd(d00, vscale, a0); - m01 = hn::MulAdd(d01, vscale, a1); - m10 = hn::MulAdd(d10, vscale, a0); - m11 = hn::MulAdd(d11, vscale, a1); - m20 = hn::MulAdd(d20, vscale, a0); - m21 = hn::MulAdd(d21, vscale, a1); - m30 = hn::MulAdd(d30, vscale, a0); - m31 = hn::MulAdd(d31, vscale, a1); - } else { - m00 = hn::Mul(d00, vscale); - m01 = hn::Mul(d01, vscale); - m10 = hn::Mul(d10, vscale); - m11 = hn::Mul(d11, vscale); - m20 = hn::Mul(d20, vscale); - m21 = hn::Mul(d21, vscale); - m30 = hn::Mul(d30, vscale); - m31 = hn::Mul(d31, vscale); - } - // First convert f64 to f32. - const VF f00 = hn::DemoteTo(df, m00); - const VF f01 = hn::DemoteTo(df, m01); - const VF f10 = hn::DemoteTo(df, m10); - const VF f11 = hn::DemoteTo(df, m11); - const VF f20 = hn::DemoteTo(df, m20); - const VF f21 = hn::DemoteTo(df, m21); - const VF f30 = hn::DemoteTo(df, m30); - const VF f31 = hn::DemoteTo(df, m31); - // Note that Stream is neutral on SKX and harmful on Zen4. - hn::Store(TCFromF32(dc, f00), dc, cr0 + col_c); - hn::Store(TCFromF32(dc, f01), dc, cr0 + col_c + ND); - hn::Store(TCFromF32(dc, f10), dc, cr1 + col_c); - hn::Store(TCFromF32(dc, f11), dc, cr1 + col_c + ND); - hn::Store(TCFromF32(dc, f20), dc, cr2 + col_c); - hn::Store(TCFromF32(dc, f21), dc, cr2 + col_c + ND); - hn::Store(TCFromF32(dc, f30), dc, cr3 + col_c); - hn::Store(TCFromF32(dc, f31), dc, cr3 + col_c + ND); - } - } - - for (; col_c < range_nc.end(); col_c += ND) { - const size_t remaining = range_nc.end() - col_c; - HWY_DASSERT(remaining < 2 * ND); - - VD a0; // unused if !kAdd - if constexpr (kAdd) { - // Promoting to double lets us fuse the Add into MulAdd. - a0 = hn::PromoteTo(dd, hn::LoadN(df, args.add + col_c, remaining)); - } - const VD d00 = hn::LoadN(dd, pr0 + col_c, remaining); - const VD d10 = hn::LoadN(dd, pr1 + col_c, remaining); - const VD d20 = hn::LoadN(dd, pr2 + col_c, remaining); - const VD d30 = hn::LoadN(dd, pr3 + col_c, remaining); - VD m00, m10, m20, m30; - if constexpr (kAdd) { - m00 = hn::MulAdd(d00, vscale, a0); - m10 = hn::MulAdd(d10, vscale, a0); - m20 = hn::MulAdd(d20, vscale, a0); - m30 = hn::MulAdd(d30, vscale, a0); - } else { - m00 = hn::Mul(d00, vscale); - m10 = hn::Mul(d10, vscale); - m20 = hn::Mul(d20, vscale); - m30 = hn::Mul(d30, vscale); - } - // First convert f64 to f32. - const VF f00 = hn::DemoteTo(df, m00); - const VF f10 = hn::DemoteTo(df, m10); - const VF f20 = hn::DemoteTo(df, m20); - const VF f30 = hn::DemoteTo(df, m30); - hn::StoreN(TCFromF32(dc, f00), dc, cr0 + col_c, remaining); - hn::StoreN(TCFromF32(dc, f10), dc, cr1 + col_c, remaining); - hn::StoreN(TCFromF32(dc, f20), dc, cr2 + col_c, remaining); - hn::StoreN(TCFromF32(dc, f30), dc, cr3 + col_c, remaining); - } - } - - // Same as above but handles a single row (for remainder rows). - template - static HWY_INLINE void Do1Row(size_t row_c, const IndexRange& range_nc, - const MMArgs& args, RowPtrs C_rows) { - const hn::ScalableTag dd; - const hn::Rebind df; // result of DemoteTo - const hn::Rebind dc; - using VD = hn::Vec; - using VF = hn::Vec; - HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd); - const VD vscale = hn::Set(dd, args.scale); - const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0); - TC* HWY_RESTRICT cr0 = C_rows[row_c + 0]; - - // We manually unroll 2x for higher IPC in batch=1. - size_t col_c = range_nc.begin(); - if (HWY_LIKELY(range_nc.Num() >= 2 * ND)) { - for (; col_c <= range_nc.end() - 2 * ND; col_c += 2 * ND) { - VD a0, a1; // unused if !kAdd - if constexpr (kAdd) { - // Promoting to double lets us fuse the Add into MulAdd. - a0 = hn::PromoteTo(dd, hn::Load(df, args.add + col_c)); - a1 = hn::PromoteTo(dd, hn::Load(df, args.add + col_c + ND)); - } - const VD d00 = hn::Load(dd, pr0 + col_c); - const VD d01 = hn::Load(dd, pr0 + col_c + ND); - VD m00, m01; - if constexpr (kAdd) { - m00 = hn::MulAdd(d00, vscale, a0); - m01 = hn::MulAdd(d01, vscale, a1); - } else { - m00 = hn::Mul(d00, vscale); - m01 = hn::Mul(d01, vscale); - } - // First convert f64 to f32. - const VF f00 = hn::DemoteTo(df, m00); - const VF f01 = hn::DemoteTo(df, m01); - // Note that Stream is neutral on SKX and harmful on Zen4. - hn::Store(TCFromF32(dc, f00), dc, cr0 + col_c); - hn::Store(TCFromF32(dc, f01), dc, cr0 + col_c + ND); - } - } - - for (; col_c < range_nc.end(); col_c += ND) { - const size_t remaining = range_nc.end() - col_c; - HWY_DASSERT(remaining < 2 * ND); - - VD a0; // unused if !kAdd - if constexpr (kAdd) { - // Promoting to double lets us fuse the Add into MulAdd. - a0 = hn::PromoteTo(dd, hn::LoadN(df, args.add + col_c, remaining)); - } - const VD d00 = hn::LoadN(dd, pr0 + col_c, remaining); - VD m00; - if constexpr (kAdd) { - m00 = hn::MulAdd(d00, vscale, a0); - } else { - m00 = hn::Mul(d00, vscale); - } - // First convert f64 to f32. - const VF f00 = hn::DemoteTo(df, m00); - hn::StoreN(TCFromF32(dc, f00), dc, cr0 + col_c, remaining); - } - } -}; // MMScaleDemoteAdd - // Called on the main thread with the entire N range, or by each package with // a static partition of N. This class contains several variants of the // outer M/N/K loops, and calls `A2C0` which loops over the inner KC and MC. @@ -1132,7 +767,6 @@ class MMPerPackage { ranges_nc_(config.RangesOfNC(range_np)), order_(config.Order()), inner_tasks_(config.InnerTasks()), - out_(config.Out()), line_bytes_(args.env->ctx.allocator.LineBytes()) {} // The size of `A` that will actually be used, for purposes of choosing the @@ -1244,11 +878,9 @@ class MMPerPackage { MMSetC(), args_, C_rows); } }); - - HWY_DASSERT(out_ == MMOut::kDirect); // already filled C } - // Single M range, parallel N, sequential K. Fills all of partial. + // Single M range, parallel N, sequential K. Sets C, then accumulates. template HWY_INLINE void DoNT_K(const StridedView A, const bool A_padded, const MatPtrT& B, RowPtrs C_rows) const { @@ -1288,32 +920,12 @@ class MMPerPackage { // Peel off the first iteration of the kc loop: avoid // zero-initializing `partial` by writing into it. ranges_kc_.VisitFirst([&](const IndexRange& range_kc) { - loop_nc(B_storage, range_kc, range_nc, MMSetPartial()); + loop_nc(B_storage, range_kc, range_nc, MMSetC()); }); ranges_kc_.VisitRemaining([&](const IndexRange& range_kc) { - loop_nc(B_storage, range_kc, range_nc, MMAddPartial()); + loop_nc(B_storage, range_kc, range_nc, MMAddC()); }); }); - - if (out_ == MMOut::kCopy) { - static const auto zone = - args_.env->ctx.profiler.AddZone("MM.NT_K.FillC.Copy"); - MMZone fill_zone; - fill_zone.MaybeEnter(0, zone, args_); - MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C_rows); - } else if (out_ == MMOut::kParM) { - static const auto zone = - args_.env->ctx.profiler.AddZone("MM.NT_K.FillC.ParM"); - args_.env->parallel.ForRangeMC( - range_mc, pkg_idx_, [&](size_t row_a, size_t worker) HWY_ATTR { - MMZone fill_zone; - fill_zone.MaybeEnter(worker, zone, args_); - MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_, - args_, C_rows); - }); - } else { - HWY_UNREACHABLE; // kDirect is only used with kNT. - } } // Parallel loops over mc/nc blocks of M/range_np, single K. @@ -1343,14 +955,12 @@ class MMPerPackage { for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { - StridedViewBF B_view = + const StridedViewBF B_view = DecompressB(B, row_b, range_K, B_storage_view); MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, K, MMSetC(), args_, C_rows); } }); - - HWY_DASSERT(out_ == MMOut::kDirect); // already filled C } // Parallel loops over mc/nc blocks of M/range_np, sequential K. @@ -1359,8 +969,6 @@ class MMPerPackage { HWY_INLINE void DoNT_MT_K(const StridedView A, const bool A_padded, const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K"); - static const auto fill_zone = - args_.env->ctx.profiler.AddZone("MM.NT_MT_K.FillC"); const size_t kc_max = ranges_kc_.TaskSize(); HWY_DASSERT(kc_max <= MMStorage::kMaxKC); const size_t B_stride = @@ -1395,22 +1003,13 @@ class MMPerPackage { const StridedViewBF B_storage_view(B_storage, kc_max, B_stride); // Peel off the first iteration of the kc loop: avoid - // zero-initializing `partial` by writing into it. + // zero-initializing `C` by writing into it. ranges_kc_.VisitFirst([&](const IndexRange& range_kc) { - loop_nc(B_storage_view, range_mc, range_kc, range_nc, - MMSetPartial()); + loop_nc(B_storage_view, range_mc, range_kc, range_nc, MMSetC()); }); ranges_kc_.VisitRemaining([&](const IndexRange& range_kc) { - loop_nc(B_storage_view, range_mc, range_kc, range_nc, - MMAddPartial()); + loop_nc(B_storage_view, range_mc, range_kc, range_nc, MMAddC()); }); - - // Already in parallel section, hence no `kParM`, and - // `kDirect` is only used with `kNT_MT`. - HWY_DASSERT(out_ == MMOut::kCopy); - MMZone fill_mm_zone; - fill_mm_zone.MaybeEnter(worker, fill_zone, args_); - MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C_rows); }); } @@ -1559,7 +1158,6 @@ class MMPerPackage { const IndexRangePartition ranges_nc_; const MMOrder order_; const size_t inner_tasks_; - const MMOut out_; const size_t line_bytes_; }; // MMPerPackage @@ -1632,7 +1230,7 @@ template HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const float* HWY_RESTRICT add, MatMulEnv& env, MatPtrT& C) { - RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[2]); + RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[0]); const Allocator& allocator = env.ctx.allocator; const size_t M = A.Rows(); @@ -1659,7 +1257,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, MMAutoTune& tuner = per_key.autotune; const MMArgs args(env, per_key, static_cast(A.Scale()) * B.Scale(), - add, env.storage.Partial()); + add); if (HWY_LIKELY(tuner.Best())) { MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best()); return &per_key; @@ -1673,7 +1271,6 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, HWY_ASSERT(K == B.Cols()); HWY_ASSERT(M <= MMStorage::kMaxM); HWY_ASSERT(K <= MMStorage::kMaxK); - HWY_ASSERT(N <= MMStorage::kMaxN); HWY_ASSERT(N % kNR == 0); tuner.SetCandidates( @@ -1690,10 +1287,9 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, hwy::platform::InvariantTicksPerSecond(); const double flops = 2 * M * K * N / min_elapsed; // * 2 for FMA if (HWY_UNLIKELY(env.print_measurement && tuner.ShouldPrint())) { - fprintf(stderr, "%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%s\n", flops * 1E-9, + fprintf(stderr, "%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu\n", flops * 1E-9, min_elapsed * 1E3, cfg.MR(), cfg.MC(), cfg.KC(), cfg.NC(), - StringFromOrder(cfg.Order()), cfg.InnerTasks(), - StringFromOut(cfg.Out())); + StringFromOrder(cfg.Order()), cfg.InnerTasks()); } if (HWY_UNLIKELY(env.print_best && tuner.Best())) { const auto ratio = [per_key](uint64_t ticks) -> double { @@ -1702,11 +1298,11 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, }; const MMConfig& best = *tuner.Best(); fprintf(stderr, - "\n%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%s,%.2f,%.2f\n", - M, K, N, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(), + "\n%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%.2f,%.2f\n", M, + K, N, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(), best.KC(), best.NC(), StringFromOrder(best.Order()), - best.InnerTasks(), StringFromOut(best.Out()), - ratio(tuner.WorstMinTicks()), ratio(tuner.FirstConfigTicks())); + best.InnerTasks(), ratio(tuner.WorstMinTicks()), + ratio(tuner.FirstConfigTicks())); } return &per_key; diff --git a/ops/matmul.cc b/ops/matmul.cc index c9ddfb6..71f2efe 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -91,24 +91,21 @@ class GenerateCandidates { for (size_t mr : MR()) { for (MMOrder order : Orders(mr)) { const std::vector& all_inner_tasks = InnerTasks(order); - const std::vector& all_outs = Outs(order); for (size_t kc : KC(mr, order)) { for (size_t mc : MC(mr, kc, order)) { for (size_t nc : NC(mr, mc, kc, order)) { for (int inner_tasks : all_inner_tasks) { - for (MMOut out : all_outs) { - const MMConfig config(K_, N_, mr, mc, kc, nc, kc_multiple_, - nc_multiple_, order, out, inner_tasks); - const size_t M_tasks = config.RangesOfMC(M_).NumTasks(); - const size_t K_tasks = config.RangesOfKC(K_).NumTasks(); + const MMConfig config(K_, N_, mr, mc, kc, nc, kc_multiple_, + nc_multiple_, order, inner_tasks); + const size_t M_tasks = config.RangesOfMC(M_).NumTasks(); + const size_t K_tasks = config.RangesOfKC(K_).NumTasks(); - // Blocks only make sense when there are multiple M tasks. - if (IsBlock(order) != (M_tasks > 1)) continue; - // Single KC only makes sense when there is a single K task. - if (IsOneKC(order) != (K_tasks == 1)) continue; + // Blocks only make sense when there are multiple M tasks. + if (IsBlock(order) != (M_tasks > 1)) continue; + // Single KC only makes sense when there is a single K task. + if (IsOneKC(order) != (K_tasks == 1)) continue; - candidates.push_back(config); - } + candidates.push_back(config); } } } @@ -265,14 +262,13 @@ class GenerateCandidates { SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const { const size_t np_max = ranges_np_.TaskSize(); size_t nc_max = np_max; - const size_t out_bytes = IsOneKC(order) ? sizeof_TC_ : sizeof(double); // Only if there will be reuse of B: choose the largest `nc_max` (C cols) // such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3. // Otherwise, leave it unbounded. if (M_ > mr) { - const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * out_bytes); - nc_max = hwy::DivCeil(allocator_.L3Bytes(), bytes_per_nc); - nc_max = HWY_MIN(HWY_MIN(nc_max, MMStorage::kMaxN), np_max); + const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * sizeof_TC_); + nc_max = + HWY_MIN(hwy::DivCeil(allocator_.L3Bytes(), bytes_per_nc), np_max); } HWY_DASSERT(nc_max != 0); nc_max = RoundDownWithFloor(nc_max, nc_multiple_); @@ -340,24 +336,6 @@ class GenerateCandidates { return inner_tasks; } - // Whether to parallelize FillC or enable direct writes to C. - std::vector Outs(MMOrder order) const { - std::vector outs; - for (size_t out_idx = 0;; ++out_idx) { - const MMOut out = static_cast(out_idx); - if (StringFromOut(out) == nullptr) return outs; // done - // kParM only makes sense if we have more than one row of A. - if (out == MMOut::kParM && M_ == 1) continue; - // Blocks are already parallelized. - if (out == MMOut::kParM && IsBlock(order)) continue; - // Direct only works for a single kc range. - if ((out == MMOut::kDirect) != IsOneKC(order)) continue; - // For non-block, kCopy does not beat kDirect. - if (out == MMOut::kCopy && IsOneKC(order) && !IsBlock(order)) continue; - outs.push_back(out); - } - } - const Allocator& allocator_; const size_t M_; const size_t K_; @@ -432,8 +410,6 @@ MatMulEnv::MatMulEnv(ThreadingContext& ctx) char cpu100[100]; have_timer_stop = hwy::platform::HaveTimerStop(cpu100); - row_ptrs.push_back(hwy::AllocateAligned(MMStorage::kMaxM)); // A - row_ptrs.push_back(hwy::AllocateAligned(MMStorage::kMaxN)); // B row_ptrs.push_back(hwy::AllocateAligned(MMStorage::kMaxM)); // C } @@ -461,7 +437,7 @@ void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) { } } -// C is BF16/float, or double for partial +// C is BF16/float void BindC(MatPtr& C, MMParallel& parallel) { Allocator& allocator = parallel.allocator(); if (!allocator.ShouldBind()) return; diff --git a/ops/matmul.h b/ops/matmul.h index 99290c1..e4c436f 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -43,7 +43,7 @@ namespace gcpp { // at least the product of the FMA latency (3..5) times the throughput (2). // This and `mr` are limited by the number of registers, which is generally // 32 but 16 for AVX2. `kNR` == 4 enables the `StoreInterleaved4` transpose in -// `MMAddHorizontalSumsIntoPartial`. We ensure `C.Cols() % kNR == 0`. +// `MMStoreHorizontalSumsIntoC`. We ensure `C.Cols() % kNR == 0`. constexpr size_t kNR = 4; // Choosing `kMaxMR == kNR` minimizes the ratio of loads to FMA, because @@ -195,7 +195,7 @@ class MMParallel { }; void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel); -// C is BF16/float, or double for partial. +// C is BF16/float. void BindC(MatPtr& C, MMParallel& parallel); // Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows. @@ -236,8 +236,7 @@ class StridedView { using StridedViewBF = StridedView; using StridedViewD = StridedView; -// Per-package storage for packed A, and one global C-shaped `partial` for -// accumulating partial dot products (sections of K). +// Per-package storage for packed A. class MMStorage { public: // Compile-time bounds on matrix dimensions to enable pre-allocating storage @@ -245,21 +244,13 @@ class MMStorage { // per package and 512 MiB, respectively. static constexpr size_t kMaxM = 4096; static constexpr size_t kMaxK = 64 * 1024; - static constexpr size_t kMaxN = 256 * 1024; // Upper bound for per-worker B storage on the stack. Chosen such that one row // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. static constexpr size_t kMaxKC = 8 * 1024; // Internally threaded; must not be called concurrently with the same // `ThreadingContext` (used via `parallel`). - MMStorage(const Allocator& allocator, MMParallel& parallel) - : // Per-worker copies of `partial` would be wasteful. We instead - // allocate one instance of the maximum matrix extents because threads - // write at false-sharing-free granularity. - partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN), allocator, - MatPadding::kOdd), - // Same stride independent of the actual C.Cols() so we can pre-bind. - partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) { + MMStorage(const Allocator& allocator, MMParallel& parallel) { // Per-package allocation so each can decompress A into its own copy. // Must be padded, see `DoDecompressA`. parallel.ForPkg(kMaxPackages, [&](size_t pkg_idx) { @@ -276,9 +267,6 @@ class MMStorage { } } }); - - // Avoid cross-package accesses. - BindC(partial_storage_, parallel); } // Returns per-package matrix view. Converting A=F32 to BF16 up-front is @@ -291,12 +279,8 @@ class MMStorage { extents.cols, pkg_A_[pkg_idx]->Stride()); } - StridedViewD Partial() const { return partial_; } - private: std::unique_ptr> pkg_A_[kMaxPackages]; - MatStorageT partial_storage_; - StridedViewD partial_; }; //------------------------------------------------------------------------------ @@ -349,29 +333,6 @@ static inline const char* StringFromOrder(MMOrder order) { } } -// How/where to write the A2C0 result. This determines the `tag` argument to -// that function, which governs whether we call `MMStoreHorizontalSumsIntoC` or -// `MMAddHorizontalSumsIntoPartial`. -enum class MMOut : uint8_t { - kCopy, // accumulate into partial, scale/add to C - kDirect, // single kc task, write directly to C - kParM // kCopy but parallel over M - // kParN is not better on SKX/Zen4. -}; - -static inline const char* StringFromOut(MMOut out) { - switch (out) { - case MMOut::kDirect: - return "Direct"; - case MMOut::kCopy: - return "Copy"; - case MMOut::kParM: - return "ParM"; - default: - return nullptr; - } -} - // How to parallelize the per-package `DecompressA`. To reduce combinatorial // explosion, we tune this separately from `MMConfig`. enum class MMParA : uint8_t { kNone, kK1 = 1, kK2 = 2, kK4 = 4, kM }; @@ -405,10 +366,9 @@ class MMConfig { MMConfig() = default; // for std::vector // `mr` is the number of A rows per call to `MMKernel::LoopKC`. // `MMOrder` is how to parallelize the outer loops. - // `MMOut` is how/whether to parallelize filling the C result. // `inner_tasks` chooses the within-cluster task granularity in `ForNP`. MMConfig(size_t K, size_t N, size_t mr, size_t mc, size_t kc, size_t nc, - size_t kc_multiple, size_t nc_multiple, MMOrder order, MMOut out, + size_t kc_multiple, size_t nc_multiple, MMOrder order, int inner_tasks) : mr_(static_cast(mr)), mc_(static_cast(mc)), @@ -417,7 +377,6 @@ class MMConfig { nc_multiple_(static_cast(nc_multiple)), kc_multiple_(static_cast(kc_multiple)), order_(order), - out_(out), inner_tasks_(static_cast(inner_tasks)), reserved_{} { HWY_DASSERT(mr == 1 || mr == 2 || mr == 4); @@ -433,7 +392,6 @@ class MMConfig { HWY_WARN("nc %zu not a multiple of nc_multiple %zu", nc, nc_multiple); } HWY_DASSERT(StringFromOrder(order_) != nullptr); - HWY_DASSERT(StringFromOut(out_) != nullptr); HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); } @@ -450,7 +408,6 @@ class MMConfig { } MMOrder Order() const { return order_; } - MMOut Out() const { return out_; } // No `OuterTasks` because static partitioning across clusters is sufficient. size_t InnerTasks() const { return static_cast(inner_tasks_); } @@ -469,9 +426,8 @@ class MMConfig { uint32_t nc_multiple_; uint32_t kc_multiple_; MMOrder order_; - MMOut out_; uint8_t inner_tasks_; - HWY_MAYBE_UNUSED uint8_t reserved_[5]; + HWY_MAYBE_UNUSED uint8_t reserved_[6]; }; static_assert(sizeof(MMConfig) == 32); // for faster indexing #pragma pack(pop) @@ -691,11 +647,10 @@ struct MatMulEnv { // Storage for arbitrary output rows, see `MatPtr::AllocateAndAttachRowPtrs`. // Most MatMul callers use strided MatPtr, but GemmaAttention::ComputeQKV // writes to differing KV positions per query / output row. - // The first three allocations are sufficient for any A, B, C, respectively, - // but also potentially overwritten by each MatMul. Subsequent entries are - // precomputed for tensors and not overwritten. Per-tensor allocations make - // it likelier that asan detects bugs such as use after free, overrun, and - // dangling references. + // The first entry is sufficient for any C argument, but also potentially + // overwritten by each MatMul. Subsequent entries are precomputed for tensors + // and not overwritten. Per-tensor allocations make it likelier that asan + // detects bugs such as use after free, overrun, and dangling references. std::vector> row_ptrs; }; @@ -703,20 +658,14 @@ struct MatMulEnv { // Reduces register pressure compared to individual values/references. struct MMArgs { MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale, - const float* HWY_RESTRICT add, const StridedViewD& partial) - : env(&env), - per_key(&per_key), - scale(scale), - add(add), - partial(partial) {} + const float* HWY_RESTRICT add) + : env(&env), per_key(&per_key), scale(scale), add(add) {} MatMulEnv* env; MMPerKey* per_key; double scale; const float* HWY_RESTRICT add; - // Same size as C, threads write at false-sharing-free granularity. - StridedViewD partial; }; // Wrapper over hwy::Zone that is only enabled when autotuning finished.