mirror of https://github.com/google/gemma.cpp.git
Add support for arbitrary output row pointers
Useful for writing directly to KV cache. PiperOrigin-RevId: 765615147
This commit is contained in:
parent
9c3e089b09
commit
0023ff8770
151
ops/matmul-inl.h
151
ops/matmul-inl.h
|
|
@ -80,6 +80,19 @@ hn::Vec<DC> TCFromF32(DC dc, hn::Vec<DF> vf) {
|
||||||
return hn::DemoteTo(dc, vf);
|
return hn::DemoteTo(dc, vf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename TC>
|
||||||
|
class CRows {
|
||||||
|
public:
|
||||||
|
CRows(uint8_t** C_rows) : C_rows_(C_rows) {}
|
||||||
|
|
||||||
|
TC* HWY_RESTRICT operator[](size_t row_idx) const {
|
||||||
|
return HWY_RCAST_ALIGNED(TC*, C_rows_[row_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
uint8_t** C_rows_;
|
||||||
|
};
|
||||||
|
|
||||||
// Tag classes, passed to `MMKernel::A2C0` to choose between writing one
|
// Tag classes, passed to `MMKernel::A2C0` to choose between writing one
|
||||||
// (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or writing the
|
// (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or writing the
|
||||||
// first kc result to partial, or accumulating the next kc result into partial
|
// first kc result to partial, or accumulating the next kc result into partial
|
||||||
|
|
@ -110,7 +123,7 @@ class MMStoreHorizontalSumsIntoC {
|
||||||
VF C20, VF C21, VF C22, VF C23, //
|
VF C20, VF C21, VF C22, VF C23, //
|
||||||
VF C30, VF C31, VF C32, VF C33, //
|
VF C30, VF C31, VF C32, VF C33, //
|
||||||
const size_t row_c, const size_t col_c,
|
const size_t row_c, const size_t col_c,
|
||||||
const MMArgs& args, const RowPtr<TC>& C) const {
|
const MMArgs& args, CRows<TC> C_rows) const {
|
||||||
HWY_ALIGN float buf[16 * hn::MaxLanes(df)];
|
HWY_ALIGN float buf[16 * hn::MaxLanes(df)];
|
||||||
const size_t N = hn::Lanes(df);
|
const size_t N = hn::Lanes(df);
|
||||||
// Horizontal reductions (`ReduceSum`) are rather expensive, entailing
|
// Horizontal reductions (`ReduceSum`) are rather expensive, entailing
|
||||||
|
|
@ -146,10 +159,10 @@ class MMStoreHorizontalSumsIntoC {
|
||||||
if constexpr (kAdd) {
|
if constexpr (kAdd) {
|
||||||
vadd = hn::Load(d4, args.add + col_c);
|
vadd = hn::Load(d4, args.add + col_c);
|
||||||
}
|
}
|
||||||
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, C, row_c, col_c);
|
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, C_rows, row_c, col_c);
|
||||||
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, C, row_c, col_c);
|
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, C_rows, row_c, col_c);
|
||||||
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, C, row_c, col_c);
|
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, C_rows, row_c, col_c);
|
||||||
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, C, row_c, col_c);
|
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, C_rows, row_c, col_c);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
@ -185,13 +198,14 @@ class MMStoreHorizontalSumsIntoC {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t kRow, typename TC, class DF4, class VF4 = hn::Vec<DF4>>
|
template <size_t kRow, /*deduced:*/ class DF4, class VF4 = hn::Vec<DF4>,
|
||||||
|
typename TC>
|
||||||
static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale,
|
static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale,
|
||||||
VF4 vadd, const RowPtr<TC>& C,
|
VF4 vadd, CRows<TC> C_rows,
|
||||||
const size_t row_c,
|
const size_t row_c,
|
||||||
const size_t col_c) {
|
const size_t col_c) {
|
||||||
if constexpr (kRow < kRowsAC) {
|
if constexpr (kRow < kRowsAC) {
|
||||||
TC* HWY_RESTRICT pos = C.Row(row_c + kRow) + col_c;
|
TC* HWY_RESTRICT pos = C_rows[row_c + kRow] + col_c;
|
||||||
const hn::Rebind<TC, DF4> dc4;
|
const hn::Rebind<TC, DF4> dc4;
|
||||||
const VF4 out = hn::MulAdd(sum, vscale, vadd);
|
const VF4 out = hn::MulAdd(sum, vscale, vadd);
|
||||||
hn::Store(TCFromF32(dc4, out), dc4, pos);
|
hn::Store(TCFromF32(dc4, out), dc4, pos);
|
||||||
|
|
@ -359,7 +373,7 @@ class MMKernel {
|
||||||
static HWY_INLINE void A2C0(const RowPtrBF& A_view, const RowPtrBF& B_view,
|
static HWY_INLINE void A2C0(const RowPtrBF& A_view, const RowPtrBF& B_view,
|
||||||
size_t mr, const IndexRange& range_mc,
|
size_t mr, const IndexRange& range_mc,
|
||||||
const size_t row_b, size_t kc, Tag tag,
|
const size_t row_b, size_t kc, Tag tag,
|
||||||
const MMArgs& args, const RowPtr<TC>& C) {
|
const MMArgs& args, CRows<TC> C_rows) {
|
||||||
HWY_DASSERT(1 <= mr && mr <= kMaxMR);
|
HWY_DASSERT(1 <= mr && mr <= kMaxMR);
|
||||||
const size_t row0 = range_mc.begin();
|
const size_t row0 = range_mc.begin();
|
||||||
const size_t mc = range_mc.Num();
|
const size_t mc = range_mc.Num();
|
||||||
|
|
@ -368,7 +382,8 @@ class MMKernel {
|
||||||
// M == 1, or x86 with 8 SIMD registers:
|
// M == 1, or x86 with 8 SIMD registers:
|
||||||
if (HWY_UNLIKELY(mr == 1)) {
|
if (HWY_UNLIKELY(mr == 1)) {
|
||||||
for (; imc < mc; ++imc) {
|
for (; imc < mc; ++imc) {
|
||||||
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
|
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
||||||
|
C_rows);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
@ -377,11 +392,13 @@ class MMKernel {
|
||||||
if (HWY_UNLIKELY(mr == 2)) {
|
if (HWY_UNLIKELY(mr == 2)) {
|
||||||
if (HWY_LIKELY(mc >= 2)) {
|
if (HWY_LIKELY(mc >= 2)) {
|
||||||
for (; imc <= mc - 2; imc += 2) {
|
for (; imc <= mc - 2; imc += 2) {
|
||||||
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
|
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
||||||
|
C_rows);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (HWY_UNLIKELY(imc != mc)) {
|
if (HWY_UNLIKELY(imc != mc)) {
|
||||||
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
|
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
||||||
|
C_rows);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
@ -389,17 +406,18 @@ class MMKernel {
|
||||||
HWY_DASSERT(mr == 4);
|
HWY_DASSERT(mr == 4);
|
||||||
if (HWY_LIKELY(mc >= 4)) {
|
if (HWY_LIKELY(mc >= 4)) {
|
||||||
for (; imc <= mc - 4; imc += 4) {
|
for (; imc <= mc - 4; imc += 4) {
|
||||||
LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
|
LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
||||||
|
C_rows);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const size_t remainder_mc = mc - imc;
|
const size_t remainder_mc = mc - imc;
|
||||||
HWY_DASSERT(remainder_mc < 4);
|
HWY_DASSERT(remainder_mc < 4);
|
||||||
if (HWY_UNLIKELY(remainder_mc & 2)) {
|
if (HWY_UNLIKELY(remainder_mc & 2)) {
|
||||||
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
|
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows);
|
||||||
imc += 2;
|
imc += 2;
|
||||||
}
|
}
|
||||||
if (HWY_UNLIKELY(remainder_mc & 1)) {
|
if (HWY_UNLIKELY(remainder_mc & 1)) {
|
||||||
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
|
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows);
|
||||||
imc += 1;
|
imc += 1;
|
||||||
}
|
}
|
||||||
HWY_DASSERT(imc == mc);
|
HWY_DASSERT(imc == mc);
|
||||||
|
|
@ -496,11 +514,11 @@ class MMKernel {
|
||||||
// with top-left corner `partial.Row(row_ac) + col_c`. Both A and B must be
|
// with top-left corner `partial.Row(row_ac) + col_c`. Both A and B must be
|
||||||
// BF16 so we can load directly without `Decompress2`, which is expensive for
|
// BF16 so we can load directly without `Decompress2`, which is expensive for
|
||||||
// NUQ and requires 2x unrolling, which requires more loads.
|
// NUQ and requires 2x unrolling, which requires more loads.
|
||||||
template <size_t kRowsAC, class Tag, typename TC>
|
template <size_t kRowsAC, /*deduced:*/ class Tag, typename TC>
|
||||||
static HWY_INLINE void LoopKC(const RowPtrBF& A_view, const RowPtrBF& B_view,
|
static HWY_INLINE void LoopKC(const RowPtrBF& A_view, const RowPtrBF& B_view,
|
||||||
size_t row_ac, size_t imc, size_t col_c,
|
size_t row_ac, size_t imc, size_t col_c,
|
||||||
size_t kc, Tag tag, const MMArgs& args,
|
size_t kc, Tag tag, const MMArgs& args,
|
||||||
const RowPtr<TC>& C) {
|
CRows<TC> C_rows) {
|
||||||
const hn::ScalableTag<BF16> dbf;
|
const hn::ScalableTag<BF16> dbf;
|
||||||
using VBF = hn::Vec<decltype(dbf)>;
|
using VBF = hn::Vec<decltype(dbf)>;
|
||||||
const size_t NBF = hn::Lanes(dbf);
|
const size_t NBF = hn::Lanes(dbf);
|
||||||
|
|
@ -614,11 +632,11 @@ class MMKernel {
|
||||||
if (args.add) {
|
if (args.add) {
|
||||||
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/true>()(
|
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/true>()(
|
||||||
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
|
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
|
||||||
C31, C32, C33, row_ac, col_c, args, C);
|
C31, C32, C33, row_ac, col_c, args, C_rows);
|
||||||
} else {
|
} else {
|
||||||
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/false>()(
|
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/false>()(
|
||||||
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
|
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
|
||||||
C31, C32, C33, row_ac, col_c, args, C);
|
C31, C32, C33, row_ac, col_c, args, C_rows);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
MMAddHorizontalSumsIntoPartial<kRowsAC, Tag>()(
|
MMAddHorizontalSumsIntoPartial<kRowsAC, Tag>()(
|
||||||
|
|
@ -642,27 +660,27 @@ class MMScaleDemoteAdd {
|
||||||
template <typename TC>
|
template <typename TC>
|
||||||
static HWY_INLINE void FillC(const IndexRange& range_mc,
|
static HWY_INLINE void FillC(const IndexRange& range_mc,
|
||||||
const IndexRange& range_nc, const MMArgs& args,
|
const IndexRange& range_nc, const MMArgs& args,
|
||||||
const RowPtr<TC>& C) {
|
CRows<TC> C_rows) {
|
||||||
size_t row_c = range_mc.begin();
|
size_t row_c = range_mc.begin();
|
||||||
if (args.add) {
|
if (args.add) {
|
||||||
constexpr bool kAdd = true;
|
constexpr bool kAdd = true;
|
||||||
if (range_mc.Num() >= 4) {
|
if (range_mc.Num() >= 4) {
|
||||||
for (; row_c <= range_mc.end() - 4; row_c += 4) {
|
for (; row_c <= range_mc.end() - 4; row_c += 4) {
|
||||||
Do4Rows<kAdd>(row_c, range_nc, args, C);
|
Do4Rows<kAdd>(row_c, range_nc, args, C_rows);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (; row_c < range_mc.end(); ++row_c) {
|
for (; row_c < range_mc.end(); ++row_c) {
|
||||||
Do1Row<kAdd>(row_c, range_nc, args, C);
|
Do1Row<kAdd>(row_c, range_nc, args, C_rows);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
constexpr bool kAdd = false;
|
constexpr bool kAdd = false;
|
||||||
if (range_mc.Num() >= 4) {
|
if (range_mc.Num() >= 4) {
|
||||||
for (; row_c <= range_mc.end() - 4; row_c += 4) {
|
for (; row_c <= range_mc.end() - 4; row_c += 4) {
|
||||||
Do4Rows<kAdd>(row_c, range_nc, args, C);
|
Do4Rows<kAdd>(row_c, range_nc, args, C_rows);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (; row_c < range_mc.end(); ++row_c) {
|
for (; row_c < range_mc.end(); ++row_c) {
|
||||||
Do1Row<kAdd>(row_c, range_nc, args, C);
|
Do1Row<kAdd>(row_c, range_nc, args, C_rows);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -671,7 +689,7 @@ class MMScaleDemoteAdd {
|
||||||
// Unrolled for 4 rows to reduce the number of loads from `add`.
|
// Unrolled for 4 rows to reduce the number of loads from `add`.
|
||||||
template <bool kAdd, typename TC>
|
template <bool kAdd, typename TC>
|
||||||
static HWY_INLINE void Do4Rows(size_t row_c, const IndexRange& range_nc,
|
static HWY_INLINE void Do4Rows(size_t row_c, const IndexRange& range_nc,
|
||||||
const MMArgs& args, const RowPtr<TC>& C) {
|
const MMArgs& args, CRows<TC> C_rows) {
|
||||||
const hn::ScalableTag<double> dd;
|
const hn::ScalableTag<double> dd;
|
||||||
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
|
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
|
||||||
const hn::Rebind<TC, decltype(dd)> dc;
|
const hn::Rebind<TC, decltype(dd)> dc;
|
||||||
|
|
@ -685,10 +703,10 @@ class MMScaleDemoteAdd {
|
||||||
const double* HWY_RESTRICT pr2 = args.partial.Row(row_c + 2);
|
const double* HWY_RESTRICT pr2 = args.partial.Row(row_c + 2);
|
||||||
const double* HWY_RESTRICT pr3 = args.partial.Row(row_c + 3);
|
const double* HWY_RESTRICT pr3 = args.partial.Row(row_c + 3);
|
||||||
|
|
||||||
TC* HWY_RESTRICT cr0 = C.Row(row_c + 0);
|
TC* HWY_RESTRICT cr0 = C_rows[row_c + 0];
|
||||||
TC* HWY_RESTRICT cr1 = C.Row(row_c + 1);
|
TC* HWY_RESTRICT cr1 = C_rows[row_c + 1];
|
||||||
TC* HWY_RESTRICT cr2 = C.Row(row_c + 2);
|
TC* HWY_RESTRICT cr2 = C_rows[row_c + 2];
|
||||||
TC* HWY_RESTRICT cr3 = C.Row(row_c + 3);
|
TC* HWY_RESTRICT cr3 = C_rows[row_c + 3];
|
||||||
|
|
||||||
// We manually unroll 2x for higher IPC in batch=1.
|
// We manually unroll 2x for higher IPC in batch=1.
|
||||||
size_t col_c = range_nc.begin();
|
size_t col_c = range_nc.begin();
|
||||||
|
|
@ -789,7 +807,7 @@ class MMScaleDemoteAdd {
|
||||||
// Same as above but handles a single row (for remainder rows).
|
// Same as above but handles a single row (for remainder rows).
|
||||||
template <bool kAdd, typename TC>
|
template <bool kAdd, typename TC>
|
||||||
static HWY_INLINE void Do1Row(size_t row_c, const IndexRange& range_nc,
|
static HWY_INLINE void Do1Row(size_t row_c, const IndexRange& range_nc,
|
||||||
const MMArgs& args, const RowPtr<TC>& C) {
|
const MMArgs& args, CRows<TC> C_rows) {
|
||||||
const hn::ScalableTag<double> dd;
|
const hn::ScalableTag<double> dd;
|
||||||
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
|
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
|
||||||
const hn::Rebind<TC, decltype(dd)> dc;
|
const hn::Rebind<TC, decltype(dd)> dc;
|
||||||
|
|
@ -798,7 +816,7 @@ class MMScaleDemoteAdd {
|
||||||
const size_t ND = hn::Lanes(dd);
|
const size_t ND = hn::Lanes(dd);
|
||||||
const VD vscale = hn::Set(dd, args.scale);
|
const VD vscale = hn::Set(dd, args.scale);
|
||||||
const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0);
|
const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0);
|
||||||
TC* HWY_RESTRICT cr0 = C.Row(row_c + 0);
|
TC* HWY_RESTRICT cr0 = C_rows[row_c + 0];
|
||||||
|
|
||||||
// We manually unroll 2x for higher IPC in batch=1.
|
// We manually unroll 2x for higher IPC in batch=1.
|
||||||
size_t col_c = range_nc.begin();
|
size_t col_c = range_nc.begin();
|
||||||
|
|
@ -867,8 +885,8 @@ class MMPerPackage {
|
||||||
A_(args_.env->storage.A(pkg_idx, A.Extents())),
|
A_(args_.env->storage.A(pkg_idx, A.Extents())),
|
||||||
range_np_(range_np),
|
range_np_(range_np),
|
||||||
mr_(config.MR()),
|
mr_(config.MR()),
|
||||||
ranges_mc_(config.RangesOfMC(A.Extents().rows)),
|
ranges_mc_(config.RangesOfMC(A.Rows())),
|
||||||
ranges_kc_(config.RangesOfKC(A.Extents().cols)),
|
ranges_kc_(config.RangesOfKC(A.Cols())),
|
||||||
ranges_nc_(config.RangesOfNC(range_np)),
|
ranges_nc_(config.RangesOfNC(range_np)),
|
||||||
order_(config.Order()),
|
order_(config.Order()),
|
||||||
inner_tasks_(config.InnerTasks()),
|
inner_tasks_(config.InnerTasks()),
|
||||||
|
|
@ -882,17 +900,16 @@ class MMPerPackage {
|
||||||
// B is decompressed several call layers lower, but not all member functions
|
// B is decompressed several call layers lower, but not all member functions
|
||||||
// depend on TB, so pass it as an argument instead of templating the class.
|
// depend on TB, so pass it as an argument instead of templating the class.
|
||||||
template <typename TB, typename TC>
|
template <typename TB, typename TC>
|
||||||
HWY_NOINLINE void operator()(const MatPtrT<TB>& B,
|
HWY_NOINLINE void operator()(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
|
||||||
const RowPtr<TC>& C) const {
|
|
||||||
switch (order_) {
|
switch (order_) {
|
||||||
case MMOrder::kNT:
|
case MMOrder::kNT:
|
||||||
return DoNT(B, C);
|
return DoNT(B, C_rows);
|
||||||
case MMOrder::kNT_K:
|
case MMOrder::kNT_K:
|
||||||
return DoNT_K(B, C);
|
return DoNT_K(B, C_rows);
|
||||||
case MMOrder::kNT_MT:
|
case MMOrder::kNT_MT:
|
||||||
return DoNT_MT(B, C);
|
return DoNT_MT(B, C_rows);
|
||||||
case MMOrder::kNT_MT_K:
|
case MMOrder::kNT_MT_K:
|
||||||
return DoNT_MT_K(B, C);
|
return DoNT_MT_K(B, C_rows);
|
||||||
default:
|
default:
|
||||||
HWY_UNREACHABLE;
|
HWY_UNREACHABLE;
|
||||||
}
|
}
|
||||||
|
|
@ -913,7 +930,7 @@ class MMPerPackage {
|
||||||
|
|
||||||
// Single M and K, parallel N. Fills all of C directly.
|
// Single M and K, parallel N. Fills all of C directly.
|
||||||
template <typename TB, typename TC>
|
template <typename TB, typename TC>
|
||||||
HWY_INLINE void DoNT(const MatPtrT<TB>& B, const RowPtr<TC>& C) const {
|
HWY_INLINE void DoNT(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
|
||||||
MMZone zone;
|
MMZone zone;
|
||||||
zone.MaybeEnter("MM.NT", args_);
|
zone.MaybeEnter("MM.NT", args_);
|
||||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||||
|
|
@ -940,7 +957,7 @@ class MMPerPackage {
|
||||||
DecompressB(B, row_b, range_K, B_view);
|
DecompressB(B, row_b, range_K, B_view);
|
||||||
}
|
}
|
||||||
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
|
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
|
||||||
args_, C);
|
args_, C_rows);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -949,7 +966,7 @@ class MMPerPackage {
|
||||||
|
|
||||||
// Single M, parallel N, sequential K. Fills all of partial.
|
// Single M, parallel N, sequential K. Fills all of partial.
|
||||||
template <typename TB, typename TC>
|
template <typename TB, typename TC>
|
||||||
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, const RowPtr<TC>& C) const {
|
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
|
||||||
MMZone zone;
|
MMZone zone;
|
||||||
zone.MaybeEnter("MM.NT_K", args_);
|
zone.MaybeEnter("MM.NT_K", args_);
|
||||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||||
|
|
@ -975,7 +992,7 @@ class MMPerPackage {
|
||||||
DecompressB(B, row_b, range_kc, B_view);
|
DecompressB(B, row_b, range_kc, B_view);
|
||||||
}
|
}
|
||||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
||||||
C);
|
C_rows);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -997,13 +1014,13 @@ class MMPerPackage {
|
||||||
MMZone fill_zone;
|
MMZone fill_zone;
|
||||||
if (out_ == MMOut::kCopy) {
|
if (out_ == MMOut::kCopy) {
|
||||||
fill_zone.MaybeEnter("MM.NT_K.FillC", args_);
|
fill_zone.MaybeEnter("MM.NT_K.FillC", args_);
|
||||||
MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C);
|
MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C_rows);
|
||||||
} else if (out_ == MMOut::kParM) {
|
} else if (out_ == MMOut::kParM) {
|
||||||
fill_zone.MaybeEnter("MM.NT_K.FillC.ParM", args_);
|
fill_zone.MaybeEnter("MM.NT_K.FillC.ParM", args_);
|
||||||
args_.env->parallel.ForRangeMC(
|
args_.env->parallel.ForRangeMC(
|
||||||
range_mc, pkg_idx_, [&](size_t row_a) HWY_ATTR {
|
range_mc, pkg_idx_, [&](size_t row_a) HWY_ATTR {
|
||||||
MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_,
|
MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_,
|
||||||
args_, C);
|
args_, C_rows);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
HWY_UNREACHABLE; // kDirect is only used with kNT.
|
HWY_UNREACHABLE; // kDirect is only used with kNT.
|
||||||
|
|
@ -1013,7 +1030,7 @@ class MMPerPackage {
|
||||||
// Parallel loops over mc/nc blocks of M/range_np, single K.
|
// Parallel loops over mc/nc blocks of M/range_np, single K.
|
||||||
// Fills `mc x nc` sections of C directly, in parallel.
|
// Fills `mc x nc` sections of C directly, in parallel.
|
||||||
template <typename TB, typename TC>
|
template <typename TB, typename TC>
|
||||||
HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, const RowPtr<TC>& C) const {
|
HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
|
||||||
MMZone zone;
|
MMZone zone;
|
||||||
zone.MaybeEnter("MM.NT_MT", args_);
|
zone.MaybeEnter("MM.NT_MT", args_);
|
||||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||||
|
|
@ -1039,7 +1056,7 @@ class MMPerPackage {
|
||||||
DecompressB(B, row_b, range_K, B_view);
|
DecompressB(B, row_b, range_K, B_view);
|
||||||
}
|
}
|
||||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
|
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
|
||||||
args_, C);
|
args_, C_rows);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -1049,7 +1066,7 @@ class MMPerPackage {
|
||||||
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
|
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
|
||||||
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
|
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
|
||||||
template <typename TB, typename TC>
|
template <typename TB, typename TC>
|
||||||
HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, const RowPtr<TC>& C) const {
|
HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
|
||||||
MMZone zone;
|
MMZone zone;
|
||||||
zone.MaybeEnter("MM.NT_MT_K", args_);
|
zone.MaybeEnter("MM.NT_MT_K", args_);
|
||||||
const size_t kc_max = ranges_kc_.TaskSize();
|
const size_t kc_max = ranges_kc_.TaskSize();
|
||||||
|
|
@ -1074,7 +1091,7 @@ class MMPerPackage {
|
||||||
DecompressB(B, row_b, range_kc, B_view);
|
DecompressB(B, row_b, range_kc, B_view);
|
||||||
}
|
}
|
||||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
||||||
C);
|
C_rows);
|
||||||
}
|
}
|
||||||
}; // loop_nc
|
}; // loop_nc
|
||||||
args_.env->parallel.ForRangesMC_NC(
|
args_.env->parallel.ForRangesMC_NC(
|
||||||
|
|
@ -1097,7 +1114,7 @@ class MMPerPackage {
|
||||||
HWY_DASSERT(out_ == MMOut::kCopy);
|
HWY_DASSERT(out_ == MMOut::kCopy);
|
||||||
MMZone fill_zone;
|
MMZone fill_zone;
|
||||||
fill_zone.MaybeEnter("MM.NT_MT_K.FillC", args_);
|
fill_zone.MaybeEnter("MM.NT_MT_K.FillC", args_);
|
||||||
MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C);
|
MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C_rows);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1258,7 +1275,7 @@ struct MMImpl {
|
||||||
// or with the best config.
|
// or with the best config.
|
||||||
template <typename TA, typename TB, typename TC>
|
template <typename TA, typename TB, typename TC>
|
||||||
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
const RowPtr<TC>& C, const MMArgs& args,
|
CRows<TC> C_rows, const MMArgs& args,
|
||||||
const MMConfig& config) {
|
const MMConfig& config) {
|
||||||
MMZone matmul_zone;
|
MMZone matmul_zone;
|
||||||
matmul_zone.MaybeEnter("MM.DoMatMul", args);
|
matmul_zone.MaybeEnter("MM.DoMatMul", args);
|
||||||
|
|
@ -1267,7 +1284,7 @@ struct MMImpl {
|
||||||
args.env->parallel.ForPkg(
|
args.env->parallel.ForPkg(
|
||||||
args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) {
|
args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) {
|
||||||
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
||||||
MMPerPackage(A, args, config, pkg_idx, range_np)(B, C);
|
MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -1275,7 +1292,7 @@ struct MMImpl {
|
||||||
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
|
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
|
||||||
//
|
//
|
||||||
// `A` is a row-major matrix with `M` rows and `B` is transposed. The latter's
|
// `A` is a row-major matrix with `M` rows and `B` is transposed. The latter's
|
||||||
// `K = B.Extents().cols`, which must match `A.Extents().cols`, is the number
|
// `K = B.Cols()`, which must match `A.Cols()`, is the number
|
||||||
// of rows in the original B. `N = C.Cols()` must be a multiple of 4. There
|
// of rows in the original B. `N = C.Cols()` must be a multiple of 4. There
|
||||||
// are no other restrictions on shape, though performance is better when `M % 4
|
// are no other restrictions on shape, though performance is better when `M % 4
|
||||||
// == 0` or `M <= 4`.
|
// == 0` or `M <= 4`.
|
||||||
|
|
@ -1295,11 +1312,11 @@ struct MMImpl {
|
||||||
template <typename TA, typename TB, typename TC>
|
template <typename TA, typename TB, typename TC>
|
||||||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||||
const RowPtr<TC>& C) {
|
CRows<TC> C_rows) {
|
||||||
const Allocator& allocator = env.ctx.allocator;
|
const Allocator& allocator = env.ctx.allocator;
|
||||||
const size_t M = A.Extents().rows;
|
const size_t M = A.Rows();
|
||||||
const size_t K = A.Extents().cols;
|
const size_t K = A.Cols();
|
||||||
const size_t N = B.Extents().rows;
|
const size_t N = B.Rows();
|
||||||
const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N);
|
const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N);
|
||||||
intptr_t index = MMImpl::IndexOfKey(key, env.keys);
|
intptr_t index = MMImpl::IndexOfKey(key, env.keys);
|
||||||
// First time we see this shape/key.
|
// First time we see this shape/key.
|
||||||
|
|
@ -1323,7 +1340,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
|
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
|
||||||
add, env.storage.Partial());
|
add, env.storage.Partial());
|
||||||
if (HWY_LIKELY(tuner.Best())) {
|
if (HWY_LIKELY(tuner.Best())) {
|
||||||
MMImpl::DoMatMul(A, B, C, args, *tuner.Best());
|
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best());
|
||||||
return &per_key;
|
return &per_key;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1332,8 +1349,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
// First call: enumerate all feasible configs.
|
// First call: enumerate all feasible configs.
|
||||||
if (HWY_UNLIKELY(!tuner.HasCandidates())) {
|
if (HWY_UNLIKELY(!tuner.HasCandidates())) {
|
||||||
// Ensure matrix dimensions match each other.
|
// Ensure matrix dimensions match each other.
|
||||||
HWY_ASSERT(K == B.Extents().cols);
|
HWY_ASSERT(K == B.Cols());
|
||||||
HWY_ASSERT(N == C.Cols());
|
|
||||||
HWY_ASSERT(M <= MMStorage::kMaxM);
|
HWY_ASSERT(M <= MMStorage::kMaxM);
|
||||||
HWY_ASSERT(K <= MMStorage::kMaxK);
|
HWY_ASSERT(K <= MMStorage::kMaxK);
|
||||||
HWY_ASSERT(N <= MMStorage::kMaxN);
|
HWY_ASSERT(N <= MMStorage::kMaxN);
|
||||||
|
|
@ -1347,7 +1363,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
|
|
||||||
const MMConfig& cfg = tuner.NextConfig();
|
const MMConfig& cfg = tuner.NextConfig();
|
||||||
const uint64_t t0 = hwy::timer::Start();
|
const uint64_t t0 = hwy::timer::Start();
|
||||||
MMImpl::DoMatMul(A, B, C, args, cfg);
|
MMImpl::DoMatMul(A, B, C_rows, args, cfg);
|
||||||
const uint64_t t1 =
|
const uint64_t t1 =
|
||||||
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||||
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
||||||
|
|
@ -1376,6 +1392,19 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
return &per_key;
|
return &per_key;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Adapter that fills the row array. This is the common case, whereas only
|
||||||
|
// GemmaAttention::ComputeQKV uses the arbitrary output rows feature.
|
||||||
|
template <typename TA, typename TB, typename TC>
|
||||||
|
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
|
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||||
|
const RowPtr<TC>& C) {
|
||||||
|
HWY_DASSERT(B.Rows() == C.Cols());
|
||||||
|
for (size_t row_ac = 0; row_ac < A.Rows(); ++row_ac) {
|
||||||
|
env.storage.OutRow(row_ac) = reinterpret_cast<uint8_t*>(C.Row(row_ac));
|
||||||
|
}
|
||||||
|
return MatMul(A, B, add, env, CRows<TC>(&env.storage.OutRow(0)));
|
||||||
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
10
ops/matmul.h
10
ops/matmul.h
|
|
@ -193,10 +193,11 @@ class MMStorage {
|
||||||
// Internally threaded; must not be called concurrently with the same
|
// Internally threaded; must not be called concurrently with the same
|
||||||
// `ThreadingContext` (used via `parallel`).
|
// `ThreadingContext` (used via `parallel`).
|
||||||
MMStorage(const Allocator& allocator, MMParallel& parallel)
|
MMStorage(const Allocator& allocator, MMParallel& parallel)
|
||||||
|
: out_rows(hwy::AllocateAligned<uint8_t*>(kMaxM)),
|
||||||
// Per-worker copies of `partial` would be wasteful. We instead allocate
|
// Per-worker copies of `partial` would be wasteful. We instead allocate
|
||||||
// one instance of the maximum matrix extents because threads write at
|
// one instance of the maximum matrix extents because threads write at
|
||||||
// false-sharing-free granularity.
|
// false-sharing-free granularity.
|
||||||
: partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN),
|
partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN),
|
||||||
MatPadding::kOdd),
|
MatPadding::kOdd),
|
||||||
// Same stride independent of the actual C.Cols() so we can pre-bind.
|
// Same stride independent of the actual C.Cols() so we can pre-bind.
|
||||||
partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) {
|
partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) {
|
||||||
|
|
@ -220,6 +221,8 @@ class MMStorage {
|
||||||
BindC(partial_storage_, parallel);
|
BindC(partial_storage_, parallel);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint8_t*& OutRow(size_t row_idx) { return out_rows[row_idx]; }
|
||||||
|
|
||||||
// Returns per-package matrix view.
|
// Returns per-package matrix view.
|
||||||
RowPtrBF A(size_t pkg_idx, const Extents2D& extents) const {
|
RowPtrBF A(size_t pkg_idx, const Extents2D& extents) const {
|
||||||
HWY_DASSERT(extents.rows <= kMaxM);
|
HWY_DASSERT(extents.rows <= kMaxM);
|
||||||
|
|
@ -231,6 +234,11 @@ class MMStorage {
|
||||||
RowPtrD Partial() const { return partial_; }
|
RowPtrD Partial() const { return partial_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// Enables arbitrary output rows. Most callers pass `RowPtr`, which assumes a
|
||||||
|
// constant stride, but GemmaAttention::ComputeQKV writes to differing KV
|
||||||
|
// positions per query / output row. `kMaxM` elements are too large for the
|
||||||
|
// stack, hence dynamic allocation.
|
||||||
|
hwy::AlignedFreeUniquePtr<uint8_t*[]> out_rows;
|
||||||
std::unique_ptr<MatStorageT<BF16>> pkg_A_[MMParallel::kMaxPackages];
|
std::unique_ptr<MatStorageT<BF16>> pkg_A_[MMParallel::kMaxPackages];
|
||||||
MatStorageT<double> partial_storage_;
|
MatStorageT<double> partial_storage_;
|
||||||
RowPtrD partial_;
|
RowPtrD partial_;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue