mirror of https://github.com/google/gemma.cpp.git
Rename RowPtr->StridedView, CRows->RowPtrs
PiperOrigin-RevId: 770046362
This commit is contained in:
parent
b84149310b
commit
bd98b43cea
111
ops/matmul-inl.h
111
ops/matmul-inl.h
|
|
@ -80,20 +80,6 @@ hn::Vec<DC> TCFromF32(DC dc, hn::Vec<DF> vf) {
|
|||
return hn::DemoteTo(dc, vf);
|
||||
}
|
||||
|
||||
// Type-safe wrapper over uint8_t row pointers referenced by MatPtrT.
|
||||
template <typename TC>
|
||||
class CRows {
|
||||
public:
|
||||
CRows(uint8_t** C_rows) : C_rows_(C_rows) {}
|
||||
|
||||
TC* HWY_RESTRICT operator[](size_t row_idx) const {
|
||||
return HWY_RCAST_ALIGNED(TC*, C_rows_[row_idx]);
|
||||
}
|
||||
|
||||
private:
|
||||
uint8_t** C_rows_;
|
||||
};
|
||||
|
||||
// Tag classes, passed to `MMKernel::A2C0` to choose between writing one
|
||||
// (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or writing the
|
||||
// first kc result to partial, or accumulating the next kc result into partial
|
||||
|
|
@ -124,7 +110,7 @@ class MMStoreHorizontalSumsIntoC {
|
|||
VF C20, VF C21, VF C22, VF C23, //
|
||||
VF C30, VF C31, VF C32, VF C33, //
|
||||
const size_t row_c, const size_t col_c,
|
||||
const MMArgs& args, CRows<TC> C_rows) const {
|
||||
const MMArgs& args, RowPtrs<TC> C_rows) const {
|
||||
HWY_ALIGN float buf[16 * hn::MaxLanes(df)];
|
||||
const size_t N = hn::Lanes(df);
|
||||
// Horizontal reductions (`ReduceSum`) are rather expensive, entailing
|
||||
|
|
@ -202,7 +188,7 @@ class MMStoreHorizontalSumsIntoC {
|
|||
template <size_t kRow, /*deduced:*/ class DF4, class VF4 = hn::Vec<DF4>,
|
||||
typename TC>
|
||||
static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale,
|
||||
VF4 vadd, CRows<TC> C_rows,
|
||||
VF4 vadd, RowPtrs<TC> C_rows,
|
||||
const size_t row_c,
|
||||
const size_t col_c) {
|
||||
if constexpr (kRow < kRowsAC) {
|
||||
|
|
@ -236,7 +222,7 @@ class MMAddHorizontalSumsIntoPartial {
|
|||
VF F20, VF F21, VF F22, VF F23, //
|
||||
VF F30, VF F31, VF F32, VF F33, //
|
||||
const size_t row_c, const size_t col_c,
|
||||
const RowPtrD& partial) const {
|
||||
const StridedViewD& partial) const {
|
||||
// We accumulate in 64-bit to avoid loss of precision.
|
||||
static_assert(HWY_HAVE_FLOAT64, "Disable Armv7 NEON: we require fp64");
|
||||
|
||||
|
|
@ -342,7 +328,8 @@ class MMAddHorizontalSumsIntoPartial {
|
|||
}
|
||||
|
||||
template <size_t kRow, class D4, class V4 = hn::Vec<D4>>
|
||||
static HWY_INLINE void MaybeAddStore(D4 d4, V4 sum, const RowPtrD& partial,
|
||||
static HWY_INLINE void MaybeAddStore(D4 d4, V4 sum,
|
||||
const StridedViewD& partial,
|
||||
const size_t row_c, const size_t col_c) {
|
||||
if constexpr (kRow < kRowsAC) {
|
||||
double* HWY_RESTRICT pos = partial.Row(row_c + kRow) + col_c;
|
||||
|
|
@ -371,10 +358,11 @@ class MMKernel {
|
|||
// is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0.
|
||||
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
|
||||
template <class Tag, typename TC>
|
||||
static HWY_INLINE void A2C0(const RowPtrBF& A_view, const RowPtrBF& B_view,
|
||||
size_t mr, const IndexRange& range_mc,
|
||||
const size_t row_b, size_t kc, Tag tag,
|
||||
const MMArgs& args, CRows<TC> C_rows) {
|
||||
static HWY_INLINE void A2C0(const StridedViewBF& A_view,
|
||||
const StridedViewBF& B_view, size_t mr,
|
||||
const IndexRange& range_mc, const size_t row_b,
|
||||
size_t kc, Tag tag, const MMArgs& args,
|
||||
RowPtrs<TC> C_rows) {
|
||||
HWY_DASSERT(1 <= mr && mr <= kMaxMR);
|
||||
const size_t row0 = range_mc.begin();
|
||||
const size_t mc = range_mc.Num();
|
||||
|
|
@ -516,10 +504,10 @@ class MMKernel {
|
|||
// BF16 so we can load directly without `Decompress2`, which is expensive for
|
||||
// NUQ and requires 2x unrolling, which requires more loads.
|
||||
template <size_t kRowsAC, /*deduced:*/ class Tag, typename TC>
|
||||
static HWY_INLINE void LoopKC(const RowPtrBF& A_view, const RowPtrBF& B_view,
|
||||
size_t row_ac, size_t imc, size_t col_c,
|
||||
size_t kc, Tag tag, const MMArgs& args,
|
||||
CRows<TC> C_rows) {
|
||||
static HWY_INLINE void LoopKC(const StridedViewBF& A_view,
|
||||
const StridedViewBF& B_view, size_t row_ac,
|
||||
size_t imc, size_t col_c, size_t kc, Tag tag,
|
||||
const MMArgs& args, RowPtrs<TC> C_rows) {
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
const size_t NBF = hn::Lanes(dbf);
|
||||
|
|
@ -661,7 +649,7 @@ class MMScaleDemoteAdd {
|
|||
template <typename TC>
|
||||
static HWY_INLINE void FillC(const IndexRange& range_mc,
|
||||
const IndexRange& range_nc, const MMArgs& args,
|
||||
CRows<TC> C_rows) {
|
||||
RowPtrs<TC> C_rows) {
|
||||
size_t row_c = range_mc.begin();
|
||||
if (args.add) {
|
||||
constexpr bool kAdd = true;
|
||||
|
|
@ -690,7 +678,7 @@ class MMScaleDemoteAdd {
|
|||
// Unrolled for 4 rows to reduce the number of loads from `add`.
|
||||
template <bool kAdd, typename TC>
|
||||
static HWY_INLINE void Do4Rows(size_t row_c, const IndexRange& range_nc,
|
||||
const MMArgs& args, CRows<TC> C_rows) {
|
||||
const MMArgs& args, RowPtrs<TC> C_rows) {
|
||||
const hn::ScalableTag<double> dd;
|
||||
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
|
||||
const hn::Rebind<TC, decltype(dd)> dc;
|
||||
|
|
@ -808,7 +796,7 @@ class MMScaleDemoteAdd {
|
|||
// Same as above but handles a single row (for remainder rows).
|
||||
template <bool kAdd, typename TC>
|
||||
static HWY_INLINE void Do1Row(size_t row_c, const IndexRange& range_nc,
|
||||
const MMArgs& args, CRows<TC> C_rows) {
|
||||
const MMArgs& args, RowPtrs<TC> C_rows) {
|
||||
const hn::ScalableTag<double> dd;
|
||||
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
|
||||
const hn::Rebind<TC, decltype(dd)> dc;
|
||||
|
|
@ -901,7 +889,7 @@ class MMPerPackage {
|
|||
// B is decompressed several call layers lower, but not all member functions
|
||||
// depend on TB, so pass it as an argument instead of templating the class.
|
||||
template <typename TB, typename TC>
|
||||
HWY_NOINLINE void operator()(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
|
||||
HWY_NOINLINE void operator()(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
switch (order_) {
|
||||
case MMOrder::kNT:
|
||||
return DoNT(B, C_rows);
|
||||
|
|
@ -931,7 +919,7 @@ class MMPerPackage {
|
|||
|
||||
// Single M and K, parallel N. Fills all of C directly.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
|
||||
HWY_INLINE void DoNT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT", args_);
|
||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||
|
|
@ -939,7 +927,7 @@ class MMPerPackage {
|
|||
const IndexRange& range_M = ranges_mc_.Range(0);
|
||||
const IndexRange& range_K = ranges_kc_.Range(0);
|
||||
const size_t K = range_K.Num();
|
||||
const RowPtrBF& A_view = A_.View(range_M.begin(), 0, K);
|
||||
const StridedViewBF& A_view = A_.View(range_M.begin(), 0, K);
|
||||
const size_t B_stride =
|
||||
Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_);
|
||||
|
||||
|
|
@ -948,11 +936,12 @@ class MMPerPackage {
|
|||
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
|
||||
[&](const IndexRange& range_nc) HWY_ATTR {
|
||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||
const RowPtrBF B_storage_view(B_storage, K, B_stride);
|
||||
const StridedViewBF B_storage_view(B_storage, K, B_stride);
|
||||
|
||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
RowPtrBF B_view = DecompressB(B, row_b, range_K, B_storage_view);
|
||||
StridedViewBF B_view =
|
||||
DecompressB(B, row_b, range_K, B_storage_view);
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
|
||||
args_, C_rows);
|
||||
}
|
||||
|
|
@ -963,7 +952,7 @@ class MMPerPackage {
|
|||
|
||||
// Single M, parallel N, sequential K. Fills all of partial.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
|
||||
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT_K", args_);
|
||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||
|
|
@ -976,14 +965,15 @@ class MMPerPackage {
|
|||
const IndexRange& range_nc,
|
||||
auto out_tag) HWY_ATTR {
|
||||
const size_t kc = range_kc.Num();
|
||||
const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc);
|
||||
const RowPtrBF B_storage_view(
|
||||
const StridedViewBF& A_view =
|
||||
A_.View(range_mc.begin(), range_kc.begin(), kc);
|
||||
const StridedViewBF B_storage_view(
|
||||
B_storage, kc,
|
||||
Stride(MatPadding::kOdd, kc, sizeof(BF16), line_bytes_));
|
||||
|
||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
RowPtrBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
|
||||
StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
||||
C_rows);
|
||||
}
|
||||
|
|
@ -1023,7 +1013,7 @@ class MMPerPackage {
|
|||
// Parallel loops over mc/nc blocks of M/range_np, single K.
|
||||
// Fills `mc x nc` sections of C directly, in parallel.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
|
||||
HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT_MT", args_);
|
||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||
|
|
@ -1037,13 +1027,14 @@ class MMPerPackage {
|
|||
args_.env->parallel.ForRangesMC_NC(
|
||||
ranges_mc_, ranges_nc_, pkg_idx_,
|
||||
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
|
||||
const RowPtrBF& A_view = A_.View(range_mc.begin(), 0, K);
|
||||
const StridedViewBF& A_view = A_.View(range_mc.begin(), 0, K);
|
||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||
const RowPtrBF B_storage_view(B_storage, K, B_stride);
|
||||
const StridedViewBF B_storage_view(B_storage, K, B_stride);
|
||||
|
||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
RowPtrBF B_view = DecompressB(B, row_b, range_K, B_storage_view);
|
||||
StridedViewBF B_view =
|
||||
DecompressB(B, row_b, range_K, B_storage_view);
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
|
||||
args_, C_rows);
|
||||
}
|
||||
|
|
@ -1055,7 +1046,7 @@ class MMPerPackage {
|
|||
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
|
||||
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
|
||||
HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT_MT_K", args_);
|
||||
const size_t kc_max = ranges_kc_.TaskSize();
|
||||
|
|
@ -1065,17 +1056,18 @@ class MMPerPackage {
|
|||
// Sequential loop over NC/MC/KC, for when the M/N loops are
|
||||
// already parallel. This is B3A2C0 in MOMMS terminology: we read
|
||||
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`.
|
||||
const auto loop_nc = [&](const RowPtrBF& B_storage_view,
|
||||
const auto loop_nc = [&](const StridedViewBF& B_storage_view,
|
||||
const IndexRange& range_mc,
|
||||
const IndexRange& range_kc,
|
||||
const IndexRange& range_nc,
|
||||
auto out_tag) HWY_ATTR {
|
||||
const size_t kc = range_kc.Num();
|
||||
const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc);
|
||||
const StridedViewBF& A_view =
|
||||
A_.View(range_mc.begin(), range_kc.begin(), kc);
|
||||
|
||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
RowPtrBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
|
||||
StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
||||
C_rows);
|
||||
}
|
||||
|
|
@ -1084,7 +1076,7 @@ class MMPerPackage {
|
|||
ranges_mc_, ranges_nc_, pkg_idx_,
|
||||
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
|
||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||
const RowPtrBF B_storage_view(B_storage, kc_max, B_stride);
|
||||
const StridedViewBF B_storage_view(B_storage, kc_max, B_stride);
|
||||
|
||||
// Peel off the first iteration of the kc loop: avoid
|
||||
// zero-initializing `partial` by writing into it.
|
||||
|
|
@ -1166,15 +1158,16 @@ class MMPerPackage {
|
|||
|
||||
// Autotuning wrapper for `DoDecompressA`.
|
||||
template <typename TA>
|
||||
HWY_INLINE RowPtrBF DecompressA(const MatPtrT<TA>& A) const {
|
||||
HWY_INLINE StridedViewBF DecompressA(const MatPtrT<TA>& A) const {
|
||||
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
|
||||
// If already BF16, maybe return a view:
|
||||
if constexpr (hwy::IsSame<TA, BF16>()) {
|
||||
// Only if vector multiple and padded (see `DoDecompressA`).
|
||||
const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>());
|
||||
if (HWY_LIKELY(A.Cols() % NBF == 0 && !A.IsPacked())) {
|
||||
// Actually const, but RowPtr is also used for partial which is not.
|
||||
return RowPtrBF(const_cast<TA*>(A.Row(0)), A.Cols(), A.Stride());
|
||||
// Const, but cast because StridedView is also used for `partial` which
|
||||
// is non-const.
|
||||
return StridedViewBF(const_cast<TA*>(A.Row(0)), A.Cols(), A.Stride());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1212,12 +1205,12 @@ class MMPerPackage {
|
|||
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
|
||||
// thanks to its large table lookups, and less so on other targets.
|
||||
template <typename TB>
|
||||
HWY_INLINE RowPtrBF DecompressB(const MatPtrT<TB>& B, const size_t row_b,
|
||||
const IndexRange& range_kc,
|
||||
const RowPtrBF& B_view) const {
|
||||
HWY_INLINE StridedViewBF DecompressB(const MatPtrT<TB>& B, const size_t row_b,
|
||||
const IndexRange& range_kc,
|
||||
const StridedViewBF& B_view) const {
|
||||
if constexpr (hwy::IsSame<TB, BF16>()) {
|
||||
return RowPtrBF(const_cast<BF16*>(B.Row(row_b)) + range_kc.begin(),
|
||||
range_kc.Num(), B.Stride());
|
||||
return StridedViewBF(const_cast<BF16*>(B.Row(row_b)) + range_kc.begin(),
|
||||
range_kc.Num(), B.Stride());
|
||||
}
|
||||
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
|
|
@ -1242,7 +1235,7 @@ class MMPerPackage {
|
|||
|
||||
const MMArgs args_; // copy for locality
|
||||
const size_t pkg_idx_;
|
||||
RowPtrBF A_; // view into A or pkg_A_, both of which are padded.
|
||||
StridedViewBF A_; // view into A or pkg_A_, both of which are padded.
|
||||
|
||||
const IndexRange range_np_;
|
||||
// From MMConfig:
|
||||
|
|
@ -1272,7 +1265,7 @@ struct MMImpl {
|
|||
// or with the best config.
|
||||
template <typename TA, typename TB, typename TC>
|
||||
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
CRows<TC> C_rows, const MMArgs& args,
|
||||
RowPtrs<TC> C_rows, const MMArgs& args,
|
||||
const MMConfig& config) {
|
||||
MMZone matmul_zone;
|
||||
matmul_zone.MaybeEnter("MM.DoMatMul", args);
|
||||
|
|
@ -1314,7 +1307,7 @@ template <typename TA, typename TB, typename TC>
|
|||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
MatPtrT<TC>& C) {
|
||||
CRows<TC> C_rows(C.GetRowPtrs());
|
||||
RowPtrs<TC> C_rows(C.GetRowPtrs());
|
||||
if (HWY_UNLIKELY(!C.GetRowPtrs())) {
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
fprintf(stderr,
|
||||
|
|
@ -1326,7 +1319,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
for (size_t r = 0; r < C.Rows(); ++r) {
|
||||
env.row_ptrs[0][r] = reinterpret_cast<uint8_t*>(C.Row(r));
|
||||
}
|
||||
C_rows = CRows<TC>(env.row_ptrs[0].get());
|
||||
C_rows = RowPtrs<TC>(env.row_ptrs[0].get());
|
||||
}
|
||||
|
||||
const Allocator& allocator = env.ctx.allocator;
|
||||
|
|
|
|||
28
ops/matmul.h
28
ops/matmul.h
|
|
@ -176,12 +176,12 @@ void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel);
|
|||
// C is BF16/float, or double for partial.
|
||||
void BindC(MatPtr& C, MMParallel& parallel);
|
||||
|
||||
// Lightweight view into `MatStorageT`.
|
||||
// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows.
|
||||
#pragma pack(push, 1) // power of two size
|
||||
template <typename T>
|
||||
class RowPtr {
|
||||
class StridedView {
|
||||
public:
|
||||
RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride)
|
||||
StridedView(T* HWY_RESTRICT row0, size_t cols, size_t stride)
|
||||
: row0_(row0),
|
||||
cols_(static_cast<uint32_t>(cols)),
|
||||
stride_(static_cast<uint32_t>(stride)) {
|
||||
|
|
@ -198,10 +198,10 @@ class RowPtr {
|
|||
}
|
||||
|
||||
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
||||
RowPtr<T> View(size_t r, size_t c, size_t cols) const {
|
||||
StridedView<T> View(size_t r, size_t c, size_t cols) const {
|
||||
HWY_DASSERT(c < Cols());
|
||||
HWY_DASSERT(cols <= Cols() - c);
|
||||
return RowPtr<T>(Row(r) + c, cols, stride_);
|
||||
return StridedView<T>(Row(r) + c, cols, stride_);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -211,8 +211,8 @@ class RowPtr {
|
|||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
using RowPtrBF = RowPtr<BF16>;
|
||||
using RowPtrD = RowPtr<double>;
|
||||
using StridedViewBF = StridedView<BF16>;
|
||||
using StridedViewD = StridedView<double>;
|
||||
|
||||
// Per-package storage for packed A, and one global C-shaped `partial` for
|
||||
// accumulating partial dot products (sections of K).
|
||||
|
|
@ -260,19 +260,19 @@ class MMStorage {
|
|||
}
|
||||
|
||||
// Returns per-package matrix view.
|
||||
RowPtrBF A(size_t pkg_idx, const Extents2D& extents) const {
|
||||
StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const {
|
||||
HWY_DASSERT(extents.rows <= kMaxM);
|
||||
HWY_DASSERT(extents.cols <= kMaxK);
|
||||
return RowPtrBF(const_cast<BF16*>(pkg_A_[pkg_idx]->Row(0)), extents.cols,
|
||||
pkg_A_[pkg_idx]->Stride());
|
||||
return StridedViewBF(const_cast<BF16*>(pkg_A_[pkg_idx]->Row(0)),
|
||||
extents.cols, pkg_A_[pkg_idx]->Stride());
|
||||
}
|
||||
|
||||
RowPtrD Partial() const { return partial_; }
|
||||
StridedViewD Partial() const { return partial_; }
|
||||
|
||||
private:
|
||||
std::unique_ptr<MatStorageT<BF16>> pkg_A_[MMParallel::kMaxPackages];
|
||||
MatStorageT<double> partial_storage_;
|
||||
RowPtrD partial_;
|
||||
StridedViewD partial_;
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
|
@ -673,7 +673,7 @@ struct MatMulEnv {
|
|||
// Reduces register pressure compared to individual values/references.
|
||||
struct MMArgs {
|
||||
MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale,
|
||||
const float* HWY_RESTRICT add, const RowPtrD& partial)
|
||||
const float* HWY_RESTRICT add, const StridedViewD& partial)
|
||||
: env(&env),
|
||||
per_key(&per_key),
|
||||
scale(scale),
|
||||
|
|
@ -686,7 +686,7 @@ struct MMArgs {
|
|||
double scale;
|
||||
const float* HWY_RESTRICT add;
|
||||
// Same size as C, threads write at false-sharing-free granularity.
|
||||
RowPtrD partial;
|
||||
StridedViewD partial;
|
||||
};
|
||||
|
||||
// Wrapper over hwy::Zone that is only enabled when autotuning finished.
|
||||
|
|
|
|||
22
util/mat.h
22
util/mat.h
|
|
@ -33,16 +33,19 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
// Type-safe wrapper over type-erased uint8_t row pointers from MatPtr.
|
||||
template <typename TC>
|
||||
class CRows {
|
||||
// Type-safe wrapper over type-erased uint8_t row pointers from MatPtr. Used
|
||||
// for C, in future also for A.
|
||||
template <typename T>
|
||||
class RowPtrs {
|
||||
public:
|
||||
CRows(TC** C_rows) : C_rows_(C_rows) {}
|
||||
RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs) {}
|
||||
|
||||
TC* HWY_RESTRICT operator[](size_t row_idx) const { return C_rows_[row_idx]; }
|
||||
T* HWY_RESTRICT operator[](size_t row_idx) const {
|
||||
return HWY_RCAST_ALIGNED(T*, row_ptrs_[row_idx]);
|
||||
}
|
||||
|
||||
private:
|
||||
TC** C_rows_;
|
||||
uint8_t** row_ptrs_;
|
||||
};
|
||||
|
||||
// Type-erased, non-owning pointer and metadata for rank-1 or 2 tensors (vector
|
||||
|
|
@ -88,7 +91,9 @@ class MatPtr : public IFields {
|
|||
|
||||
bool HasPtr() const { return ptr_ != nullptr; }
|
||||
|
||||
// Caller has initialized Rows() pointers in row_ptrs[].
|
||||
// Caller has initialized Rows() pointers in row_ptrs[]. Note that this only
|
||||
// changes `GetRowPtrs`, not `Row()`, because that would require branching
|
||||
// and only a few call sites, in particular MatMul, use row pointers.
|
||||
void AttachRowPtrs(uint8_t** row_ptrs) {
|
||||
row_ptrs_ = row_ptrs;
|
||||
for (size_t r = 0; r < Rows(); ++r) {
|
||||
|
|
@ -96,6 +101,8 @@ class MatPtr : public IFields {
|
|||
}
|
||||
}
|
||||
|
||||
// Called by Activations to allocate once, rather than have to fill row
|
||||
// pointers in each call to MatMul.
|
||||
void AllocateAndAttachRowPtrs(
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs) {
|
||||
if (!HasPtr()) return;
|
||||
|
|
@ -107,6 +114,7 @@ class MatPtr : public IFields {
|
|||
AttachRowPtrs(ptrs);
|
||||
};
|
||||
|
||||
// If non-null, this array should be used instead of `Row()`.
|
||||
uint8_t** GetRowPtrs() const { return row_ptrs_; }
|
||||
|
||||
// A single row counts as packed because there is no padding between rows.
|
||||
|
|
|
|||
Loading…
Reference in New Issue