mirror of https://github.com/google/gemma.cpp.git
Refactor MatMul to accept views in the kernel functions
Make arg order consistent. Move StridedView into mat.h. Add view support to RowPtrs. PiperOrigin-RevId: 805197381
This commit is contained in:
parent
f10ac41a20
commit
9457258330
193
ops/matmul-inl.h
193
ops/matmul-inl.h
|
|
@ -148,21 +148,21 @@ class MMStoreHorizontalSumsIntoC {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scales the dot-product terms and adds bias (if present) and stores the
|
// Scales the dot-product terms plus `add` (if non-null) and stores the four
|
||||||
// four 4-wide vectors to `C` starting at `(row_c, col_c)`. If `tag` is
|
// 4-wide vectors to `C` starting at row 0, column 0. If `tag` is `MMSetC`,
|
||||||
// `MMSetC`, the vectors are written as-is (first call, or small K).
|
// the vectors are written as-is (first call, or small K). Otherwise, they
|
||||||
// Otherwise, they are partial sums and are accumulated into C.
|
// are partial sums and are accumulated into C.
|
||||||
template <class D4, class V4 = hn::Vec<D4>, class Tag, class CRows>
|
template <class D4, class V4 = hn::Vec<D4>, class Tag, class CView>
|
||||||
HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3, Tag tag,
|
HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3,
|
||||||
const size_t row_c, const size_t col_c,
|
const float scale, const float* HWY_RESTRICT add,
|
||||||
const MMArgs& args, CRows C_rows) const {
|
const size_t imc, Tag tag, CView C_rows) const {
|
||||||
const V4 vscale = hn::Set(d4, args.scale);
|
const V4 vscale = hn::Set(d4, scale);
|
||||||
HWY_ALIGN static constexpr float kZero[4] = {};
|
HWY_ALIGN static constexpr float kZero[4] = {};
|
||||||
const V4 vadd = hn::Load(d4, args.add ? args.add + col_c : kZero);
|
const V4 vadd = hn::Load(d4, add ? add : kZero);
|
||||||
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, C_rows, row_c, col_c);
|
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, imc, C_rows);
|
||||||
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, C_rows, row_c, col_c);
|
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, imc, C_rows);
|
||||||
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, C_rows, row_c, col_c);
|
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, imc, C_rows);
|
||||||
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, C_rows, row_c, col_c);
|
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, imc, C_rows);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
@ -199,13 +199,13 @@ class MMStoreHorizontalSumsIntoC {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t kRow, /*deduced:*/ class DF4, class VF4 = hn::Vec<DF4>,
|
template <size_t kRow, /*deduced:*/ class DF4, class VF4 = hn::Vec<DF4>,
|
||||||
class Tag, typename TC>
|
class Tag, class CView>
|
||||||
static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale,
|
static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale,
|
||||||
VF4 vadd, Tag, RowPtrs<TC> C_rows,
|
VF4 vadd, Tag, const size_t imc,
|
||||||
const size_t row_c,
|
CView C_view) {
|
||||||
const size_t col_c) {
|
|
||||||
if constexpr (kRow < kRowsAC) {
|
if constexpr (kRow < kRowsAC) {
|
||||||
TC* HWY_RESTRICT pos = C_rows[row_c + kRow] + col_c;
|
using TC = hwy::RemoveCvRef<decltype(C_view.Row(0)[0])>;
|
||||||
|
TC* HWY_RESTRICT pos = C_view.Row(imc + kRow);
|
||||||
const hn::Rebind<TC, DF4> dc4;
|
const hn::Rebind<TC, DF4> dc4;
|
||||||
if constexpr (hwy::IsSame<Tag, MMAddC>()) {
|
if constexpr (hwy::IsSame<Tag, MMAddC>()) {
|
||||||
vadd = F32FromTC(dc4, hn::Load(dc4, pos)); // load prior value
|
vadd = F32FromTC(dc4, hn::Load(dc4, pos)); // load prior value
|
||||||
|
|
@ -234,7 +234,7 @@ class MMDecompress {
|
||||||
|
|
||||||
// Neither A nor B require padding because `LoopKC` handles remainders.
|
// Neither A nor B require padding because `LoopKC` handles remainders.
|
||||||
if constexpr (hwy::IsSame<TB, BF16>()) {
|
if constexpr (hwy::IsSame<TB, BF16>()) {
|
||||||
return View(B, row_b, range_kc.begin(), range_kc.Num());
|
return StridedViewBF(B, row_b, range_kc.begin(), range_kc.Num());
|
||||||
}
|
}
|
||||||
|
|
||||||
const PackedSpan<const TB> B_span = B.PaddedSpan();
|
const PackedSpan<const TB> B_span = B.PaddedSpan();
|
||||||
|
|
@ -264,7 +264,7 @@ class MMDecompress {
|
||||||
if constexpr (IsBF16<TA>()) {
|
if constexpr (IsBF16<TA>()) {
|
||||||
// We can use a view, regardless of columns/padding, because
|
// We can use a view, regardless of columns/padding, because
|
||||||
// `MMKernel::LoopKC` supports non-vector multiples.
|
// `MMKernel::LoopKC` supports non-vector multiples.
|
||||||
return View(A, 0, 0, A.Cols());
|
return StridedViewBF(A, 0, 0, A.Cols());
|
||||||
} else {
|
} else {
|
||||||
// Always decompress. To reduce code size/compile time, we no longer
|
// Always decompress. To reduce code size/compile time, we no longer
|
||||||
// support a separate F32 kernel; most A are already BF16. We also only
|
// support a separate F32 kernel; most A are already BF16. We also only
|
||||||
|
|
@ -277,15 +277,6 @@ class MMDecompress {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
|
||||||
template <typename T>
|
|
||||||
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c,
|
|
||||||
size_t cols) {
|
|
||||||
HWY_DASSERT(c < AB.Cols());
|
|
||||||
HWY_DASSERT(cols <= AB.Cols() - c);
|
|
||||||
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decompresses all `M x K` from `A` into padded BF16 `A_view`.
|
// Decompresses all `M x K` from `A` into padded BF16 `A_view`.
|
||||||
static HWY_NOINLINE void DecompressA(const MatPtrT<float>& A,
|
static HWY_NOINLINE void DecompressA(const MatPtrT<float>& A,
|
||||||
const StridedViewBF A_view,
|
const StridedViewBF A_view,
|
||||||
|
|
@ -402,26 +393,26 @@ class MMKernel {
|
||||||
kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16);
|
kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view`
|
|
||||||
// is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0.
|
|
||||||
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
|
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
|
||||||
|
// Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` is
|
||||||
|
// `mc x kc` and `B_view` is `(kNR x kc)`. All views, including `add`, start
|
||||||
|
// at row/col 0. `CView` is either `RowPtrs<TC>` or `StridedView<TC>`.
|
||||||
// Called by B3A2C0 and by callers that hoist `A_view`.
|
// Called by B3A2C0 and by callers that hoist `A_view`.
|
||||||
template <class Tag, class CRows>
|
template <class Tag, class CView>
|
||||||
static HWY_INLINE void A2C0(const StridedViewBF A_view,
|
static HWY_INLINE void A2C0(const StridedViewBF A_view,
|
||||||
const StridedViewBF B_view, size_t mr,
|
const StridedViewBF B_view, size_t mr,
|
||||||
const IndexRange& range_mc, const size_t row_b,
|
const IndexRange& range_mc, size_t kc,
|
||||||
size_t kc, Tag tag, const MMArgs& args,
|
const float scale, const float* HWY_RESTRICT add,
|
||||||
CRows C_rows) {
|
Tag tag, CView C_view) {
|
||||||
HWY_DASSERT(1 <= mr && mr <= kMaxMR);
|
HWY_DASSERT(1 <= mr && mr <= kMaxMR);
|
||||||
const size_t row0 = range_mc.begin();
|
|
||||||
const size_t mc = range_mc.Num();
|
const size_t mc = range_mc.Num();
|
||||||
size_t imc = 0;
|
size_t imc = 0;
|
||||||
|
|
||||||
// M == 1, or x86 with 8 SIMD registers:
|
// M == 1, or x86 with 8 SIMD registers:
|
||||||
if (HWY_UNLIKELY(mr == 1)) {
|
if (HWY_UNLIKELY(mr == 1)) {
|
||||||
for (; imc < mc; ++imc) {
|
for (; imc < mc; ++imc) {
|
||||||
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view);
|
||||||
C_rows);
|
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
@ -430,13 +421,11 @@ class MMKernel {
|
||||||
if (HWY_UNLIKELY(mr == 2)) {
|
if (HWY_UNLIKELY(mr == 2)) {
|
||||||
if (HWY_LIKELY(mc >= 2)) {
|
if (HWY_LIKELY(mc >= 2)) {
|
||||||
for (; imc <= mc - 2; imc += 2) {
|
for (; imc <= mc - 2; imc += 2) {
|
||||||
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_view);
|
||||||
C_rows);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (HWY_UNLIKELY(imc != mc)) {
|
if (HWY_UNLIKELY(imc != mc)) {
|
||||||
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view);
|
||||||
C_rows);
|
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
@ -444,18 +433,17 @@ class MMKernel {
|
||||||
HWY_DASSERT(mr == 4);
|
HWY_DASSERT(mr == 4);
|
||||||
if (HWY_LIKELY(mc >= 4)) {
|
if (HWY_LIKELY(mc >= 4)) {
|
||||||
for (; imc <= mc - 4; imc += 4) {
|
for (; imc <= mc - 4; imc += 4) {
|
||||||
LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
LoopKC<4>(A_view, B_view, imc, kc, scale, add, tag, C_view);
|
||||||
C_rows);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const size_t remainder_mc = mc - imc;
|
const size_t remainder_mc = mc - imc;
|
||||||
HWY_DASSERT(remainder_mc < 4);
|
HWY_DASSERT(remainder_mc < 4);
|
||||||
if (HWY_UNLIKELY(remainder_mc & 2)) {
|
if (HWY_UNLIKELY(remainder_mc & 2)) {
|
||||||
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows);
|
LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_view);
|
||||||
imc += 2;
|
imc += 2;
|
||||||
}
|
}
|
||||||
if (HWY_UNLIKELY(remainder_mc & 1)) {
|
if (HWY_UNLIKELY(remainder_mc & 1)) {
|
||||||
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows);
|
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view);
|
||||||
imc += 1;
|
imc += 1;
|
||||||
}
|
}
|
||||||
HWY_DASSERT(imc == mc);
|
HWY_DASSERT(imc == mc);
|
||||||
|
|
@ -466,11 +454,11 @@ class MMKernel {
|
||||||
// Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads
|
// Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads
|
||||||
// `mc x kc` of A, `nc x kc` of B, and updates `mc x nc` of C. Called by
|
// `mc x kc` of A, `nc x kc` of B, and updates `mc x nc` of C. Called by
|
||||||
// `ForeachKC` and when there is only a single KC task.
|
// `ForeachKC` and when there is only a single KC task.
|
||||||
template <typename TB, typename Tag, class CRows>
|
template <typename TB, typename TC, typename Tag>
|
||||||
static void B3A2C0(const StridedViewBF A, const MatPtrT<TB>& B,
|
static void B3A2C0(const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
const MMArgs& args, const IndexRange& range_mc,
|
const IndexRange& range_mc, const IndexRange& range_kc,
|
||||||
const IndexRange& range_kc, const IndexRange& range_nc,
|
const IndexRange& range_nc, const MMArgs& args,
|
||||||
size_t mr, Tag out_tag, CRows C_rows) {
|
Tag out_tag, RowPtrs<TC> C) {
|
||||||
HWY_ALIGN BF16 B_storage[B_storage_max];
|
HWY_ALIGN BF16 B_storage[B_storage_max];
|
||||||
|
|
||||||
const size_t kc = range_kc.Num();
|
const size_t kc = range_kc.Num();
|
||||||
|
|
@ -482,24 +470,28 @@ class MMKernel {
|
||||||
|
|
||||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||||
row_b += kNR) {
|
row_b += kNR) {
|
||||||
StridedViewBF B_view =
|
const StridedViewBF B_view =
|
||||||
MMDecompress::DecompressB(B, row_b, range_kc, B_storage_view);
|
MMDecompress::DecompressB(B, row_b, range_kc, B_storage_view);
|
||||||
A2C0(A_view, B_view, mr, range_mc, row_b, kc, out_tag, args, C_rows);
|
const RowPtrs<TC> C_view = C.View(range_mc.begin(), row_b);
|
||||||
|
const float* HWY_RESTRICT add = args.add ? args.add + row_b : nullptr;
|
||||||
|
A2C0(A_view, B_view, args.mr, range_mc, kc, args.scale, add, out_tag,
|
||||||
|
C_view);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TB, class CRows>
|
template <typename TB, typename TC>
|
||||||
static void ForeachKC(const StridedViewBF A, const MatPtrT<TB>& B,
|
static void ForeachKC(const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
const MMArgs& args, const IndexRange& range_mc,
|
const IndexRange& range_mc,
|
||||||
const IndexRangePartition& ranges_kc,
|
const IndexRangePartition& ranges_kc,
|
||||||
const IndexRange& range_nc, size_t mr, CRows C_rows) {
|
const IndexRange& range_nc, const MMArgs& args,
|
||||||
|
RowPtrs<TC> C) {
|
||||||
// Peel off the first iteration of the kc loop: avoid zero-initializing `C`
|
// Peel off the first iteration of the kc loop: avoid zero-initializing `C`
|
||||||
// by writing directly into it, and later accumulating into it.
|
// by writing directly into it, and later accumulating into it.
|
||||||
ranges_kc.VisitFirst([&](const IndexRange& range_kc) {
|
ranges_kc.VisitFirst([&](const IndexRange& range_kc) {
|
||||||
B3A2C0(A, B, args, range_mc, range_kc, range_nc, mr, MMSetC(), C_rows);
|
B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMSetC(), C);
|
||||||
});
|
});
|
||||||
ranges_kc.VisitRemaining([&](const IndexRange& range_kc) {
|
ranges_kc.VisitRemaining([&](const IndexRange& range_kc) {
|
||||||
B3A2C0(A, B, args, range_mc, range_kc, range_nc, mr, MMAddC(), C_rows);
|
B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMAddC(), C);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -593,19 +585,20 @@ class MMKernel {
|
||||||
// Innermost loop over `kc` columns (typically 1024-4096, not necessarily a
|
// Innermost loop over `kc` columns (typically 1024-4096, not necessarily a
|
||||||
// multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view`
|
// multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view`
|
||||||
// from range_mc-relative `imc` and `B_view` from row 0 (both at column 0).
|
// from range_mc-relative `imc` and `B_view` from row 0 (both at column 0).
|
||||||
// Updates a `kRowsAC x kNR` tile with top-left `C.Row(row_ac) + col_c`.
|
// Updates a `kRowsAC x kNR` tile in `C_view` starting at row `imc`, column 0.
|
||||||
// `A` and `B` are always BF16, `C` can be F32 or BF16.
|
// `A` and `B` are always BF16, `C` can be F32 or BF16. `add` is also
|
||||||
template <size_t kRowsAC, /*deduced:*/ class Tag, class CRows>
|
// relative to the C column.
|
||||||
|
template <size_t kRowsAC, /*deduced:*/ class Tag, class CView>
|
||||||
static HWY_INLINE void LoopKC(const StridedViewBF A_view,
|
static HWY_INLINE void LoopKC(const StridedViewBF A_view,
|
||||||
const StridedViewBF B_view, size_t row_ac,
|
const StridedViewBF B_view, size_t imc,
|
||||||
size_t imc, size_t col_c, size_t kc, Tag tag,
|
size_t kc, const float scale,
|
||||||
const MMArgs& args, CRows C_rows) {
|
const float* HWY_RESTRICT add, Tag tag,
|
||||||
|
CView C_view) {
|
||||||
const hn::ScalableTag<BF16> dbf;
|
const hn::ScalableTag<BF16> dbf;
|
||||||
using VBF = hn::Vec<decltype(dbf)>;
|
using VBF = hn::Vec<decltype(dbf)>;
|
||||||
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
|
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
|
||||||
|
|
||||||
HWY_DASSERT(kRowsAC <= kMaxMR);
|
HWY_DASSERT(kRowsAC <= kMaxMR);
|
||||||
HWY_DASSERT(col_c % kNR == 0);
|
|
||||||
// Rows are aligned to `kMaxMR`, except for the last tile of A.
|
// Rows are aligned to `kMaxMR`, except for the last tile of A.
|
||||||
|
|
||||||
// `kRowsAC` rows of A (null for the rest) and `kNR` rows of B.
|
// `kRowsAC` rows of A (null for the rest) and `kNR` rows of B.
|
||||||
|
|
@ -784,7 +777,7 @@ class MMKernel {
|
||||||
hn::Vec<decltype(d4)> sum0, sum1, sum2, sum3;
|
hn::Vec<decltype(d4)> sum0, sum1, sum2, sum3;
|
||||||
horz.Reduce4x4(df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22,
|
horz.Reduce4x4(df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22,
|
||||||
C23, C30, C31, C32, C33, sum0, sum1, sum2, sum3);
|
C23, C30, C31, C32, C33, sum0, sum1, sum2, sum3);
|
||||||
horz.Store(d4, sum0, sum1, sum2, sum3, tag, row_ac, col_c, args, C_rows);
|
horz.Store(d4, sum0, sum1, sum2, sum3, scale, add, imc, tag, C_view);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -884,7 +877,7 @@ class MMLoops {
|
||||||
// or with the best config.
|
// or with the best config.
|
||||||
template <typename TB, typename TC>
|
template <typename TB, typename TC>
|
||||||
static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT<TB>& B,
|
static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
RowPtrs<TC> C_rows, const MMArgs& args) {
|
RowPtrs<TC> C, const MMArgs& args) {
|
||||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch");
|
static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch");
|
||||||
PROFILER_ZONE3(args.env.ctx.profiler,
|
PROFILER_ZONE3(args.env.ctx.profiler,
|
||||||
args.env.ctx.Worker(args.options.cluster_idx), zone);
|
args.env.ctx.Worker(args.options.cluster_idx), zone);
|
||||||
|
|
@ -892,7 +885,7 @@ class MMLoops {
|
||||||
DispatchParallelism(
|
DispatchParallelism(
|
||||||
args.options.parallelism, [&](const auto& parallel) HWY_ATTR {
|
args.options.parallelism, [&](const auto& parallel) HWY_ATTR {
|
||||||
DispatchOrder(args.order, [&](const auto& order) HWY_ATTR {
|
DispatchOrder(args.order, [&](const auto& order) HWY_ATTR {
|
||||||
Loop(order, parallel, A, B, C_rows, args);
|
Loop(order, parallel, A, B, C, args);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -904,11 +897,11 @@ class MMLoops {
|
||||||
return HWY_MAX(kNR, line_bytes / sizeof_TC);
|
return HWY_MAX(kNR, line_bytes / sizeof_TC);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Single M and K ranges, parallel N. Fills all of C directly.
|
// Single M and K ranges, parallel N.
|
||||||
template <typename TB, typename TC, class Parallel>
|
template <typename TB, typename TC, class Parallel>
|
||||||
static HWY_INLINE void Loop(MMOrderNT, Parallel parallel,
|
static HWY_INLINE void Loop(MMOrderNT, Parallel parallel,
|
||||||
const StridedViewBF A, const MatPtrT<TB>& B,
|
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
RowPtrs<TC> C_rows, const MMArgs& args) {
|
RowPtrs<TC> C, const MMArgs& args) {
|
||||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT");
|
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT");
|
||||||
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
|
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
|
||||||
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
|
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
|
||||||
|
|
@ -932,10 +925,21 @@ class MMLoops {
|
||||||
|
|
||||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||||
row_b += kNR) {
|
row_b += kNR) {
|
||||||
StridedViewBF B_view =
|
const StridedViewBF B_view =
|
||||||
MMDecompress::DecompressB(B, row_b, range_K, B_storage_view);
|
MMDecompress::DecompressB(B, row_b, range_K, B_storage_view);
|
||||||
MMKernel::A2C0(A_view, B_view, args.mr, range_M, row_b, K, MMSetC(),
|
const RowPtrs<TC> C_view = C.View(range_M.begin(), row_b);
|
||||||
args, C_rows);
|
const float* HWY_RESTRICT add =
|
||||||
|
args.add ? args.add + row_b : nullptr;
|
||||||
|
|
||||||
|
MMKernel::A2C0(A_view, B_view, args.mr, range_M, K, args.scale, add,
|
||||||
|
MMSetC(), C_view);
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (IsBF16<TC>()) {
|
||||||
|
if (args.options.fused) {
|
||||||
|
StridedViewBF C2(nullptr, 0, 0);
|
||||||
|
args.options.fused(C, range_M, range_nc, C2, worker);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -944,7 +948,7 @@ class MMLoops {
|
||||||
template <typename TB, typename TC, class Parallel>
|
template <typename TB, typename TC, class Parallel>
|
||||||
static HWY_INLINE void Loop(MMOrderNT_K, Parallel parallel,
|
static HWY_INLINE void Loop(MMOrderNT_K, Parallel parallel,
|
||||||
const StridedViewBF A, const MatPtrT<TB>& B,
|
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
RowPtrs<TC> C_rows, const MMArgs& args) {
|
RowPtrs<TC> C, const MMArgs& args) {
|
||||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_K");
|
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_K");
|
||||||
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
|
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
|
||||||
const IndexRange& range_mc = args.ranges_mc.Range(0);
|
const IndexRange& range_mc = args.ranges_mc.Range(0);
|
||||||
|
|
@ -955,17 +959,24 @@ class MMLoops {
|
||||||
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
||||||
MMKernel::ForeachKC(A, B, args, range_mc, args.ranges_kc,
|
MMKernel::ForeachKC(A, B, range_mc, args.ranges_kc,
|
||||||
range_nc, args.mr, C_rows);
|
range_nc, args, C);
|
||||||
|
|
||||||
|
if constexpr (IsBF16<TC>()) {
|
||||||
|
if (args.options.fused) {
|
||||||
|
StridedViewBF C2(nullptr, 0, 0);
|
||||||
|
args.options.fused(C, range_mc, range_nc, C2, worker);
|
||||||
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parallel loops over mc/nc blocks of M/range_n, single K.
|
// Parallel loops over mc/nc blocks of M/range_n, single K.
|
||||||
// Fills `mc x nc` sections of C directly, in parallel.
|
// Fills `mc x nc` sections of C.
|
||||||
template <typename TB, typename TC, class Parallel>
|
template <typename TB, typename TC, class Parallel>
|
||||||
static HWY_INLINE void Loop(MMOrderNT_MT, Parallel parallel,
|
static HWY_INLINE void Loop(MMOrderNT_MT, Parallel parallel,
|
||||||
const StridedViewBF A, const MatPtrT<TB>& B,
|
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
RowPtrs<TC> C_rows, const MMArgs& args) {
|
RowPtrs<TC> C, const MMArgs& args) {
|
||||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT");
|
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT");
|
||||||
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
|
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
|
||||||
const IndexRange& range_K = args.ranges_kc.Range(0);
|
const IndexRange& range_K = args.ranges_kc.Range(0);
|
||||||
|
|
@ -976,17 +987,24 @@ class MMLoops {
|
||||||
size_t worker) HWY_ATTR {
|
size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
||||||
MMKernel::B3A2C0(A, B, args, range_mc, range_K, range_nc, args.mr,
|
MMKernel::B3A2C0(A, B, range_mc, range_K, range_nc, args, MMSetC(),
|
||||||
MMSetC(), C_rows);
|
C);
|
||||||
|
|
||||||
|
if constexpr (IsBF16<TC>()) {
|
||||||
|
if (args.options.fused) {
|
||||||
|
StridedViewBF C2(nullptr, 0, 0);
|
||||||
|
args.options.fused(C, range_mc, range_nc, C2, worker);
|
||||||
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
|
// Parallel loops over mc/nc blocks of M/range_n, sequential K.
|
||||||
// Accumulates into `mc x nc` sections of `C`.
|
// Accumulates into `mc x nc` sections of `C`.
|
||||||
template <typename TB, typename TC, class Parallel>
|
template <typename TB, typename TC, class Parallel>
|
||||||
static HWY_INLINE void Loop(MMOrderNT_MT_K, Parallel parallel,
|
static HWY_INLINE void Loop(MMOrderNT_MT_K, Parallel parallel,
|
||||||
const StridedViewBF A, const MatPtrT<TB>& B,
|
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
RowPtrs<TC> C_rows, const MMArgs& args) {
|
RowPtrs<TC> C, const MMArgs& args) {
|
||||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT_K");
|
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT_K");
|
||||||
|
|
||||||
parallel.ForRangesMC_NC(
|
parallel.ForRangesMC_NC(
|
||||||
|
|
@ -995,8 +1013,15 @@ class MMLoops {
|
||||||
size_t worker) HWY_ATTR {
|
size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
||||||
MMKernel::ForeachKC(A, B, args, range_mc, args.ranges_kc, range_nc,
|
MMKernel::ForeachKC(A, B, range_mc, args.ranges_kc, range_nc, args,
|
||||||
args.mr, C_rows);
|
C);
|
||||||
|
|
||||||
|
if constexpr (IsBF16<TC>()) {
|
||||||
|
if (args.options.fused) {
|
||||||
|
StridedViewBF C2(nullptr, 0, 0);
|
||||||
|
args.options.fused(C, range_mc, range_nc, C2, worker);
|
||||||
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}; // MMLoops
|
}; // MMLoops
|
||||||
|
|
|
||||||
61
ops/matmul.h
61
ops/matmul.h
|
|
@ -60,54 +60,6 @@ HWY_INLINE_VAR constexpr size_t kMaxNC = 16384; // TODO: shrink?
|
||||||
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
|
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
|
||||||
HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024;
|
HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024;
|
||||||
|
|
||||||
// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows.
|
|
||||||
// Also used to decompress B, hence non-const.
|
|
||||||
#pragma pack(push, 1) // power of two size
|
|
||||||
template <typename T>
|
|
||||||
class StridedView {
|
|
||||||
public:
|
|
||||||
StridedView(T* HWY_RESTRICT row0, size_t cols, size_t stride)
|
|
||||||
: row0_(row0),
|
|
||||||
cols_(static_cast<uint32_t>(cols)),
|
|
||||||
stride_(static_cast<uint32_t>(stride)) {
|
|
||||||
HWY_DASSERT(stride >= cols);
|
|
||||||
}
|
|
||||||
|
|
||||||
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
|
|
||||||
size_t Cols() const { return static_cast<size_t>(cols_); }
|
|
||||||
|
|
||||||
size_t Stride() const { return static_cast<size_t>(stride_); }
|
|
||||||
void SetStride(size_t stride) {
|
|
||||||
HWY_DASSERT(stride >= Cols());
|
|
||||||
stride_ = stride;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
|
||||||
StridedView<T> View(size_t r, size_t c, size_t cols) const {
|
|
||||||
HWY_DASSERT(c < Cols());
|
|
||||||
HWY_DASSERT(cols <= Cols() - c);
|
|
||||||
return StridedView<T>(Row(r) + c, cols, stride_);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
T* HWY_RESTRICT row0_;
|
|
||||||
uint32_t cols_;
|
|
||||||
uint32_t stride_;
|
|
||||||
};
|
|
||||||
#pragma pack(pop)
|
|
||||||
|
|
||||||
using StridedViewBF = StridedView<BF16>;
|
|
||||||
using StridedViewD = StridedView<double>;
|
|
||||||
|
|
||||||
using MMFused = std::function<void(StridedViewBF, size_t, size_t)>;
|
|
||||||
|
|
||||||
struct MMOptions {
|
|
||||||
uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
|
|
||||||
ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical;
|
|
||||||
|
|
||||||
MMFused fused;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Policy classes for parallelism, implementing some of `ParallelismStrategy`.
|
// Policy classes for parallelism, implementing some of `ParallelismStrategy`.
|
||||||
|
|
||||||
struct MMParallelNone {
|
struct MMParallelNone {
|
||||||
|
|
@ -735,6 +687,19 @@ struct MatMulEnv {
|
||||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Called with the entire C matrix, the sub-ranges of M (rows) and N (cols)
|
||||||
|
// that this thread has just filled, a view into a second tile (only for the
|
||||||
|
// upcoming `GatedMatmul`), and the worker thread index (see `ParallelFor`).
|
||||||
|
using MMFused = std::function<void(RowPtrsBF, IndexRange, IndexRange,
|
||||||
|
StridedViewBF, size_t)>;
|
||||||
|
|
||||||
|
struct MMOptions {
|
||||||
|
uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
|
||||||
|
ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical;
|
||||||
|
|
||||||
|
MMFused fused; // called if non-null and `TC` is BF16.
|
||||||
|
};
|
||||||
|
|
||||||
// Arguments to MatMul() that are independent of the A/B/C types. Reduces
|
// Arguments to MatMul() that are independent of the A/B/C types. Reduces
|
||||||
// register pressure compared to individual values/references. Also used for
|
// register pressure compared to individual values/references. Also used for
|
||||||
// passing through `DispatchOrder`.
|
// passing through `DispatchOrder`.
|
||||||
|
|
|
||||||
84
util/mat.h
84
util/mat.h
|
|
@ -38,17 +38,27 @@ namespace gcpp {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class RowPtrs {
|
class RowPtrs {
|
||||||
public:
|
public:
|
||||||
RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs) {}
|
RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs), r0_(0), c0_(0) {}
|
||||||
|
|
||||||
|
RowPtrs View(size_t r, size_t c) {
|
||||||
|
RowPtrs<T> view(row_ptrs_);
|
||||||
|
view.r0_ = static_cast<uint32_t>(r);
|
||||||
|
view.c0_ = static_cast<uint32_t>(c);
|
||||||
|
return view;
|
||||||
|
}
|
||||||
|
|
||||||
T* HWY_RESTRICT Row(size_t row_idx) const {
|
T* HWY_RESTRICT Row(size_t row_idx) const {
|
||||||
return HWY_RCAST_ALIGNED(T*, row_ptrs_[row_idx]);
|
return HWY_RCAST_ALIGNED(T*, row_ptrs_[r0_ + row_idx]) + c0_;
|
||||||
}
|
}
|
||||||
T* HWY_RESTRICT operator[](size_t row_idx) const { return Row(row_idx); }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
uint8_t** row_ptrs_;
|
uint8_t** row_ptrs_;
|
||||||
|
uint32_t r0_;
|
||||||
|
uint32_t c0_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using RowPtrsBF = RowPtrs<BF16>;
|
||||||
|
|
||||||
// Type-erased, non-owning pointer and metadata for rank-1 or 2 tensors (vector
|
// Type-erased, non-owning pointer and metadata for rank-1 or 2 tensors (vector
|
||||||
// or matrix). Base class of the non-type-erased `MatPtrT`. Use this class
|
// or matrix). Base class of the non-type-erased `MatPtrT`. Use this class
|
||||||
// to store hetereogeneous tensor references in a vector.
|
// to store hetereogeneous tensor references in a vector.
|
||||||
|
|
@ -349,12 +359,12 @@ RowPtrs<T> GetOrSetTempRowPtrs(
|
||||||
template <class Func, typename... Args>
|
template <class Func, typename... Args>
|
||||||
decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
|
decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
|
||||||
Args&&... args) {
|
Args&&... args) {
|
||||||
#if GEMMA_ENABLE_NUQ
|
if constexpr (GEMMA_ENABLE_NUQ) {
|
||||||
if (base->GetType() == Type::kNUQ) {
|
if (base->GetType() == Type::kNUQ) {
|
||||||
const MatPtrT<NuqStream> mat(*base);
|
const MatPtrT<NuqStream> mat(*base);
|
||||||
return func(&mat, std::forward<Args>(args)...);
|
return func(&mat, std::forward<Args>(args)...);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#endif // GEMMA_ENABLE_NUQ
|
|
||||||
|
|
||||||
if (base->GetType() == Type::kF32) {
|
if (base->GetType() == Type::kF32) {
|
||||||
const MatPtrT<float> mat(*base);
|
const MatPtrT<float> mat(*base);
|
||||||
|
|
@ -376,13 +386,13 @@ decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
|
||||||
const Func& func, Args&&... args) {
|
const Func& func, Args&&... args) {
|
||||||
HWY_DASSERT(base1->GetType() == base2->GetType());
|
HWY_DASSERT(base1->GetType() == base2->GetType());
|
||||||
|
|
||||||
#if GEMMA_ENABLE_NUQ
|
if constexpr (GEMMA_ENABLE_NUQ) {
|
||||||
if (base1->GetType() == Type::kNUQ) {
|
if (base1->GetType() == Type::kNUQ) {
|
||||||
const MatPtrT<NuqStream> mat1(*base1);
|
const MatPtrT<NuqStream> mat1(*base1);
|
||||||
const MatPtrT<NuqStream> mat2(*base2);
|
const MatPtrT<NuqStream> mat2(*base2);
|
||||||
return func(&mat1, &mat2, std::forward<Args>(args)...);
|
return func(&mat1, &mat2, std::forward<Args>(args)...);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#endif // GEMMA_ENABLE_NUQ
|
|
||||||
|
|
||||||
if (base1->GetType() == Type::kF32) {
|
if (base1->GetType() == Type::kF32) {
|
||||||
const MatPtrT<float> mat1(*base1);
|
const MatPtrT<float> mat1(*base1);
|
||||||
|
|
@ -508,5 +518,51 @@ class MatFactory {
|
||||||
MatPadding padding_;
|
MatPadding padding_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows.
|
||||||
|
// Also used to decompress B, hence non-const.
|
||||||
|
#pragma pack(push, 1) // power of two size
|
||||||
|
template <typename T>
|
||||||
|
class StridedView {
|
||||||
|
public:
|
||||||
|
StridedView(T* HWY_RESTRICT row0, size_t cols, size_t stride)
|
||||||
|
: row0_(row0),
|
||||||
|
cols_(static_cast<uint32_t>(cols)),
|
||||||
|
stride_(static_cast<uint32_t>(stride)) {
|
||||||
|
HWY_DASSERT(stride >= cols);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
||||||
|
StridedView(const MatPtrT<T>& mat, size_t r, size_t c, size_t cols)
|
||||||
|
: StridedView(const_cast<T*>(mat.Row(r)) + c, cols, mat.Stride()) {
|
||||||
|
HWY_DASSERT(c < mat.Cols());
|
||||||
|
HWY_DASSERT(cols <= mat.Cols() - c);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
||||||
|
StridedView<T> View(size_t r, size_t c, size_t cols) const {
|
||||||
|
HWY_DASSERT(c < Cols());
|
||||||
|
HWY_DASSERT(cols <= Cols() - c);
|
||||||
|
return StridedView<T>(Row(r) + c, cols, stride_);
|
||||||
|
}
|
||||||
|
|
||||||
|
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
|
||||||
|
size_t Cols() const { return static_cast<size_t>(cols_); }
|
||||||
|
|
||||||
|
size_t Stride() const { return static_cast<size_t>(stride_); }
|
||||||
|
void SetStride(size_t stride) {
|
||||||
|
HWY_DASSERT(stride >= Cols());
|
||||||
|
stride_ = stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
T* HWY_RESTRICT row0_;
|
||||||
|
uint32_t cols_;
|
||||||
|
uint32_t stride_;
|
||||||
|
};
|
||||||
|
#pragma pack(pop)
|
||||||
|
|
||||||
|
using StridedViewBF = StridedView<BF16>;
|
||||||
|
using StridedViewD = StridedView<double>;
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue