Remove F64 partial storage in matmul.

Also remove no longer used kMaxN; row_ptrs only used for C

PiperOrigin-RevId: 800774757
This commit is contained in:
Jan Wassenberg 2025-08-29 00:11:31 -07:00 committed by Copybara-Service
parent 31c09cca4c
commit 7288891439
4 changed files with 69 additions and 548 deletions

View File

@ -275,10 +275,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t cache_pos =
activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx);
env.row_ptrs[2][interleaved_idx] = reinterpret_cast<uint8_t*>(
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size);
}
kv_rows.AttachRowPtrs(env.row_ptrs[2].get());
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
/*add=*/nullptr, env, kv_rows);

View File

@ -71,23 +71,29 @@ static hn::VFromD<DF> FastPromoteOddTo(DF df, hn::VFromD<DBF> vbf) {
#endif
}
// Converts from float intermediate to MatMul output type `TC`.
template <class DC, class DF = hn::Rebind<float, DC>, HWY_IF_F32_D(DC)>
hn::Vec<DC> TCFromF32(DC /*dc*/, hn::Vec<DF> vf) {
// Converts from float intermediate to/from MatMul output type `TC`.
template <class DC, HWY_IF_F32_D(DC)>
hn::Vec<DC> TCFromF32(DC /*dc*/, hn::Vec<DC> vf) {
return vf;
}
template <class DC, class DF = hn::Rebind<float, DC>, HWY_IF_BF16_D(DC)>
hn::Vec<DC> TCFromF32(DC dc, hn::Vec<DF> vf) {
return hn::DemoteTo(dc, vf);
}
template <class DC, HWY_IF_F32_D(DC)>
hn::Vec<DC> F32FromTC(DC /*dc*/, hn::Vec<DC> vc) {
return vc;
}
template <class DC, class DF = hn::Rebind<float, DC>, HWY_IF_BF16_D(DC)>
hn::Vec<DF> F32FromTC(DC dc, hn::Vec<DC> vc) {
return hn::PromoteTo(DF(), vc);
}
// Tag classes, passed to `MMKernel::A2C0` to choose between writing one
// (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or writing the
// first kc result to partial, or accumulating the next kc result into partial
// via `MMAddHorizontalSumsIntoPartial`.
// (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or accumulating the
// next kc result into `C`.
struct MMSetC {};
struct MMSetPartial {};
struct MMAddPartial {};
struct MMAddC {};
// Stores horizontal sums of up to 16 vectors via transpose.
template <size_t kRowsAC, bool kAdd>
@ -143,10 +149,8 @@ class MMStoreHorizontalSumsIntoC {
sum3 = MaybeAdd<3>(d4, N, sum3, buf + kNR * lane);
}
const V4 vscale = hn::Set(d4, args.scale);
V4 vadd = hn::Zero(d4);
if constexpr (kAdd) {
vadd = hn::Load(d4, args.add + col_c);
}
HWY_ALIGN static constexpr float kZero[4] = {};
const V4 vadd = hn::Load(d4, args.add ? args.add + col_c : kZero);
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, C_rows, row_c, col_c);
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, C_rows, row_c, col_c);
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, C_rows, row_c, col_c);
@ -195,156 +199,16 @@ class MMStoreHorizontalSumsIntoC {
if constexpr (kRow < kRowsAC) {
TC* HWY_RESTRICT pos = C_rows[row_c + kRow] + col_c;
const hn::Rebind<TC, DF4> dc4;
if constexpr (kAdd) {
vadd = F32FromTC(dc4, hn::Load(dc4, pos)); // load prior value
} // else: add bias (only once, the first time we store to C)
const VF4 out = hn::MulAdd(sum, vscale, vadd);
hn::Store(TCFromF32(dc4, out), dc4, pos);
}
}
}; // MMStoreHorizontalSumsIntoC
// Accumulates horizontal sums of up to 16 vectors via transpose.
template <size_t kRowsAC, class Tag>
class MMAddHorizontalSumsIntoPartial {
public:
static_assert(kNR == 4); // for `StoreInterleaved4`
// Computes horizontal sums of `kRowsAC x kNR` vectors and accumulates
// into `partial` starting at `(row_c, col_c)`.
//
// `Crc` are the 16 combinations of an A row vector indexed by `r`, times a
// transposed B row vector indexed by `c`. Their elements are thus a subset
// of the terms of the dot product constituting the final `C[r, c]` result.
// Thus we compute the horizontal sums of each `Crc`. The elements may be
// permuted because we multiply bf16 via `ReorderWidenMulAccumulate`, but
// this does not change their horizontal sum.
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void operator()(DF df, //
VF F00, VF F01, VF F02, VF F03, //
VF F10, VF F11, VF F12, VF F13, //
VF F20, VF F21, VF F22, VF F23, //
VF F30, VF F31, VF F32, VF F33, //
const size_t row_c, const size_t col_c,
const StridedViewD& partial) const {
// We accumulate in 64-bit to avoid loss of precision.
static_assert(HWY_HAVE_FLOAT64, "Disable Armv7 NEON: we require fp64");
const hn::Repartition<double, DF> dd;
HWY_ALIGN double buf[16 * hn::MaxLanes(dd)];
using VD = hn::Vec<decltype(dd)>;
HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd);
VD C00 = SumOfPromotedPairs(dd, F00);
VD C01 = SumOfPromotedPairs(dd, F01);
VD C02 = SumOfPromotedPairs(dd, F02);
VD C03 = SumOfPromotedPairs(dd, F03);
VD C10 = SumOfPromotedPairs(dd, F10);
VD C11 = SumOfPromotedPairs(dd, F11);
VD C12 = SumOfPromotedPairs(dd, F12);
VD C13 = SumOfPromotedPairs(dd, F13);
VD C20 = SumOfPromotedPairs(dd, F20);
VD C21 = SumOfPromotedPairs(dd, F21);
VD C22 = SumOfPromotedPairs(dd, F22);
VD C23 = SumOfPromotedPairs(dd, F23);
VD C30 = SumOfPromotedPairs(dd, F30);
VD C31 = SumOfPromotedPairs(dd, F31);
VD C32 = SumOfPromotedPairs(dd, F32);
VD C33 = SumOfPromotedPairs(dd, F33);
// Horizontal reductions (`ReduceSum`) are rather expensive, entailing
// log(N) operations for vectors of length N. Because `kNR` == 4, we
// instead use `StoreInterleaved4` for a vector length-agnostic
// 'transpose': `buf[0, 4 * N)` holds `C00[0], C01[0], C02[0], C03[0],
// C00[1], C01[1], C02[1], C03[1] .. C00[N-1], C01[N-1], C02[N-1],
// C03[N-1]`.
MaybeStoreInterleaved4<0>(dd, ND, C00, C01, C02, C03, buf);
MaybeStoreInterleaved4<1>(dd, ND, C10, C11, C12, C13, buf);
MaybeStoreInterleaved4<2>(dd, ND, C20, C21, C22, C23, buf);
MaybeStoreInterleaved4<3>(dd, ND, C30, C31, C32, C33, buf);
// Adding N consecutive V4 yields horizontal sums of Cr0, Cr1, Cr2, Cr3 in
// the elements of one V4. We have four independent rows `r`, hence the
// code is effectively unrolled, which increases throughput.
const hn::CappedTag<double, kNR> d4;
using V4 = hn::Vec<decltype(d4)>;
// Store to four elements per row of `partial`.
// Loop is required because vectors may be smaller than 4*64 bits.
for (size_t c = 0; c < kNR; c += hn::Lanes(d4)) {
V4 sum0 = MaybeLoad<0>(d4, ND, buf + c);
V4 sum1 = MaybeLoad<1>(d4, ND, buf + c);
V4 sum2 = MaybeLoad<2>(d4, ND, buf + c);
V4 sum3 = MaybeLoad<3>(d4, ND, buf + c);
for (size_t lane = 1; lane < ND; ++lane) {
sum0 = MaybeAdd<0>(d4, ND, sum0, buf + c + kNR * lane);
sum1 = MaybeAdd<1>(d4, ND, sum1, buf + c + kNR * lane);
sum2 = MaybeAdd<2>(d4, ND, sum2, buf + c + kNR * lane);
sum3 = MaybeAdd<3>(d4, ND, sum3, buf + c + kNR * lane);
}
MaybeAddStore<0>(d4, sum0, partial, row_c, col_c + c);
MaybeAddStore<1>(d4, sum1, partial, row_c, col_c + c);
MaybeAddStore<2>(d4, sum2, partial, row_c, col_c + c);
MaybeAddStore<3>(d4, sum3, partial, row_c, col_c + c);
}
}
private:
// Converts lanes to double and adds pairs of them to obtain a vector with the
// same horizontal sum, but element type double.
template <class DD, class VD = hn::Vec<DD>,
class DF = hn::Repartition<float, DD>, class VF = hn::Vec<DF>>
static HWY_INLINE VD SumOfPromotedPairs(DD dd, VF f) {
// TODO: SVE could PromoteEvenTo.
const VD d0 = hn::PromoteLowerTo(dd, f);
const VD d1 = hn::PromoteUpperTo(dd, f);
return hn::Add(d0, d1);
}
// These helper functions hoist if() out of the main code below. They have
// no effect if kRow >= kRowsAC.
template <size_t kRow, class DD, class VD = hn::Vec<DD>>
static HWY_INLINE void MaybeStoreInterleaved4(DD dd, size_t N, VD Cr0, VD Cr1,
VD Cr2, VD Cr3,
double* HWY_RESTRICT buf) {
if constexpr (kRow < kRowsAC) {
hn::StoreInterleaved4(Cr0, Cr1, Cr2, Cr3, dd, buf + 4 * kRow * N);
}
}
// Note: N is the number of lanes in the StoreInterleaved4 vectors, not V4.
template <size_t kRow, class D4, class V4 = hn::Vec<D4>>
static HWY_INLINE V4 MaybeLoad(D4 d4, size_t N,
const double* HWY_RESTRICT buf) {
if constexpr (kRow < kRowsAC) {
return hn::Load(d4, buf + 4 * kRow * N);
} else {
return hn::Zero(d4);
}
}
template <size_t kRow, class D4, class V4 = hn::Vec<D4>>
static HWY_INLINE V4 MaybeAdd(D4 d4, size_t N, V4 sum,
const double* HWY_RESTRICT buf) {
if constexpr (kRow < kRowsAC) {
return hn::Add(sum, hn::Load(d4, buf + 4 * kRow * N));
} else {
return sum;
}
}
template <size_t kRow, class D4, class V4 = hn::Vec<D4>>
static HWY_INLINE void MaybeAddStore(D4 d4, V4 sum,
const StridedViewD& partial,
const size_t row_c, const size_t col_c) {
if constexpr (kRow < kRowsAC) {
double* HWY_RESTRICT pos = partial.Row(row_c + kRow) + col_c;
if constexpr (hwy::IsSame<Tag, MMSetPartial>()) {
hn::Store(sum, d4, pos);
} else {
static_assert(hwy::IsSame<Tag, MMAddPartial>());
const V4 prev = hn::Load(d4, pos);
hn::Store(hn::Add(sum, prev), d4, pos);
}
}
}
}; // MMAddHorizontalSumsIntoPartial
// Stateless, wraps member functions.
class MMKernel {
public:
@ -865,247 +729,18 @@ class MMKernel {
// called frequently, so do not add a profiler zone.
if constexpr (hwy::IsSame<Tag, MMSetC>()) {
if (args.add) {
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/true>()(
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
C31, C32, C33, row_ac, col_c, args, C_rows);
} else {
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/false>()(
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
C31, C32, C33, row_ac, col_c, args, C_rows);
}
} else {
MMAddHorizontalSumsIntoPartial<kRowsAC, Tag>()(
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/false>()(
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
C31, C32, C33, row_ac, col_c, args.partial);
C31, C32, C33, row_ac, col_c, args, C_rows);
} else {
static_assert(hwy::IsSame<Tag, MMAddC>());
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/true>()(
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
C31, C32, C33, row_ac, col_c, args, C_rows);
}
}
};
// Multiply partial by scale, add bias if present, demote and store to f32 `C`.
// Stateless, wraps member functions.
class MMScaleDemoteAdd {
public:
// Fills the `range_mc/range_nc` region of `outputs.C` by multiplying the
// same region of `outputs.partial` by `outputs.scale`, which is the product
// of the scales of A and B, demoting from f64 to f32, then if `outputs.add`
// is nonzero, adding it to each row.
// TODO: fuse with subsequent operations - function pointer?
// Although this region in `outputs.C` is not touched again, streaming stores
// do not help on SKX and Zen4. TODO: re-check this.
template <typename TC>
static HWY_INLINE void FillC(const IndexRange& range_mc,
const IndexRange& range_nc, const MMArgs& args,
RowPtrs<TC> C_rows) {
size_t row_c = range_mc.begin();
if (args.add) {
constexpr bool kAdd = true;
if (range_mc.Num() >= 4) {
for (; row_c <= range_mc.end() - 4; row_c += 4) {
Do4Rows<kAdd>(row_c, range_nc, args, C_rows);
}
}
for (; row_c < range_mc.end(); ++row_c) {
Do1Row<kAdd>(row_c, range_nc, args, C_rows);
}
} else {
constexpr bool kAdd = false;
if (range_mc.Num() >= 4) {
for (; row_c <= range_mc.end() - 4; row_c += 4) {
Do4Rows<kAdd>(row_c, range_nc, args, C_rows);
}
}
for (; row_c < range_mc.end(); ++row_c) {
Do1Row<kAdd>(row_c, range_nc, args, C_rows);
}
}
}
private:
// Unrolled for 4 rows to reduce the number of loads from `add`.
template <bool kAdd, typename TC>
static HWY_INLINE void Do4Rows(size_t row_c, const IndexRange& range_nc,
const MMArgs& args, RowPtrs<TC> C_rows) {
const hn::ScalableTag<double> dd;
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
const hn::Rebind<TC, decltype(dd)> dc;
using VD = hn::Vec<decltype(dd)>;
using VF = hn::Vec<decltype(df)>;
HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd);
const VD vscale = hn::Set(dd, args.scale);
const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0);
const double* HWY_RESTRICT pr1 = args.partial.Row(row_c + 1);
const double* HWY_RESTRICT pr2 = args.partial.Row(row_c + 2);
const double* HWY_RESTRICT pr3 = args.partial.Row(row_c + 3);
TC* HWY_RESTRICT cr0 = C_rows[row_c + 0];
TC* HWY_RESTRICT cr1 = C_rows[row_c + 1];
TC* HWY_RESTRICT cr2 = C_rows[row_c + 2];
TC* HWY_RESTRICT cr3 = C_rows[row_c + 3];
// We manually unroll 2x for higher IPC in batch=1.
size_t col_c = range_nc.begin();
if (HWY_LIKELY(range_nc.Num() >= 2 * ND)) {
for (; col_c <= range_nc.end() - 2 * ND; col_c += 2 * ND) {
VD a0, a1; // unused if !kAdd
if constexpr (kAdd) {
// Promoting to double lets us fuse the Add into MulAdd.
a0 = hn::PromoteTo(dd, hn::Load(df, args.add + col_c));
a1 = hn::PromoteTo(dd, hn::Load(df, args.add + col_c + ND));
}
const VD d00 = hn::Load(dd, pr0 + col_c);
const VD d01 = hn::Load(dd, pr0 + col_c + ND);
const VD d10 = hn::Load(dd, pr1 + col_c);
const VD d11 = hn::Load(dd, pr1 + col_c + ND);
const VD d20 = hn::Load(dd, pr2 + col_c);
const VD d21 = hn::Load(dd, pr2 + col_c + ND);
const VD d30 = hn::Load(dd, pr3 + col_c);
const VD d31 = hn::Load(dd, pr3 + col_c + ND);
VD m00, m01, m10, m11, m20, m21, m30, m31;
if constexpr (kAdd) {
m00 = hn::MulAdd(d00, vscale, a0);
m01 = hn::MulAdd(d01, vscale, a1);
m10 = hn::MulAdd(d10, vscale, a0);
m11 = hn::MulAdd(d11, vscale, a1);
m20 = hn::MulAdd(d20, vscale, a0);
m21 = hn::MulAdd(d21, vscale, a1);
m30 = hn::MulAdd(d30, vscale, a0);
m31 = hn::MulAdd(d31, vscale, a1);
} else {
m00 = hn::Mul(d00, vscale);
m01 = hn::Mul(d01, vscale);
m10 = hn::Mul(d10, vscale);
m11 = hn::Mul(d11, vscale);
m20 = hn::Mul(d20, vscale);
m21 = hn::Mul(d21, vscale);
m30 = hn::Mul(d30, vscale);
m31 = hn::Mul(d31, vscale);
}
// First convert f64 to f32.
const VF f00 = hn::DemoteTo(df, m00);
const VF f01 = hn::DemoteTo(df, m01);
const VF f10 = hn::DemoteTo(df, m10);
const VF f11 = hn::DemoteTo(df, m11);
const VF f20 = hn::DemoteTo(df, m20);
const VF f21 = hn::DemoteTo(df, m21);
const VF f30 = hn::DemoteTo(df, m30);
const VF f31 = hn::DemoteTo(df, m31);
// Note that Stream is neutral on SKX and harmful on Zen4.
hn::Store(TCFromF32(dc, f00), dc, cr0 + col_c);
hn::Store(TCFromF32(dc, f01), dc, cr0 + col_c + ND);
hn::Store(TCFromF32(dc, f10), dc, cr1 + col_c);
hn::Store(TCFromF32(dc, f11), dc, cr1 + col_c + ND);
hn::Store(TCFromF32(dc, f20), dc, cr2 + col_c);
hn::Store(TCFromF32(dc, f21), dc, cr2 + col_c + ND);
hn::Store(TCFromF32(dc, f30), dc, cr3 + col_c);
hn::Store(TCFromF32(dc, f31), dc, cr3 + col_c + ND);
}
}
for (; col_c < range_nc.end(); col_c += ND) {
const size_t remaining = range_nc.end() - col_c;
HWY_DASSERT(remaining < 2 * ND);
VD a0; // unused if !kAdd
if constexpr (kAdd) {
// Promoting to double lets us fuse the Add into MulAdd.
a0 = hn::PromoteTo(dd, hn::LoadN(df, args.add + col_c, remaining));
}
const VD d00 = hn::LoadN(dd, pr0 + col_c, remaining);
const VD d10 = hn::LoadN(dd, pr1 + col_c, remaining);
const VD d20 = hn::LoadN(dd, pr2 + col_c, remaining);
const VD d30 = hn::LoadN(dd, pr3 + col_c, remaining);
VD m00, m10, m20, m30;
if constexpr (kAdd) {
m00 = hn::MulAdd(d00, vscale, a0);
m10 = hn::MulAdd(d10, vscale, a0);
m20 = hn::MulAdd(d20, vscale, a0);
m30 = hn::MulAdd(d30, vscale, a0);
} else {
m00 = hn::Mul(d00, vscale);
m10 = hn::Mul(d10, vscale);
m20 = hn::Mul(d20, vscale);
m30 = hn::Mul(d30, vscale);
}
// First convert f64 to f32.
const VF f00 = hn::DemoteTo(df, m00);
const VF f10 = hn::DemoteTo(df, m10);
const VF f20 = hn::DemoteTo(df, m20);
const VF f30 = hn::DemoteTo(df, m30);
hn::StoreN(TCFromF32(dc, f00), dc, cr0 + col_c, remaining);
hn::StoreN(TCFromF32(dc, f10), dc, cr1 + col_c, remaining);
hn::StoreN(TCFromF32(dc, f20), dc, cr2 + col_c, remaining);
hn::StoreN(TCFromF32(dc, f30), dc, cr3 + col_c, remaining);
}
}
// Same as above but handles a single row (for remainder rows).
template <bool kAdd, typename TC>
static HWY_INLINE void Do1Row(size_t row_c, const IndexRange& range_nc,
const MMArgs& args, RowPtrs<TC> C_rows) {
const hn::ScalableTag<double> dd;
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
const hn::Rebind<TC, decltype(dd)> dc;
using VD = hn::Vec<decltype(dd)>;
using VF = hn::Vec<decltype(df)>;
HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd);
const VD vscale = hn::Set(dd, args.scale);
const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0);
TC* HWY_RESTRICT cr0 = C_rows[row_c + 0];
// We manually unroll 2x for higher IPC in batch=1.
size_t col_c = range_nc.begin();
if (HWY_LIKELY(range_nc.Num() >= 2 * ND)) {
for (; col_c <= range_nc.end() - 2 * ND; col_c += 2 * ND) {
VD a0, a1; // unused if !kAdd
if constexpr (kAdd) {
// Promoting to double lets us fuse the Add into MulAdd.
a0 = hn::PromoteTo(dd, hn::Load(df, args.add + col_c));
a1 = hn::PromoteTo(dd, hn::Load(df, args.add + col_c + ND));
}
const VD d00 = hn::Load(dd, pr0 + col_c);
const VD d01 = hn::Load(dd, pr0 + col_c + ND);
VD m00, m01;
if constexpr (kAdd) {
m00 = hn::MulAdd(d00, vscale, a0);
m01 = hn::MulAdd(d01, vscale, a1);
} else {
m00 = hn::Mul(d00, vscale);
m01 = hn::Mul(d01, vscale);
}
// First convert f64 to f32.
const VF f00 = hn::DemoteTo(df, m00);
const VF f01 = hn::DemoteTo(df, m01);
// Note that Stream is neutral on SKX and harmful on Zen4.
hn::Store(TCFromF32(dc, f00), dc, cr0 + col_c);
hn::Store(TCFromF32(dc, f01), dc, cr0 + col_c + ND);
}
}
for (; col_c < range_nc.end(); col_c += ND) {
const size_t remaining = range_nc.end() - col_c;
HWY_DASSERT(remaining < 2 * ND);
VD a0; // unused if !kAdd
if constexpr (kAdd) {
// Promoting to double lets us fuse the Add into MulAdd.
a0 = hn::PromoteTo(dd, hn::LoadN(df, args.add + col_c, remaining));
}
const VD d00 = hn::LoadN(dd, pr0 + col_c, remaining);
VD m00;
if constexpr (kAdd) {
m00 = hn::MulAdd(d00, vscale, a0);
} else {
m00 = hn::Mul(d00, vscale);
}
// First convert f64 to f32.
const VF f00 = hn::DemoteTo(df, m00);
hn::StoreN(TCFromF32(dc, f00), dc, cr0 + col_c, remaining);
}
}
}; // MMScaleDemoteAdd
// Called on the main thread with the entire N range, or by each package with
// a static partition of N. This class contains several variants of the
// outer M/N/K loops, and calls `A2C0` which loops over the inner KC and MC.
@ -1132,7 +767,6 @@ class MMPerPackage {
ranges_nc_(config.RangesOfNC(range_np)),
order_(config.Order()),
inner_tasks_(config.InnerTasks()),
out_(config.Out()),
line_bytes_(args.env->ctx.allocator.LineBytes()) {}
// The size of `A` that will actually be used, for purposes of choosing the
@ -1244,11 +878,9 @@ class MMPerPackage {
MMSetC(), args_, C_rows);
}
});
HWY_DASSERT(out_ == MMOut::kDirect); // already filled C
}
// Single M range, parallel N, sequential K. Fills all of partial.
// Single M range, parallel N, sequential K. Sets C, then accumulates.
template <typename TA, typename TB, typename TC>
HWY_INLINE void DoNT_K(const StridedView<TA> A, const bool A_padded,
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
@ -1288,32 +920,12 @@ class MMPerPackage {
// Peel off the first iteration of the kc loop: avoid
// zero-initializing `partial` by writing into it.
ranges_kc_.VisitFirst([&](const IndexRange& range_kc) {
loop_nc(B_storage, range_kc, range_nc, MMSetPartial());
loop_nc(B_storage, range_kc, range_nc, MMSetC());
});
ranges_kc_.VisitRemaining([&](const IndexRange& range_kc) {
loop_nc(B_storage, range_kc, range_nc, MMAddPartial());
loop_nc(B_storage, range_kc, range_nc, MMAddC());
});
});
if (out_ == MMOut::kCopy) {
static const auto zone =
args_.env->ctx.profiler.AddZone("MM.NT_K.FillC.Copy");
MMZone fill_zone;
fill_zone.MaybeEnter(0, zone, args_);
MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C_rows);
} else if (out_ == MMOut::kParM) {
static const auto zone =
args_.env->ctx.profiler.AddZone("MM.NT_K.FillC.ParM");
args_.env->parallel.ForRangeMC(
range_mc, pkg_idx_, [&](size_t row_a, size_t worker) HWY_ATTR {
MMZone fill_zone;
fill_zone.MaybeEnter(worker, zone, args_);
MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_,
args_, C_rows);
});
} else {
HWY_UNREACHABLE; // kDirect is only used with kNT.
}
}
// Parallel loops over mc/nc blocks of M/range_np, single K.
@ -1343,14 +955,12 @@ class MMPerPackage {
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) {
StridedViewBF B_view =
const StridedViewBF B_view =
DecompressB(B, row_b, range_K, B_storage_view);
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, K,
MMSetC(), args_, C_rows);
}
});
HWY_DASSERT(out_ == MMOut::kDirect); // already filled C
}
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
@ -1359,8 +969,6 @@ class MMPerPackage {
HWY_INLINE void DoNT_MT_K(const StridedView<TA> A, const bool A_padded,
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K");
static const auto fill_zone =
args_.env->ctx.profiler.AddZone("MM.NT_MT_K.FillC");
const size_t kc_max = ranges_kc_.TaskSize();
HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
const size_t B_stride =
@ -1395,22 +1003,13 @@ class MMPerPackage {
const StridedViewBF B_storage_view(B_storage, kc_max, B_stride);
// Peel off the first iteration of the kc loop: avoid
// zero-initializing `partial` by writing into it.
// zero-initializing `C` by writing into it.
ranges_kc_.VisitFirst([&](const IndexRange& range_kc) {
loop_nc(B_storage_view, range_mc, range_kc, range_nc,
MMSetPartial());
loop_nc(B_storage_view, range_mc, range_kc, range_nc, MMSetC());
});
ranges_kc_.VisitRemaining([&](const IndexRange& range_kc) {
loop_nc(B_storage_view, range_mc, range_kc, range_nc,
MMAddPartial());
loop_nc(B_storage_view, range_mc, range_kc, range_nc, MMAddC());
});
// Already in parallel section, hence no `kParM`, and
// `kDirect` is only used with `kNT_MT`.
HWY_DASSERT(out_ == MMOut::kCopy);
MMZone fill_mm_zone;
fill_mm_zone.MaybeEnter(worker, fill_zone, args_);
MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C_rows);
});
}
@ -1559,7 +1158,6 @@ class MMPerPackage {
const IndexRangePartition ranges_nc_;
const MMOrder order_;
const size_t inner_tasks_;
const MMOut out_;
const size_t line_bytes_;
}; // MMPerPackage
@ -1632,7 +1230,7 @@ template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
MatPtrT<TC>& C) {
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[2]);
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[0]);
const Allocator& allocator = env.ctx.allocator;
const size_t M = A.Rows();
@ -1659,7 +1257,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
MMAutoTune<MMConfig>& tuner = per_key.autotune;
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
add, env.storage.Partial());
add);
if (HWY_LIKELY(tuner.Best())) {
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best());
return &per_key;
@ -1673,7 +1271,6 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
HWY_ASSERT(K == B.Cols());
HWY_ASSERT(M <= MMStorage::kMaxM);
HWY_ASSERT(K <= MMStorage::kMaxK);
HWY_ASSERT(N <= MMStorage::kMaxN);
HWY_ASSERT(N % kNR == 0);
tuner.SetCandidates(
@ -1690,10 +1287,9 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
hwy::platform::InvariantTicksPerSecond();
const double flops = 2 * M * K * N / min_elapsed; // * 2 for FMA
if (HWY_UNLIKELY(env.print_measurement && tuner.ShouldPrint())) {
fprintf(stderr, "%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%s\n", flops * 1E-9,
fprintf(stderr, "%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu\n", flops * 1E-9,
min_elapsed * 1E3, cfg.MR(), cfg.MC(), cfg.KC(), cfg.NC(),
StringFromOrder(cfg.Order()), cfg.InnerTasks(),
StringFromOut(cfg.Out()));
StringFromOrder(cfg.Order()), cfg.InnerTasks());
}
if (HWY_UNLIKELY(env.print_best && tuner.Best())) {
const auto ratio = [per_key](uint64_t ticks) -> double {
@ -1702,11 +1298,11 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
};
const MMConfig& best = *tuner.Best();
fprintf(stderr,
"\n%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%s,%.2f,%.2f\n",
M, K, N, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(),
"\n%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%.2f,%.2f\n", M,
K, N, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(),
best.KC(), best.NC(), StringFromOrder(best.Order()),
best.InnerTasks(), StringFromOut(best.Out()),
ratio(tuner.WorstMinTicks()), ratio(tuner.FirstConfigTicks()));
best.InnerTasks(), ratio(tuner.WorstMinTicks()),
ratio(tuner.FirstConfigTicks()));
}
return &per_key;

