Rename RowPtr->StridedView, CRows->RowPtrs

PiperOrigin-RevId: 770046362
This commit is contained in:
Jan Wassenberg 2025-06-11 02:29:41 -07:00 committed by Copybara-Service
parent b84149310b
commit bd98b43cea
3 changed files with 81 additions and 80 deletions

View File

@ -80,20 +80,6 @@ hn::Vec<DC> TCFromF32(DC dc, hn::Vec<DF> vf) {
return hn::DemoteTo(dc, vf);
}
// Type-safe wrapper over uint8_t row pointers referenced by MatPtrT.
template <typename TC>
class CRows {
public:
CRows(uint8_t** C_rows) : C_rows_(C_rows) {}
TC* HWY_RESTRICT operator[](size_t row_idx) const {
return HWY_RCAST_ALIGNED(TC*, C_rows_[row_idx]);
}
private:
uint8_t** C_rows_;
};
// Tag classes, passed to `MMKernel::A2C0` to choose between writing one
// (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or writing the
// first kc result to partial, or accumulating the next kc result into partial
@ -124,7 +110,7 @@ class MMStoreHorizontalSumsIntoC {
VF C20, VF C21, VF C22, VF C23, //
VF C30, VF C31, VF C32, VF C33, //
const size_t row_c, const size_t col_c,
const MMArgs& args, CRows<TC> C_rows) const {
const MMArgs& args, RowPtrs<TC> C_rows) const {
HWY_ALIGN float buf[16 * hn::MaxLanes(df)];
const size_t N = hn::Lanes(df);
// Horizontal reductions (`ReduceSum`) are rather expensive, entailing
@ -202,7 +188,7 @@ class MMStoreHorizontalSumsIntoC {
template <size_t kRow, /*deduced:*/ class DF4, class VF4 = hn::Vec<DF4>,
typename TC>
static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale,
VF4 vadd, CRows<TC> C_rows,
VF4 vadd, RowPtrs<TC> C_rows,
const size_t row_c,
const size_t col_c) {
if constexpr (kRow < kRowsAC) {
@ -236,7 +222,7 @@ class MMAddHorizontalSumsIntoPartial {
VF F20, VF F21, VF F22, VF F23, //
VF F30, VF F31, VF F32, VF F33, //
const size_t row_c, const size_t col_c,
const RowPtrD& partial) const {
const StridedViewD& partial) const {
// We accumulate in 64-bit to avoid loss of precision.
static_assert(HWY_HAVE_FLOAT64, "Disable Armv7 NEON: we require fp64");
@ -342,7 +328,8 @@ class MMAddHorizontalSumsIntoPartial {
}
template <size_t kRow, class D4, class V4 = hn::Vec<D4>>
static HWY_INLINE void MaybeAddStore(D4 d4, V4 sum, const RowPtrD& partial,
static HWY_INLINE void MaybeAddStore(D4 d4, V4 sum,
const StridedViewD& partial,
const size_t row_c, const size_t col_c) {
if constexpr (kRow < kRowsAC) {
double* HWY_RESTRICT pos = partial.Row(row_c + kRow) + col_c;
@ -371,10 +358,11 @@ class MMKernel {
// is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0.
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
template <class Tag, typename TC>
static HWY_INLINE void A2C0(const RowPtrBF& A_view, const RowPtrBF& B_view,
size_t mr, const IndexRange& range_mc,
const size_t row_b, size_t kc, Tag tag,
const MMArgs& args, CRows<TC> C_rows) {
static HWY_INLINE void A2C0(const StridedViewBF& A_view,
const StridedViewBF& B_view, size_t mr,
const IndexRange& range_mc, const size_t row_b,
size_t kc, Tag tag, const MMArgs& args,
RowPtrs<TC> C_rows) {
HWY_DASSERT(1 <= mr && mr <= kMaxMR);
const size_t row0 = range_mc.begin();
const size_t mc = range_mc.Num();
@ -516,10 +504,10 @@ class MMKernel {
// BF16 so we can load directly without `Decompress2`, which is expensive for
// NUQ and requires 2x unrolling, which requires more loads.
template <size_t kRowsAC, /*deduced:*/ class Tag, typename TC>
static HWY_INLINE void LoopKC(const RowPtrBF& A_view, const RowPtrBF& B_view,
size_t row_ac, size_t imc, size_t col_c,
size_t kc, Tag tag, const MMArgs& args,
CRows<TC> C_rows) {
static HWY_INLINE void LoopKC(const StridedViewBF& A_view,
const StridedViewBF& B_view, size_t row_ac,
size_t imc, size_t col_c, size_t kc, Tag tag,
const MMArgs& args, RowPtrs<TC> C_rows) {
const hn::ScalableTag<BF16> dbf;
using VBF = hn::Vec<decltype(dbf)>;
const size_t NBF = hn::Lanes(dbf);
@ -661,7 +649,7 @@ class MMScaleDemoteAdd {
template <typename TC>
static HWY_INLINE void FillC(const IndexRange& range_mc,
const IndexRange& range_nc, const MMArgs& args,
CRows<TC> C_rows) {
RowPtrs<TC> C_rows) {
size_t row_c = range_mc.begin();
if (args.add) {
constexpr bool kAdd = true;
@ -690,7 +678,7 @@ class MMScaleDemoteAdd {
// Unrolled for 4 rows to reduce the number of loads from `add`.
template <bool kAdd, typename TC>
static HWY_INLINE void Do4Rows(size_t row_c, const IndexRange& range_nc,
const MMArgs& args, CRows<TC> C_rows) {
const MMArgs& args, RowPtrs<TC> C_rows) {
const hn::ScalableTag<double> dd;
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
const hn::Rebind<TC, decltype(dd)> dc;
@ -808,7 +796,7 @@ class MMScaleDemoteAdd {
// Same as above but handles a single row (for remainder rows).
template <bool kAdd, typename TC>
static HWY_INLINE void Do1Row(size_t row_c, const IndexRange& range_nc,
const MMArgs& args, CRows<TC> C_rows) {
const MMArgs& args, RowPtrs<TC> C_rows) {
const hn::ScalableTag<double> dd;
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
const hn::Rebind<TC, decltype(dd)> dc;
@ -901,7 +889,7 @@ class MMPerPackage {
// B is decompressed several call layers lower, but not all member functions
// depend on TB, so pass it as an argument instead of templating the class.
template <typename TB, typename TC>
HWY_NOINLINE void operator()(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
HWY_NOINLINE void operator()(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
switch (order_) {
case MMOrder::kNT:
return DoNT(B, C_rows);
@ -931,7 +919,7 @@ class MMPerPackage {
// Single M and K, parallel N. Fills all of C directly.
template <typename TB, typename TC>
HWY_INLINE void DoNT(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
HWY_INLINE void DoNT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
MMZone zone;
zone.MaybeEnter("MM.NT", args_);
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
@ -939,7 +927,7 @@ class MMPerPackage {
const IndexRange& range_M = ranges_mc_.Range(0);
const IndexRange& range_K = ranges_kc_.Range(0);
const size_t K = range_K.Num();
const RowPtrBF& A_view = A_.View(range_M.begin(), 0, K);
const StridedViewBF& A_view = A_.View(range_M.begin(), 0, K);
const size_t B_stride =
Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_);
@ -948,11 +936,12 @@ class MMPerPackage {
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
[&](const IndexRange& range_nc) HWY_ATTR {
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const RowPtrBF B_storage_view(B_storage, K, B_stride);
const StridedViewBF B_storage_view(B_storage, K, B_stride);
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) {
RowPtrBF B_view = DecompressB(B, row_b, range_K, B_storage_view);
StridedViewBF B_view =
DecompressB(B, row_b, range_K, B_storage_view);
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
args_, C_rows);
}
@ -963,7 +952,7 @@ class MMPerPackage {
// Single M, parallel N, sequential K. Fills all of partial.
template <typename TB, typename TC>
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
MMZone zone;
zone.MaybeEnter("MM.NT_K", args_);
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
@ -976,14 +965,15 @@ class MMPerPackage {
const IndexRange& range_nc,
auto out_tag) HWY_ATTR {
const size_t kc = range_kc.Num();
const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc);
const RowPtrBF B_storage_view(
const StridedViewBF& A_view =
A_.View(range_mc.begin(), range_kc.begin(), kc);
const StridedViewBF B_storage_view(
B_storage, kc,
Stride(MatPadding::kOdd, kc, sizeof(BF16), line_bytes_));
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) {
RowPtrBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
C_rows);
}
@ -1023,7 +1013,7 @@ class MMPerPackage {
// Parallel loops over mc/nc blocks of M/range_np, single K.
// Fills `mc x nc` sections of C directly, in parallel.
template <typename TB, typename TC>
HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
MMZone zone;
zone.MaybeEnter("MM.NT_MT", args_);
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
@ -1037,13 +1027,14 @@ class MMPerPackage {
args_.env->parallel.ForRangesMC_NC(
ranges_mc_, ranges_nc_, pkg_idx_,
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
const RowPtrBF& A_view = A_.View(range_mc.begin(), 0, K);
const StridedViewBF& A_view = A_.View(range_mc.begin(), 0, K);
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const RowPtrBF B_storage_view(B_storage, K, B_stride);
const StridedViewBF B_storage_view(B_storage, K, B_stride);
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) {
RowPtrBF B_view = DecompressB(B, row_b, range_K, B_storage_view);
StridedViewBF B_view =
DecompressB(B, row_b, range_K, B_storage_view);
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
args_, C_rows);
}
@ -1055,7 +1046,7 @@ class MMPerPackage {
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
template <typename TB, typename TC>
HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
MMZone zone;
zone.MaybeEnter("MM.NT_MT_K", args_);
const size_t kc_max = ranges_kc_.TaskSize();
@ -1065,17 +1056,18 @@ class MMPerPackage {
// Sequential loop over NC/MC/KC, for when the M/N loops are
// already parallel. This is B3A2C0 in MOMMS terminology: we read
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`.
const auto loop_nc = [&](const RowPtrBF& B_storage_view,
const auto loop_nc = [&](const StridedViewBF& B_storage_view,
const IndexRange& range_mc,
const IndexRange& range_kc,
const IndexRange& range_nc,
auto out_tag) HWY_ATTR {
const size_t kc = range_kc.Num();
const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc);
const StridedViewBF& A_view =
A_.View(range_mc.begin(), range_kc.begin(), kc);
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) {
RowPtrBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
C_rows);
}
@ -1084,7 +1076,7 @@ class MMPerPackage {
ranges_mc_, ranges_nc_, pkg_idx_,
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const RowPtrBF B_storage_view(B_storage, kc_max, B_stride);
const StridedViewBF B_storage_view(B_storage, kc_max, B_stride);
// Peel off the first iteration of the kc loop: avoid
// zero-initializing `partial` by writing into it.
@ -1166,15 +1158,16 @@ class MMPerPackage {
// Autotuning wrapper for `DoDecompressA`.
template <typename TA>
HWY_INLINE RowPtrBF DecompressA(const MatPtrT<TA>& A) const {
HWY_INLINE StridedViewBF DecompressA(const MatPtrT<TA>& A) const {
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
// If already BF16, maybe return a view:
if constexpr (hwy::IsSame<TA, BF16>()) {
// Only if vector multiple and padded (see `DoDecompressA`).
const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>());
if (HWY_LIKELY(A.Cols() % NBF == 0 && !A.IsPacked())) {
// Actually const, but RowPtr is also used for partial which is not.
return RowPtrBF(const_cast<TA*>(A.Row(0)), A.Cols(), A.Stride());
// Const, but cast because StridedView is also used for `partial` which
// is non-const.
return StridedViewBF(const_cast<TA*>(A.Row(0)), A.Cols(), A.Stride());
}
}
@ -1212,12 +1205,12 @@ class MMPerPackage {
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
// thanks to its large table lookups, and less so on other targets.
template <typename TB>
HWY_INLINE RowPtrBF DecompressB(const MatPtrT<TB>& B, const size_t row_b,
const IndexRange& range_kc,
const RowPtrBF& B_view) const {
HWY_INLINE StridedViewBF DecompressB(const MatPtrT<TB>& B, const size_t row_b,
const IndexRange& range_kc,
const StridedViewBF& B_view) const {
if constexpr (hwy::IsSame<TB, BF16>()) {
return RowPtrBF(const_cast<BF16*>(B.Row(row_b)) + range_kc.begin(),
range_kc.Num(), B.Stride());
return StridedViewBF(const_cast<BF16*>(B.Row(row_b)) + range_kc.begin(),
range_kc.Num(), B.Stride());
}
const hn::ScalableTag<BF16> dbf;
@ -1242,7 +1235,7 @@ class MMPerPackage {
const MMArgs args_; // copy for locality
const size_t pkg_idx_;
RowPtrBF A_; // view into A or pkg_A_, both of which are padded.
StridedViewBF A_; // view into A or pkg_A_, both of which are padded.
const IndexRange range_np_;
// From MMConfig:
@ -1272,7 +1265,7 @@ struct MMImpl {
// or with the best config.
template <typename TA, typename TB, typename TC>
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
CRows<TC> C_rows, const MMArgs& args,
RowPtrs<TC> C_rows, const MMArgs& args,
const MMConfig& config) {
MMZone matmul_zone;
matmul_zone.MaybeEnter("MM.DoMatMul", args);
@ -1314,7 +1307,7 @@ template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
MatPtrT<TC>& C) {
CRows<TC> C_rows(C.GetRowPtrs());
RowPtrs<TC> C_rows(C.GetRowPtrs());
if (HWY_UNLIKELY(!C.GetRowPtrs())) {
if constexpr (HWY_IS_DEBUG_BUILD) {
fprintf(stderr,
@ -1326,7 +1319,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
for (size_t r = 0; r < C.Rows(); ++r) {
env.row_ptrs[0][r] = reinterpret_cast<uint8_t*>(C.Row(r));
}
C_rows = CRows<TC>(env.row_ptrs[0].get());
C_rows = RowPtrs<TC>(env.row_ptrs[0].get());
}
const Allocator& allocator = env.ctx.allocator;

View File

@ -176,12 +176,12 @@ void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel);
// C is BF16/float, or double for partial.
void BindC(MatPtr& C, MMParallel& parallel);
// Lightweight view into `MatStorageT`.
// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows.
#pragma pack(push, 1) // power of two size
template <typename T>
class RowPtr {
class StridedView {
public:
RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride)
StridedView(T* HWY_RESTRICT row0, size_t cols, size_t stride)
: row0_(row0),
cols_(static_cast<uint32_t>(cols)),
stride_(static_cast<uint32_t>(stride)) {
@ -198,10 +198,10 @@ class RowPtr {
}
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
RowPtr<T> View(size_t r, size_t c, size_t cols) const {
StridedView<T> View(size_t r, size_t c, size_t cols) const {
HWY_DASSERT(c < Cols());
HWY_DASSERT(cols <= Cols() - c);
return RowPtr<T>(Row(r) + c, cols, stride_);
return StridedView<T>(Row(r) + c, cols, stride_);
}
private:
@ -211,8 +211,8 @@ class RowPtr {
};
#pragma pack(pop)
using RowPtrBF = RowPtr<BF16>;
using RowPtrD = RowPtr<double>;
using StridedViewBF = StridedView<BF16>;
using StridedViewD = StridedView<double>;
// Per-package storage for packed A, and one global C-shaped `partial` for
// accumulating partial dot products (sections of K).
@ -260,19 +260,19 @@ class MMStorage {
}
// Returns per-package matrix view.
RowPtrBF A(size_t pkg_idx, const Extents2D& extents) const {
StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const {
HWY_DASSERT(extents.rows <= kMaxM);
HWY_DASSERT(extents.cols <= kMaxK);
return RowPtrBF(const_cast<BF16*>(pkg_A_[pkg_idx]->Row(0)), extents.cols,
pkg_A_[pkg_idx]->Stride());
return StridedViewBF(const_cast<BF16*>(pkg_A_[pkg_idx]->Row(0)),
extents.cols, pkg_A_[pkg_idx]->Stride());
}
RowPtrD Partial() const { return partial_; }
StridedViewD Partial() const { return partial_; }
private:
std::unique_ptr<MatStorageT<BF16>> pkg_A_[MMParallel::kMaxPackages];
MatStorageT<double> partial_storage_;
RowPtrD partial_;
StridedViewD partial_;
};
//------------------------------------------------------------------------------
@ -673,7 +673,7 @@ struct MatMulEnv {
// Reduces register pressure compared to individual values/references.
struct MMArgs {
MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale,
const float* HWY_RESTRICT add, const RowPtrD& partial)
const float* HWY_RESTRICT add, const StridedViewD& partial)
: env(&env),
per_key(&per_key),
scale(scale),
@ -686,7 +686,7 @@ struct MMArgs {
double scale;
const float* HWY_RESTRICT add;
// Same size as C, threads write at false-sharing-free granularity.
RowPtrD partial;
StridedViewD partial;
};
// Wrapper over hwy::Zone that is only enabled when autotuning finished.

View File

@ -33,16 +33,19 @@
namespace gcpp {
// Type-safe wrapper over type-erased uint8_t row pointers from MatPtr.
template <typename TC>
class CRows {
// Type-safe wrapper over type-erased uint8_t row pointers from MatPtr. Used
// for C, in future also for A.
template <typename T>
class RowPtrs {
public:
CRows(TC** C_rows) : C_rows_(C_rows) {}
RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs) {}
TC* HWY_RESTRICT operator[](size_t row_idx) const { return C_rows_[row_idx]; }
T* HWY_RESTRICT operator[](size_t row_idx) const {
return HWY_RCAST_ALIGNED(T*, row_ptrs_[row_idx]);
}
private:
TC** C_rows_;
uint8_t** row_ptrs_;
};
// Type-erased, non-owning pointer and metadata for rank-1 or 2 tensors (vector
@ -88,7 +91,9 @@ class MatPtr : public IFields {
bool HasPtr() const { return ptr_ != nullptr; }
// Caller has initialized Rows() pointers in row_ptrs[].
// Caller has initialized Rows() pointers in row_ptrs[]. Note that this only
// changes `GetRowPtrs`, not `Row()`, because that would require branching
// and only a few call sites, in particular MatMul, use row pointers.
void AttachRowPtrs(uint8_t** row_ptrs) {
row_ptrs_ = row_ptrs;
for (size_t r = 0; r < Rows(); ++r) {
@ -96,6 +101,8 @@ class MatPtr : public IFields {
}
}
// Called by Activations to allocate once, rather than have to fill row
// pointers in each call to MatMul.
void AllocateAndAttachRowPtrs(
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs) {
if (!HasPtr()) return;
@ -107,6 +114,7 @@ class MatPtr : public IFields {
AttachRowPtrs(ptrs);
};
// If non-null, this array should be used instead of `Row()`.
uint8_t** GetRowPtrs() const { return row_ptrs_; }
// A single row counts as packed because there is no padding between rows.