mirror of https://github.com/google/gemma.cpp.git
Add support for arbitrary output row pointers
Useful for writing directly to KV cache. PiperOrigin-RevId: 765615147
This commit is contained in:
parent
9c3e089b09
commit
0023ff8770
151
ops/matmul-inl.h
151
ops/matmul-inl.h
|
|
@ -80,6 +80,19 @@ hn::Vec<DC> TCFromF32(DC dc, hn::Vec<DF> vf) {
|
|||
return hn::DemoteTo(dc, vf);
|
||||
}
|
||||
|
||||
template <typename TC>
|
||||
class CRows {
|
||||
public:
|
||||
CRows(uint8_t** C_rows) : C_rows_(C_rows) {}
|
||||
|
||||
TC* HWY_RESTRICT operator[](size_t row_idx) const {
|
||||
return HWY_RCAST_ALIGNED(TC*, C_rows_[row_idx]);
|
||||
}
|
||||
|
||||
private:
|
||||
uint8_t** C_rows_;
|
||||
};
|
||||
|
||||
// Tag classes, passed to `MMKernel::A2C0` to choose between writing one
|
||||
// (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or writing the
|
||||
// first kc result to partial, or accumulating the next kc result into partial
|
||||
|
|
@ -110,7 +123,7 @@ class MMStoreHorizontalSumsIntoC {
|
|||
VF C20, VF C21, VF C22, VF C23, //
|
||||
VF C30, VF C31, VF C32, VF C33, //
|
||||
const size_t row_c, const size_t col_c,
|
||||
const MMArgs& args, const RowPtr<TC>& C) const {
|
||||
const MMArgs& args, CRows<TC> C_rows) const {
|
||||
HWY_ALIGN float buf[16 * hn::MaxLanes(df)];
|
||||
const size_t N = hn::Lanes(df);
|
||||
// Horizontal reductions (`ReduceSum`) are rather expensive, entailing
|
||||
|
|
@ -146,10 +159,10 @@ class MMStoreHorizontalSumsIntoC {
|
|||
if constexpr (kAdd) {
|
||||
vadd = hn::Load(d4, args.add + col_c);
|
||||
}
|
||||
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, C, row_c, col_c);
|
||||
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, C, row_c, col_c);
|
||||
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, C, row_c, col_c);
|
||||
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, C, row_c, col_c);
|
||||
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, C_rows, row_c, col_c);
|
||||
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, C_rows, row_c, col_c);
|
||||
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, C_rows, row_c, col_c);
|
||||
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, C_rows, row_c, col_c);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -185,13 +198,14 @@ class MMStoreHorizontalSumsIntoC {
|
|||
}
|
||||
}
|
||||
|
||||
template <size_t kRow, typename TC, class DF4, class VF4 = hn::Vec<DF4>>
|
||||
template <size_t kRow, /*deduced:*/ class DF4, class VF4 = hn::Vec<DF4>,
|
||||
typename TC>
|
||||
static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale,
|
||||
VF4 vadd, const RowPtr<TC>& C,
|
||||
VF4 vadd, CRows<TC> C_rows,
|
||||
const size_t row_c,
|
||||
const size_t col_c) {
|
||||
if constexpr (kRow < kRowsAC) {
|
||||
TC* HWY_RESTRICT pos = C.Row(row_c + kRow) + col_c;
|
||||
TC* HWY_RESTRICT pos = C_rows[row_c + kRow] + col_c;
|
||||
const hn::Rebind<TC, DF4> dc4;
|
||||
const VF4 out = hn::MulAdd(sum, vscale, vadd);
|
||||
hn::Store(TCFromF32(dc4, out), dc4, pos);
|
||||
|
|
@ -359,7 +373,7 @@ class MMKernel {
|
|||
static HWY_INLINE void A2C0(const RowPtrBF& A_view, const RowPtrBF& B_view,
|
||||
size_t mr, const IndexRange& range_mc,
|
||||
const size_t row_b, size_t kc, Tag tag,
|
||||
const MMArgs& args, const RowPtr<TC>& C) {
|
||||
const MMArgs& args, CRows<TC> C_rows) {
|
||||
HWY_DASSERT(1 <= mr && mr <= kMaxMR);
|
||||
const size_t row0 = range_mc.begin();
|
||||
const size_t mc = range_mc.Num();
|
||||
|
|
@ -368,7 +382,8 @@ class MMKernel {
|
|||
// M == 1, or x86 with 8 SIMD registers:
|
||||
if (HWY_UNLIKELY(mr == 1)) {
|
||||
for (; imc < mc; ++imc) {
|
||||
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
|
||||
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
||||
C_rows);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
@ -377,11 +392,13 @@ class MMKernel {
|
|||
if (HWY_UNLIKELY(mr == 2)) {
|
||||
if (HWY_LIKELY(mc >= 2)) {
|
||||
for (; imc <= mc - 2; imc += 2) {
|
||||
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
|
||||
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
||||
C_rows);
|
||||
}
|
||||
}
|
||||
if (HWY_UNLIKELY(imc != mc)) {
|
||||
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
|
||||
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
||||
C_rows);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
@ -389,17 +406,18 @@ class MMKernel {
|
|||
HWY_DASSERT(mr == 4);
|
||||
if (HWY_LIKELY(mc >= 4)) {
|
||||
for (; imc <= mc - 4; imc += 4) {
|
||||
LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
|
||||
LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
||||
C_rows);
|
||||
}
|
||||
}
|
||||
const size_t remainder_mc = mc - imc;
|
||||
HWY_DASSERT(remainder_mc < 4);
|
||||
if (HWY_UNLIKELY(remainder_mc & 2)) {
|
||||
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
|
||||
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows);
|
||||
imc += 2;
|
||||
}
|
||||
if (HWY_UNLIKELY(remainder_mc & 1)) {
|
||||
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
|
||||
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows);
|
||||
imc += 1;
|
||||
}
|
||||
HWY_DASSERT(imc == mc);
|
||||
|
|
@ -496,11 +514,11 @@ class MMKernel {
|
|||
// with top-left corner `partial.Row(row_ac) + col_c`. Both A and B must be
|
||||
// BF16 so we can load directly without `Decompress2`, which is expensive for
|
||||
// NUQ and requires 2x unrolling, which requires more loads.
|
||||
template <size_t kRowsAC, class Tag, typename TC>
|
||||
template <size_t kRowsAC, /*deduced:*/ class Tag, typename TC>
|
||||
static HWY_INLINE void LoopKC(const RowPtrBF& A_view, const RowPtrBF& B_view,
|
||||
size_t row_ac, size_t imc, size_t col_c,
|
||||
size_t kc, Tag tag, const MMArgs& args,
|
||||
const RowPtr<TC>& C) {
|
||||
CRows<TC> C_rows) {
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
const size_t NBF = hn::Lanes(dbf);
|
||||
|
|
@ -614,11 +632,11 @@ class MMKernel {
|
|||
if (args.add) {
|
||||
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/true>()(
|
||||
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
|
||||
C31, C32, C33, row_ac, col_c, args, C);
|
||||
C31, C32, C33, row_ac, col_c, args, C_rows);
|
||||
} else {
|
||||
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/false>()(
|
||||
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
|
||||
C31, C32, C33, row_ac, col_c, args, C);
|
||||
C31, C32, C33, row_ac, col_c, args, C_rows);
|
||||
}
|
||||
} else {
|
||||
MMAddHorizontalSumsIntoPartial<kRowsAC, Tag>()(
|
||||
|
|
@ -642,27 +660,27 @@ class MMScaleDemoteAdd {
|
|||
template <typename TC>
|
||||
static HWY_INLINE void FillC(const IndexRange& range_mc,
|
||||
const IndexRange& range_nc, const MMArgs& args,
|
||||
const RowPtr<TC>& C) {
|
||||
CRows<TC> C_rows) {
|
||||
size_t row_c = range_mc.begin();
|
||||
if (args.add) {
|
||||
constexpr bool kAdd = true;
|
||||
if (range_mc.Num() >= 4) {
|
||||
for (; row_c <= range_mc.end() - 4; row_c += 4) {
|
||||
Do4Rows<kAdd>(row_c, range_nc, args, C);
|
||||
Do4Rows<kAdd>(row_c, range_nc, args, C_rows);
|
||||
}
|
||||
}
|
||||
for (; row_c < range_mc.end(); ++row_c) {
|
||||
Do1Row<kAdd>(row_c, range_nc, args, C);
|
||||
Do1Row<kAdd>(row_c, range_nc, args, C_rows);
|
||||
}
|
||||
} else {
|
||||
constexpr bool kAdd = false;
|
||||
if (range_mc.Num() >= 4) {
|
||||
for (; row_c <= range_mc.end() - 4; row_c += 4) {
|
||||
Do4Rows<kAdd>(row_c, range_nc, args, C);
|
||||
Do4Rows<kAdd>(row_c, range_nc, args, C_rows);
|
||||
}
|
||||
}
|
||||
for (; row_c < range_mc.end(); ++row_c) {
|
||||
Do1Row<kAdd>(row_c, range_nc, args, C);
|
||||
Do1Row<kAdd>(row_c, range_nc, args, C_rows);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -671,7 +689,7 @@ class MMScaleDemoteAdd {
|
|||
// Unrolled for 4 rows to reduce the number of loads from `add`.
|
||||
template <bool kAdd, typename TC>
|
||||
static HWY_INLINE void Do4Rows(size_t row_c, const IndexRange& range_nc,
|
||||
const MMArgs& args, const RowPtr<TC>& C) {
|
||||
const MMArgs& args, CRows<TC> C_rows) {
|
||||
const hn::ScalableTag<double> dd;
|
||||
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
|
||||
const hn::Rebind<TC, decltype(dd)> dc;
|
||||
|
|
@ -685,10 +703,10 @@ class MMScaleDemoteAdd {
|
|||
const double* HWY_RESTRICT pr2 = args.partial.Row(row_c + 2);
|
||||
const double* HWY_RESTRICT pr3 = args.partial.Row(row_c + 3);
|
||||
|
||||
TC* HWY_RESTRICT cr0 = C.Row(row_c + 0);
|
||||
TC* HWY_RESTRICT cr1 = C.Row(row_c + 1);
|
||||
TC* HWY_RESTRICT cr2 = C.Row(row_c + 2);
|
||||
TC* HWY_RESTRICT cr3 = C.Row(row_c + 3);
|
||||
TC* HWY_RESTRICT cr0 = C_rows[row_c + 0];
|
||||
TC* HWY_RESTRICT cr1 = C_rows[row_c + 1];
|
||||
TC* HWY_RESTRICT cr2 = C_rows[row_c + 2];
|
||||
TC* HWY_RESTRICT cr3 = C_rows[row_c + 3];
|
||||
|
||||
// We manually unroll 2x for higher IPC in batch=1.
|
||||
size_t col_c = range_nc.begin();
|
||||
|
|
@ -789,7 +807,7 @@ class MMScaleDemoteAdd {
|
|||
// Same as above but handles a single row (for remainder rows).
|
||||
template <bool kAdd, typename TC>
|
||||
static HWY_INLINE void Do1Row(size_t row_c, const IndexRange& range_nc,
|
||||
const MMArgs& args, const RowPtr<TC>& C) {
|
||||
const MMArgs& args, CRows<TC> C_rows) {
|
||||
const hn::ScalableTag<double> dd;
|
||||
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
|
||||
const hn::Rebind<TC, decltype(dd)> dc;
|
||||
|
|
@ -798,7 +816,7 @@ class MMScaleDemoteAdd {
|
|||
const size_t ND = hn::Lanes(dd);
|
||||
const VD vscale = hn::Set(dd, args.scale);
|
||||
const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0);
|
||||
TC* HWY_RESTRICT cr0 = C.Row(row_c + 0);
|
||||
TC* HWY_RESTRICT cr0 = C_rows[row_c + 0];
|
||||
|
||||
// We manually unroll 2x for higher IPC in batch=1.
|
||||
size_t col_c = range_nc.begin();
|
||||
|
|
@ -867,8 +885,8 @@ class MMPerPackage {
|
|||
A_(args_.env->storage.A(pkg_idx, A.Extents())),
|
||||
range_np_(range_np),
|
||||
mr_(config.MR()),
|
||||
ranges_mc_(config.RangesOfMC(A.Extents().rows)),
|
||||
ranges_kc_(config.RangesOfKC(A.Extents().cols)),
|
||||
ranges_mc_(config.RangesOfMC(A.Rows())),
|
||||
ranges_kc_(config.RangesOfKC(A.Cols())),
|
||||
ranges_nc_(config.RangesOfNC(range_np)),
|
||||
order_(config.Order()),
|
||||
inner_tasks_(config.InnerTasks()),
|
||||
|
|
@ -882,17 +900,16 @@ class MMPerPackage {
|
|||
// B is decompressed several call layers lower, but not all member functions
|
||||
// depend on TB, so pass it as an argument instead of templating the class.
|
||||
template <typename TB, typename TC>
|
||||
HWY_NOINLINE void operator()(const MatPtrT<TB>& B,
|
||||
const RowPtr<TC>& C) const {
|
||||
HWY_NOINLINE void operator()(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
|
||||
switch (order_) {
|
||||
case MMOrder::kNT:
|
||||
return DoNT(B, C);
|
||||
return DoNT(B, C_rows);
|
||||
case MMOrder::kNT_K:
|
||||
return DoNT_K(B, C);
|
||||
return DoNT_K(B, C_rows);
|
||||
case MMOrder::kNT_MT:
|
||||
return DoNT_MT(B, C);
|
||||
return DoNT_MT(B, C_rows);
|
||||
case MMOrder::kNT_MT_K:
|
||||
return DoNT_MT_K(B, C);
|
||||
return DoNT_MT_K(B, C_rows);
|
||||
default:
|
||||
HWY_UNREACHABLE;
|
||||
}
|
||||
|
|
@ -913,7 +930,7 @@ class MMPerPackage {
|
|||
|
||||
// Single M and K, parallel N. Fills all of C directly.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT(const MatPtrT<TB>& B, const RowPtr<TC>& C) const {
|
||||
HWY_INLINE void DoNT(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT", args_);
|
||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||
|
|
@ -940,7 +957,7 @@ class MMPerPackage {
|
|||
DecompressB(B, row_b, range_K, B_view);
|
||||
}
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
|
||||
args_, C);
|
||||
args_, C_rows);
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -949,7 +966,7 @@ class MMPerPackage {
|
|||
|
||||
// Single M, parallel N, sequential K. Fills all of partial.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, const RowPtr<TC>& C) const {
|
||||
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT_K", args_);
|
||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||
|
|
@ -975,7 +992,7 @@ class MMPerPackage {
|
|||
DecompressB(B, row_b, range_kc, B_view);
|
||||
}
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
||||
C);
|
||||
C_rows);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -997,13 +1014,13 @@ class MMPerPackage {
|
|||
MMZone fill_zone;
|
||||
if (out_ == MMOut::kCopy) {
|
||||
fill_zone.MaybeEnter("MM.NT_K.FillC", args_);
|
||||
MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C);
|
||||
MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C_rows);
|
||||
} else if (out_ == MMOut::kParM) {
|
||||
fill_zone.MaybeEnter("MM.NT_K.FillC.ParM", args_);
|
||||
args_.env->parallel.ForRangeMC(
|
||||
range_mc, pkg_idx_, [&](size_t row_a) HWY_ATTR {
|
||||
MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_,
|
||||
args_, C);
|
||||
args_, C_rows);
|
||||
});
|
||||
} else {
|
||||
HWY_UNREACHABLE; // kDirect is only used with kNT.
|
||||
|
|
@ -1013,7 +1030,7 @@ class MMPerPackage {
|
|||
// Parallel loops over mc/nc blocks of M/range_np, single K.
|
||||
// Fills `mc x nc` sections of C directly, in parallel.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, const RowPtr<TC>& C) const {
|
||||
HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT_MT", args_);
|
||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||
|
|
@ -1039,7 +1056,7 @@ class MMPerPackage {
|
|||
DecompressB(B, row_b, range_K, B_view);
|
||||
}
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
|
||||
args_, C);
|
||||
args_, C_rows);
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -1049,7 +1066,7 @@ class MMPerPackage {
|
|||
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
|
||||
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, const RowPtr<TC>& C) const {
|
||||
HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT_MT_K", args_);
|
||||
const size_t kc_max = ranges_kc_.TaskSize();
|
||||
|
|
@ -1074,7 +1091,7 @@ class MMPerPackage {
|
|||
DecompressB(B, row_b, range_kc, B_view);
|
||||
}
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
||||
C);
|
||||
C_rows);
|
||||
}
|
||||
}; // loop_nc
|
||||
args_.env->parallel.ForRangesMC_NC(
|
||||
|
|
@ -1097,7 +1114,7 @@ class MMPerPackage {
|
|||
HWY_DASSERT(out_ == MMOut::kCopy);
|
||||
MMZone fill_zone;
|
||||
fill_zone.MaybeEnter("MM.NT_MT_K.FillC", args_);
|
||||
MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C);
|
||||
MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C_rows);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -1258,7 +1275,7 @@ struct MMImpl {
|
|||
// or with the best config.
|
||||
template <typename TA, typename TB, typename TC>
|
||||
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
const RowPtr<TC>& C, const MMArgs& args,
|
||||
CRows<TC> C_rows, const MMArgs& args,
|
||||
const MMConfig& config) {
|
||||
MMZone matmul_zone;
|
||||
matmul_zone.MaybeEnter("MM.DoMatMul", args);
|
||||
|
|
@ -1267,7 +1284,7 @@ struct MMImpl {
|
|||
args.env->parallel.ForPkg(
|
||||
args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) {
|
||||
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
||||
MMPerPackage(A, args, config, pkg_idx, range_np)(B, C);
|
||||
MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
|
@ -1275,7 +1292,7 @@ struct MMImpl {
|
|||
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
|
||||
//
|
||||
// `A` is a row-major matrix with `M` rows and `B` is transposed. The latter's
|
||||
// `K = B.Extents().cols`, which must match `A.Extents().cols`, is the number
|
||||
// `K = B.Cols()`, which must match `A.Cols()`, is the number
|
||||
// of rows in the original B. `N = C.Cols()` must be a multiple of 4. There
|
||||
// are no other restrictions on shape, though performance is better when `M % 4
|
||||
// == 0` or `M <= 4`.
|
||||
|
|
@ -1295,11 +1312,11 @@ struct MMImpl {
|
|||
template <typename TA, typename TB, typename TC>
|
||||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
const RowPtr<TC>& C) {
|
||||
CRows<TC> C_rows) {
|
||||
const Allocator& allocator = env.ctx.allocator;
|
||||
const size_t M = A.Extents().rows;
|
||||
const size_t K = A.Extents().cols;
|
||||
const size_t N = B.Extents().rows;
|
||||
const size_t M = A.Rows();
|
||||
const size_t K = A.Cols();
|
||||
const size_t N = B.Rows();
|
||||
const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N);
|
||||
intptr_t index = MMImpl::IndexOfKey(key, env.keys);
|
||||
// First time we see this shape/key.
|
||||
|
|
@ -1323,7 +1340,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
|
||||
add, env.storage.Partial());
|
||||
if (HWY_LIKELY(tuner.Best())) {
|
||||
MMImpl::DoMatMul(A, B, C, args, *tuner.Best());
|
||||
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best());
|
||||
return &per_key;
|
||||
}
|
||||
|
||||
|
|
@ -1332,8 +1349,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
// First call: enumerate all feasible configs.
|
||||
if (HWY_UNLIKELY(!tuner.HasCandidates())) {
|
||||
// Ensure matrix dimensions match each other.
|
||||
HWY_ASSERT(K == B.Extents().cols);
|
||||
HWY_ASSERT(N == C.Cols());
|
||||
HWY_ASSERT(K == B.Cols());
|
||||
HWY_ASSERT(M <= MMStorage::kMaxM);
|
||||
HWY_ASSERT(K <= MMStorage::kMaxK);
|
||||
HWY_ASSERT(N <= MMStorage::kMaxN);
|
||||
|
|
@ -1347,7 +1363,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
|
||||
const MMConfig& cfg = tuner.NextConfig();
|
||||
const uint64_t t0 = hwy::timer::Start();
|
||||
MMImpl::DoMatMul(A, B, C, args, cfg);
|
||||
MMImpl::DoMatMul(A, B, C_rows, args, cfg);
|
||||
const uint64_t t1 =
|
||||
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
||||
|
|
@ -1376,6 +1392,19 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
return &per_key;
|
||||
}
|
||||
|
||||
// Adapter that fills the row array. This is the common case, whereas only
|
||||
// GemmaAttention::ComputeQKV uses the arbitrary output rows feature.
|
||||
template <typename TA, typename TB, typename TC>
|
||||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
const RowPtr<TC>& C) {
|
||||
HWY_DASSERT(B.Rows() == C.Cols());
|
||||
for (size_t row_ac = 0; row_ac < A.Rows(); ++row_ac) {
|
||||
env.storage.OutRow(row_ac) = reinterpret_cast<uint8_t*>(C.Row(row_ac));
|
||||
}
|
||||
return MatMul(A, B, add, env, CRows<TC>(&env.storage.OutRow(0)));
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
10
ops/matmul.h
10
ops/matmul.h
|
|
@ -193,10 +193,11 @@ class MMStorage {
|
|||
// Internally threaded; must not be called concurrently with the same
|
||||
// `ThreadingContext` (used via `parallel`).
|
||||
MMStorage(const Allocator& allocator, MMParallel& parallel)
|
||||
: out_rows(hwy::AllocateAligned<uint8_t*>(kMaxM)),
|
||||
// Per-worker copies of `partial` would be wasteful. We instead allocate
|
||||
// one instance of the maximum matrix extents because threads write at
|
||||
// false-sharing-free granularity.
|
||||
: partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN),
|
||||
partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN),
|
||||
MatPadding::kOdd),
|
||||
// Same stride independent of the actual C.Cols() so we can pre-bind.
|
||||
partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) {
|
||||
|
|
@ -220,6 +221,8 @@ class MMStorage {
|
|||
BindC(partial_storage_, parallel);
|
||||
}
|
||||
|
||||
uint8_t*& OutRow(size_t row_idx) { return out_rows[row_idx]; }
|
||||
|
||||
// Returns per-package matrix view.
|
||||
RowPtrBF A(size_t pkg_idx, const Extents2D& extents) const {
|
||||
HWY_DASSERT(extents.rows <= kMaxM);
|
||||
|
|
@ -231,6 +234,11 @@ class MMStorage {
|
|||
RowPtrD Partial() const { return partial_; }
|
||||
|
||||
private:
|
||||
// Enables arbitrary output rows. Most callers pass `RowPtr`, which assumes a
|
||||
// constant stride, but GemmaAttention::ComputeQKV writes to differing KV
|
||||
// positions per query / output row. `kMaxM` elements are too large for the
|
||||
// stack, hence dynamic allocation.
|
||||
hwy::AlignedFreeUniquePtr<uint8_t*[]> out_rows;
|
||||
std::unique_ptr<MatStorageT<BF16>> pkg_A_[MMParallel::kMaxPackages];
|
||||
MatStorageT<double> partial_storage_;
|
||||
RowPtrD partial_;
|
||||
|
|
|
|||
Loading…
Reference in New Issue