View File

@ -91,24 +91,21 @@ class GenerateCandidates {
for (size_t mr : MR()) {
for (MMOrder order : Orders(mr)) {
const std::vector<int>& all_inner_tasks = InnerTasks(order);
const std::vector<MMOut>& all_outs = Outs(order);
for (size_t kc : KC(mr, order)) {
for (size_t mc : MC(mr, kc, order)) {
for (size_t nc : NC(mr, mc, kc, order)) {
for (int inner_tasks : all_inner_tasks) {
for (MMOut out : all_outs) {
const MMConfig config(K_, N_, mr, mc, kc, nc, kc_multiple_,
nc_multiple_, order, out, inner_tasks);
const size_t M_tasks = config.RangesOfMC(M_).NumTasks();
const size_t K_tasks = config.RangesOfKC(K_).NumTasks();
const MMConfig config(K_, N_, mr, mc, kc, nc, kc_multiple_,
nc_multiple_, order, inner_tasks);
const size_t M_tasks = config.RangesOfMC(M_).NumTasks();
const size_t K_tasks = config.RangesOfKC(K_).NumTasks();
// Blocks only make sense when there are multiple M tasks.
if (IsBlock(order) != (M_tasks > 1)) continue;
// Single KC only makes sense when there is a single K task.
if (IsOneKC(order) != (K_tasks == 1)) continue;
// Blocks only make sense when there are multiple M tasks.
if (IsBlock(order) != (M_tasks > 1)) continue;
// Single KC only makes sense when there is a single K task.
if (IsOneKC(order) != (K_tasks == 1)) continue;
candidates.push_back(config);
}
candidates.push_back(config);
}
}
}
@ -265,14 +262,13 @@ class GenerateCandidates {
SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const {
const size_t np_max = ranges_np_.TaskSize();
size_t nc_max = np_max;
const size_t out_bytes = IsOneKC(order) ? sizeof_TC_ : sizeof(double);
// Only if there will be reuse of B: choose the largest `nc_max` (C cols)
// such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3.
// Otherwise, leave it unbounded.
if (M_ > mr) {
const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * out_bytes);
nc_max = hwy::DivCeil(allocator_.L3Bytes(), bytes_per_nc);
nc_max = HWY_MIN(HWY_MIN(nc_max, MMStorage::kMaxN), np_max);
const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * sizeof_TC_);
nc_max =
HWY_MIN(hwy::DivCeil(allocator_.L3Bytes(), bytes_per_nc), np_max);
}
HWY_DASSERT(nc_max != 0);
nc_max = RoundDownWithFloor(nc_max, nc_multiple_);
@ -340,24 +336,6 @@ class GenerateCandidates {
return inner_tasks;
}
// Whether to parallelize FillC or enable direct writes to C.
std::vector<MMOut> Outs(MMOrder order) const {
std::vector<MMOut> outs;
for (size_t out_idx = 0;; ++out_idx) {
const MMOut out = static_cast<MMOut>(out_idx);
if (StringFromOut(out) == nullptr) return outs; // done
// kParM only makes sense if we have more than one row of A.
if (out == MMOut::kParM && M_ == 1) continue;
// Blocks are already parallelized.
if (out == MMOut::kParM && IsBlock(order)) continue;
// Direct only works for a single kc range.
if ((out == MMOut::kDirect) != IsOneKC(order)) continue;
// For non-block, kCopy does not beat kDirect.
if (out == MMOut::kCopy && IsOneKC(order) && !IsBlock(order)) continue;
outs.push_back(out);
}
}
const Allocator& allocator_;
const size_t M_;
const size_t K_;
@ -432,8 +410,6 @@ MatMulEnv::MatMulEnv(ThreadingContext& ctx)
char cpu100[100];
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxM)); // A
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxN)); // B
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxM)); // C
}
@ -461,7 +437,7 @@ void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
}
}
// C is BF16/float, or double for partial
// C is BF16/float
void BindC(MatPtr& C, MMParallel& parallel) {
Allocator& allocator = parallel.allocator();
if (!allocator.ShouldBind()) return;

View File

@ -43,7 +43,7 @@ namespace gcpp {
// at least the product of the FMA latency (3..5) times the throughput (2).
// This and `mr` are limited by the number of registers, which is generally
// 32 but 16 for AVX2. `kNR` == 4 enables the `StoreInterleaved4` transpose in
// `MMAddHorizontalSumsIntoPartial`. We ensure `C.Cols() % kNR == 0`.
// `MMStoreHorizontalSumsIntoC`. We ensure `C.Cols() % kNR == 0`.
constexpr size_t kNR = 4;
// Choosing `kMaxMR == kNR` minimizes the ratio of loads to FMA, because
@ -195,7 +195,7 @@ class MMParallel {
};
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel);
// C is BF16/float, or double for partial.
// C is BF16/float.
void BindC(MatPtr& C, MMParallel& parallel);
// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows.
@ -236,8 +236,7 @@ class StridedView {
using StridedViewBF = StridedView<BF16>;
using StridedViewD = StridedView<double>;
// Per-package storage for packed A, and one global C-shaped `partial` for
// accumulating partial dot products (sections of K).
// Per-package storage for packed A.
class MMStorage {
public:
// Compile-time bounds on matrix dimensions to enable pre-allocating storage
@ -245,21 +244,13 @@ class MMStorage {
// per package and 512 MiB, respectively.
static constexpr size_t kMaxM = 4096;
static constexpr size_t kMaxK = 64 * 1024;
static constexpr size_t kMaxN = 256 * 1024;
// Upper bound for per-worker B storage on the stack. Chosen such that one row
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
static constexpr size_t kMaxKC = 8 * 1024;
// Internally threaded; must not be called concurrently with the same
// `ThreadingContext` (used via `parallel`).
MMStorage(const Allocator& allocator, MMParallel& parallel)
: // Per-worker copies of `partial` would be wasteful. We instead
// allocate one instance of the maximum matrix extents because threads
// write at false-sharing-free granularity.
partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN), allocator,
MatPadding::kOdd),
// Same stride independent of the actual C.Cols() so we can pre-bind.
partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) {
MMStorage(const Allocator& allocator, MMParallel& parallel) {
// Per-package allocation so each can decompress A into its own copy.
// Must be padded, see `DoDecompressA`.
parallel.ForPkg(kMaxPackages, [&](size_t pkg_idx) {
@ -276,9 +267,6 @@ class MMStorage {
}
}
});
// Avoid cross-package accesses.
BindC(partial_storage_, parallel);
}
// Returns per-package matrix view. Converting A=F32 to BF16 up-front is
@ -291,12 +279,8 @@ class MMStorage {
extents.cols, pkg_A_[pkg_idx]->Stride());
}
StridedViewD Partial() const { return partial_; }
private:
std::unique_ptr<MatStorageT<BF16>> pkg_A_[kMaxPackages];
MatStorageT<double> partial_storage_;
StridedViewD partial_;
};
//------------------------------------------------------------------------------
@ -349,29 +333,6 @@ static inline const char* StringFromOrder(MMOrder order) {
}
}
// How/where to write the A2C0 result. This determines the `tag` argument to
// that function, which governs whether we call `MMStoreHorizontalSumsIntoC` or
// `MMAddHorizontalSumsIntoPartial`.
enum class MMOut : uint8_t {
kCopy, // accumulate into partial, scale/add to C
kDirect, // single kc task, write directly to C
kParM // kCopy but parallel over M
// kParN is not better on SKX/Zen4.
};
static inline const char* StringFromOut(MMOut out) {
switch (out) {
case MMOut::kDirect:
return "Direct";
case MMOut::kCopy:
return "Copy";
case MMOut::kParM:
return "ParM";
default:
return nullptr;
}
}
// How to parallelize the per-package `DecompressA`. To reduce combinatorial
// explosion, we tune this separately from `MMConfig`.
enum class MMParA : uint8_t { kNone, kK1 = 1, kK2 = 2, kK4 = 4, kM };
@ -405,10 +366,9 @@ class MMConfig {
MMConfig() = default; // for std::vector
// `mr` is the number of A rows per call to `MMKernel::LoopKC`.
// `MMOrder` is how to parallelize the outer loops.
// `MMOut` is how/whether to parallelize filling the C result.
// `inner_tasks` chooses the within-cluster task granularity in `ForNP`.
MMConfig(size_t K, size_t N, size_t mr, size_t mc, size_t kc, size_t nc,
size_t kc_multiple, size_t nc_multiple, MMOrder order, MMOut out,
size_t kc_multiple, size_t nc_multiple, MMOrder order,
int inner_tasks)
: mr_(static_cast<uint32_t>(mr)),
mc_(static_cast<uint32_t>(mc)),
@ -417,7 +377,6 @@ class MMConfig {
nc_multiple_(static_cast<uint32_t>(nc_multiple)),
kc_multiple_(static_cast<uint32_t>(kc_multiple)),
order_(order),
out_(out),
inner_tasks_(static_cast<uint8_t>(inner_tasks)),
reserved_{} {
HWY_DASSERT(mr == 1 || mr == 2 || mr == 4);
@ -433,7 +392,6 @@ class MMConfig {
HWY_WARN("nc %zu not a multiple of nc_multiple %zu", nc, nc_multiple);
}
HWY_DASSERT(StringFromOrder(order_) != nullptr);
HWY_DASSERT(StringFromOut(out_) != nullptr);
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
}
@ -450,7 +408,6 @@ class MMConfig {
}
MMOrder Order() const { return order_; }
MMOut Out() const { return out_; }
// No `OuterTasks` because static partitioning across clusters is sufficient.
size_t InnerTasks() const { return static_cast<size_t>(inner_tasks_); }
@ -469,9 +426,8 @@ class MMConfig {
uint32_t nc_multiple_;
uint32_t kc_multiple_;
MMOrder order_;
MMOut out_;
uint8_t inner_tasks_;
HWY_MAYBE_UNUSED uint8_t reserved_[5];
HWY_MAYBE_UNUSED uint8_t reserved_[6];
};
static_assert(sizeof(MMConfig) == 32); // for faster indexing
#pragma pack(pop)
@ -691,11 +647,10 @@ struct MatMulEnv {
// Storage for arbitrary output rows, see `MatPtr::AllocateAndAttachRowPtrs`.
// Most MatMul callers use strided MatPtr, but GemmaAttention::ComputeQKV
// writes to differing KV positions per query / output row.
// The first three allocations are sufficient for any A, B, C, respectively,
// but also potentially overwritten by each MatMul. Subsequent entries are
// precomputed for tensors and not overwritten. Per-tensor allocations make
// it likelier that asan detects bugs such as use after free, overrun, and
// dangling references.
// The first entry is sufficient for any C argument, but also potentially
// overwritten by each MatMul. Subsequent entries are precomputed for tensors
// and not overwritten. Per-tensor allocations make it likelier that asan
// detects bugs such as use after free, overrun, and dangling references.
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
};
@ -703,20 +658,14 @@ struct MatMulEnv {
// Reduces register pressure compared to individual values/references.
struct MMArgs {
MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale,
const float* HWY_RESTRICT add, const StridedViewD& partial)
: env(&env),
per_key(&per_key),
scale(scale),
add(add),
partial(partial) {}
const float* HWY_RESTRICT add)
: env(&env), per_key(&per_key), scale(scale), add(add) {}
MatMulEnv* env;
MMPerKey* per_key;
double scale;
const float* HWY_RESTRICT add;
// Same size as C, threads write at false-sharing-free granularity.
StridedViewD partial;
};
// Wrapper over hwy::Zone that is only enabled when autotuning finished.