mirror of https://github.com/google/gemma.cpp.git
Refactor: move Worker to ThreadingContext, factor out MMDecompress
PiperOrigin-RevId: 804909921
This commit is contained in:
parent
461a9c7d1b
commit
24b1760f03
361
ops/matmul-inl.h
361
ops/matmul-inl.h
|
|
@ -219,6 +219,181 @@ class MMStoreHorizontalSumsIntoC {
|
||||||
}
|
}
|
||||||
}; // MMStoreHorizontalSumsIntoC
|
}; // MMStoreHorizontalSumsIntoC
|
||||||
|
|
||||||
|
// Stateless, wraps member functions.
|
||||||
|
class MMDecompress {
|
||||||
|
public:
|
||||||
|
// Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0,
|
||||||
|
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
|
||||||
|
// thanks to its large table lookups, and less so on other targets.
|
||||||
|
template <typename TB>
|
||||||
|
static StridedViewBF DecompressB(const MatPtrT<TB>& B, const size_t row_b,
|
||||||
|
const IndexRange& range_kc,
|
||||||
|
const StridedViewBF B_view) {
|
||||||
|
const hn::ScalableTag<BF16> dbf;
|
||||||
|
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
|
||||||
|
|
||||||
|
// Neither A nor B require padding because `LoopKC` handles remainders.
|
||||||
|
if constexpr (hwy::IsSame<TB, BF16>()) {
|
||||||
|
return View(B, row_b, range_kc.begin(), range_kc.Num());
|
||||||
|
}
|
||||||
|
|
||||||
|
const PackedSpan<const TB> B_span = B.PaddedSpan();
|
||||||
|
|
||||||
|
const size_t kc = range_kc.Num();
|
||||||
|
const size_t col0 = range_kc.begin();
|
||||||
|
|
||||||
|
for (size_t r = 0; r < kNR; ++r) {
|
||||||
|
const size_t packed_ofs = (row_b + r) * B.Stride() + col0;
|
||||||
|
BF16* HWY_RESTRICT to = B_view.Row(r);
|
||||||
|
DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc);
|
||||||
|
// Verify that we zero-padded.
|
||||||
|
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||||
|
for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) {
|
||||||
|
HWY_DASSERT(hwy::ConvertScalarTo<float>(to[i]) == 0.0f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return B_view;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename TA>
|
||||||
|
static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT<TA>& A,
|
||||||
|
MMAutoTune<MMParA>& autotune,
|
||||||
|
const MatMulEnv& env,
|
||||||
|
MMOptions options) {
|
||||||
|
if constexpr (IsBF16<TA>()) {
|
||||||
|
// We can use a view, regardless of columns/padding, because
|
||||||
|
// `MMKernel::LoopKC` supports non-vector multiples.
|
||||||
|
return View(A, 0, 0, A.Cols());
|
||||||
|
} else {
|
||||||
|
// Always decompress. To reduce code size/compile time, we no longer
|
||||||
|
// support a separate F32 kernel; most A are already BF16. We also only
|
||||||
|
// have a single MMStorage.
|
||||||
|
HWY_ASSERT(options.cluster_idx == 0);
|
||||||
|
const StridedViewBF A_view = env.storage.A(A.Extents());
|
||||||
|
AutotuneDecompressA(A, A_view, autotune, env, options);
|
||||||
|
return A_view;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
||||||
|
template <typename T>
|
||||||
|
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c,
|
||||||
|
size_t cols) {
|
||||||
|
HWY_DASSERT(c < AB.Cols());
|
||||||
|
HWY_DASSERT(cols <= AB.Cols() - c);
|
||||||
|
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decompresses all `M x K` from `A` into padded BF16 `A_view`.
|
||||||
|
static HWY_NOINLINE void DecompressA(const MatPtrT<float>& A,
|
||||||
|
const StridedViewBF A_view,
|
||||||
|
MMAutoTune<MMParA>& autotune,
|
||||||
|
MMParA par_a, const MatMulEnv& env,
|
||||||
|
const MMOptions& options) {
|
||||||
|
const IndexRange all_M(0, A.Rows());
|
||||||
|
const IndexRange all_K(0, A.Cols());
|
||||||
|
HWY_DASSERT(all_K.Num() == A_view.Cols());
|
||||||
|
|
||||||
|
const hn::ScalableTag<BF16> dbf;
|
||||||
|
const size_t NBF = hn::Lanes(dbf);
|
||||||
|
|
||||||
|
static const auto zone = env.ctx.profiler.AddZone("MM.DecompressA");
|
||||||
|
|
||||||
|
const auto do_range =
|
||||||
|
[&](const IndexRange& range_M, const IndexRange& range_K, size_t worker)
|
||||||
|
HWY_ATTR {
|
||||||
|
MMZone mm_zone;
|
||||||
|
mm_zone.MaybeEnter(worker, zone, env, &autotune);
|
||||||
|
|
||||||
|
const size_t col0 = range_K.begin();
|
||||||
|
const size_t cols = range_K.Num();
|
||||||
|
// Must be a vector multiple, or the last range before row
|
||||||
|
// padding, otherwise `DecompressAndZeroPad` overwrites neighbors.
|
||||||
|
HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols());
|
||||||
|
for (size_t row_a : range_M) {
|
||||||
|
const PackedSpan<const float> from =
|
||||||
|
MakeSpan(A.Row(row_a) + col0, cols);
|
||||||
|
BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0;
|
||||||
|
DecompressAndZeroPad(dbf, from, 0, to, cols);
|
||||||
|
// Verify that we zero-padded.
|
||||||
|
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||||
|
for (size_t i = cols; i < hwy::RoundUpTo(cols, NBF); ++i) {
|
||||||
|
HWY_DASSERT(hwy::ConvertScalarTo<float>(to[i]) == 0.0f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
switch (par_a) {
|
||||||
|
case MMParA::kNone:
|
||||||
|
do_range(all_M, all_K, env.ctx.Worker(options.cluster_idx));
|
||||||
|
break;
|
||||||
|
|
||||||
|
case MMParA::kK1:
|
||||||
|
case MMParA::kK2:
|
||||||
|
case MMParA::kK4: {
|
||||||
|
const size_t inner_tasks = static_cast<size_t>(par_a);
|
||||||
|
// At least one vector, otherwise DecompressAndZeroPad will add
|
||||||
|
// padding, which might overwrite neighboring tasks. Also a whole cache
|
||||||
|
// line to avoid false sharing.
|
||||||
|
const size_t multiple_K =
|
||||||
|
HWY_MAX(NBF, env.ctx.cache_info.LineBytes() / sizeof(BF16));
|
||||||
|
|
||||||
|
DispatchParallelism(options.parallelism, [&](const auto& parallel) {
|
||||||
|
parallel.ForN(env.ctx, all_K, multiple_K, inner_tasks,
|
||||||
|
options.cluster_idx,
|
||||||
|
[&](const IndexRange& range_K, size_t worker) {
|
||||||
|
do_range(all_M, range_K, worker);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case MMParA::kM:
|
||||||
|
DispatchParallelism(options.parallelism, [&](const auto& parallel) {
|
||||||
|
parallel.ForRangeMC(env.ctx, all_M, options.cluster_idx,
|
||||||
|
[&](size_t row_a, size_t worker) {
|
||||||
|
do_range(IndexRange(row_a, row_a + 1), all_K,
|
||||||
|
worker);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Autotuning wrapper for `DoDecompressA`.
|
||||||
|
static HWY_INLINE void AutotuneDecompressA(const MatPtrT<float>& A,
|
||||||
|
const StridedViewBF A_view,
|
||||||
|
MMAutoTune<MMParA>& autotune,
|
||||||
|
const MatMulEnv& env,
|
||||||
|
const MMOptions& options) {
|
||||||
|
if (HWY_LIKELY(autotune.Best())) {
|
||||||
|
return DecompressA(A, A_view, autotune, *autotune.Best(), env, options);
|
||||||
|
}
|
||||||
|
|
||||||
|
// First call: generate candidates.
|
||||||
|
if (HWY_UNLIKELY(!autotune.HasCandidates())) {
|
||||||
|
const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM;
|
||||||
|
std::vector<MMParA> candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4,
|
||||||
|
other};
|
||||||
|
autotune.SetCandidates(candidates);
|
||||||
|
}
|
||||||
|
|
||||||
|
const MMParA& par_a = autotune.NextConfig();
|
||||||
|
const uint64_t t0 = hwy::timer::Start();
|
||||||
|
DecompressA(A, A_view, autotune, par_a, env, options);
|
||||||
|
const uint64_t t1 =
|
||||||
|
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||||
|
const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0);
|
||||||
|
if (HWY_UNLIKELY(env.print_measurement && autotune.ShouldPrint())) {
|
||||||
|
fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a),
|
||||||
|
static_cast<double>(min_elapsed) /
|
||||||
|
hwy::platform::InvariantTicksPerSecond() * 1E6);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}; // MMDecompress
|
||||||
|
|
||||||
// Stateless, wraps member functions. Contains the innermost 2-4 loops.
|
// Stateless, wraps member functions. Contains the innermost 2-4 loops.
|
||||||
class MMKernel {
|
class MMKernel {
|
||||||
// Compute size of per-worker storage for `kNR` row ranges of B. Stack
|
// Compute size of per-worker storage for `kNR` row ranges of B. Stack
|
||||||
|
|
@ -288,49 +463,6 @@ class MMKernel {
|
||||||
|
|
||||||
static constexpr size_t B_storage_max = kNR * B_stride_max;
|
static constexpr size_t B_storage_max = kNR * B_stride_max;
|
||||||
|
|
||||||
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
|
||||||
template <typename T>
|
|
||||||
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c,
|
|
||||||
size_t cols) {
|
|
||||||
HWY_DASSERT(c < AB.Cols());
|
|
||||||
HWY_DASSERT(cols <= AB.Cols() - c);
|
|
||||||
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0,
|
|
||||||
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
|
|
||||||
// thanks to its large table lookups, and less so on other targets.
|
|
||||||
template <typename TB>
|
|
||||||
static StridedViewBF DecompressB(const MatPtrT<TB>& B, const size_t row_b,
|
|
||||||
const IndexRange& range_kc,
|
|
||||||
const StridedViewBF B_view) {
|
|
||||||
const hn::ScalableTag<BF16> dbf;
|
|
||||||
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
|
|
||||||
|
|
||||||
// Neither A nor B require padding because `LoopKC` handles remainders.
|
|
||||||
if constexpr (hwy::IsSame<TB, BF16>()) {
|
|
||||||
return View(B, row_b, range_kc.begin(), range_kc.Num());
|
|
||||||
}
|
|
||||||
|
|
||||||
const PackedSpan<const TB> B_span = B.PaddedSpan();
|
|
||||||
|
|
||||||
const size_t kc = range_kc.Num();
|
|
||||||
const size_t col0 = range_kc.begin();
|
|
||||||
|
|
||||||
for (size_t r = 0; r < kNR; ++r) {
|
|
||||||
const size_t packed_ofs = (row_b + r) * B.Stride() + col0;
|
|
||||||
BF16* HWY_RESTRICT to = B_view.Row(r);
|
|
||||||
DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc);
|
|
||||||
// Verify that we zero-padded.
|
|
||||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
|
||||||
for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) {
|
|
||||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(to[i]) == 0.0f);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return B_view;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads
|
// Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads
|
||||||
// `mc x kc` of A, `nc x kc` of B, and updates `mc x nc` of C. Called by
|
// `mc x kc` of A, `nc x kc` of B, and updates `mc x nc` of C. Called by
|
||||||
// `ForeachKC` and when there is only a single KC task.
|
// `ForeachKC` and when there is only a single KC task.
|
||||||
|
|
@ -350,7 +482,8 @@ class MMKernel {
|
||||||
|
|
||||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||||
row_b += kNR) {
|
row_b += kNR) {
|
||||||
StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
|
StridedViewBF B_view =
|
||||||
|
MMDecompress::DecompressB(B, row_b, range_kc, B_storage_view);
|
||||||
A2C0(A_view, B_view, mr, range_mc, row_b, kc, out_tag, args, C_rows);
|
A2C0(A_view, B_view, mr, range_mc, row_b, kc, out_tag, args, C_rows);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -742,137 +875,6 @@ class MMImpl {
|
||||||
HWY_ASSERT(hwy::IsAligned(A.RowBytes(1), vector_bytes));
|
HWY_ASSERT(hwy::IsAligned(A.RowBytes(1), vector_bytes));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t Worker(const MatMulEnv& env, size_t cluster_idx) {
|
|
||||||
return cluster_idx * env.ctx.pools.MaxWorkersPerCluster();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decompresses all `M x K` from `A` into padded BF16 `A_view`.
|
|
||||||
static HWY_NOINLINE void DoDecompressA(const MatPtrT<float>& A,
|
|
||||||
const StridedViewBF A_view,
|
|
||||||
MMAutoTune<MMParA>& autotune,
|
|
||||||
MMParA par_a, const MatMulEnv& env,
|
|
||||||
const MMOptions& options) {
|
|
||||||
const IndexRange all_M(0, A.Rows());
|
|
||||||
const IndexRange all_K(0, A.Cols());
|
|
||||||
HWY_DASSERT(all_K.Num() == A_view.Cols());
|
|
||||||
|
|
||||||
const hn::ScalableTag<BF16> dbf;
|
|
||||||
const size_t NBF = hn::Lanes(dbf);
|
|
||||||
|
|
||||||
static const auto zone = env.ctx.profiler.AddZone("MM.DecompressA");
|
|
||||||
|
|
||||||
const auto do_range =
|
|
||||||
[&](const IndexRange& range_M, const IndexRange& range_K, size_t worker)
|
|
||||||
HWY_ATTR {
|
|
||||||
MMZone mm_zone;
|
|
||||||
mm_zone.MaybeEnter(worker, zone, env, &autotune);
|
|
||||||
|
|
||||||
const size_t col0 = range_K.begin();
|
|
||||||
const size_t cols = range_K.Num();
|
|
||||||
// Must be a vector multiple, or the last range before row
|
|
||||||
// padding, otherwise `DecompressAndZeroPad` overwrites neighbors.
|
|
||||||
HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols());
|
|
||||||
for (size_t row_a : range_M) {
|
|
||||||
const PackedSpan<const float> from =
|
|
||||||
MakeSpan(A.Row(row_a) + col0, cols);
|
|
||||||
BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0;
|
|
||||||
DecompressAndZeroPad(dbf, from, 0, to, cols);
|
|
||||||
// Verify that we zero-padded.
|
|
||||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
|
||||||
for (size_t i = cols; i < hwy::RoundUpTo(cols, NBF); ++i) {
|
|
||||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(to[i]) == 0.0f);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
switch (par_a) {
|
|
||||||
case MMParA::kNone:
|
|
||||||
do_range(all_M, all_K, Worker(env, options.cluster_idx));
|
|
||||||
break;
|
|
||||||
|
|
||||||
case MMParA::kK1:
|
|
||||||
case MMParA::kK2:
|
|
||||||
case MMParA::kK4: {
|
|
||||||
const size_t inner_tasks = static_cast<size_t>(par_a);
|
|
||||||
// At least one vector, otherwise DecompressAndZeroPad will add
|
|
||||||
// padding, which might overwrite neighboring tasks. Also a whole cache
|
|
||||||
// line to avoid false sharing.
|
|
||||||
const size_t multiple_K =
|
|
||||||
HWY_MAX(NBF, env.ctx.cache_info.LineBytes() / sizeof(BF16));
|
|
||||||
|
|
||||||
DispatchParallelism(options.parallelism, [&](const auto& parallel) {
|
|
||||||
parallel.ForN(env.ctx, all_K, multiple_K, inner_tasks,
|
|
||||||
options.cluster_idx,
|
|
||||||
[&](const IndexRange& range_K, size_t worker) {
|
|
||||||
do_range(all_M, range_K, worker);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case MMParA::kM:
|
|
||||||
DispatchParallelism(options.parallelism, [&](const auto& parallel) {
|
|
||||||
parallel.ForRangeMC(env.ctx, all_M, options.cluster_idx,
|
|
||||||
[&](size_t row_a, size_t worker) {
|
|
||||||
do_range(IndexRange(row_a, row_a + 1), all_K,
|
|
||||||
worker);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Autotuning wrapper for `DoDecompressA`.
|
|
||||||
static HWY_INLINE void DecompressA(const MatPtrT<float>& A,
|
|
||||||
const StridedViewBF A_view,
|
|
||||||
MMAutoTune<MMParA>& autotune,
|
|
||||||
const MatMulEnv& env,
|
|
||||||
const MMOptions& options) {
|
|
||||||
if (HWY_LIKELY(autotune.Best())) {
|
|
||||||
return DoDecompressA(A, A_view, autotune, *autotune.Best(), env, options);
|
|
||||||
}
|
|
||||||
|
|
||||||
// First call: generate candidates.
|
|
||||||
if (HWY_UNLIKELY(!autotune.HasCandidates())) {
|
|
||||||
const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM;
|
|
||||||
std::vector<MMParA> candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4,
|
|
||||||
other};
|
|
||||||
autotune.SetCandidates(candidates);
|
|
||||||
}
|
|
||||||
|
|
||||||
const MMParA& par_a = autotune.NextConfig();
|
|
||||||
const uint64_t t0 = hwy::timer::Start();
|
|
||||||
DoDecompressA(A, A_view, autotune, par_a, env, options);
|
|
||||||
const uint64_t t1 =
|
|
||||||
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
|
||||||
const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0);
|
|
||||||
if (HWY_UNLIKELY(env.print_measurement && autotune.ShouldPrint())) {
|
|
||||||
fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a),
|
|
||||||
static_cast<double>(min_elapsed) /
|
|
||||||
hwy::platform::InvariantTicksPerSecond() * 1E6);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename TA>
|
|
||||||
static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT<TA>& A,
|
|
||||||
MMAutoTune<MMParA>& autotune,
|
|
||||||
const MatMulEnv& env,
|
|
||||||
MMOptions options) {
|
|
||||||
if constexpr (IsBF16<TA>()) {
|
|
||||||
// We can use a view, regardless of columns/padding, because `LoopKC`
|
|
||||||
// supports non-vector multiples.
|
|
||||||
return MMKernel::View(A, 0, 0, A.Cols());
|
|
||||||
} else {
|
|
||||||
// Always decompress. To reduce code size/compile time, we no longer
|
|
||||||
// support a separate F32 kernel; most A are already BF16. We also only
|
|
||||||
// have a single MMStorage.
|
|
||||||
HWY_ASSERT(options.cluster_idx == 0);
|
|
||||||
const StridedViewBF A_view = env.storage.A(A.Extents());
|
|
||||||
DecompressA(A, A_view, autotune, env, options);
|
|
||||||
return A_view;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Defines several variants of the outer M/N/K loops (see `MMOrder`).
|
// Defines several variants of the outer M/N/K loops (see `MMOrder`).
|
||||||
|
|
@ -885,7 +887,7 @@ class MMLoops {
|
||||||
RowPtrs<TC> C_rows, const MMArgs& args) {
|
RowPtrs<TC> C_rows, const MMArgs& args) {
|
||||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch");
|
static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch");
|
||||||
PROFILER_ZONE3(args.env.ctx.profiler,
|
PROFILER_ZONE3(args.env.ctx.profiler,
|
||||||
MMImpl::Worker(args.env, args.options.cluster_idx), zone);
|
args.env.ctx.Worker(args.options.cluster_idx), zone);
|
||||||
|
|
||||||
DispatchParallelism(
|
DispatchParallelism(
|
||||||
args.options.parallelism, [&](const auto& parallel) HWY_ATTR {
|
args.options.parallelism, [&](const auto& parallel) HWY_ATTR {
|
||||||
|
|
@ -931,7 +933,7 @@ class MMLoops {
|
||||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||||
row_b += kNR) {
|
row_b += kNR) {
|
||||||
StridedViewBF B_view =
|
StridedViewBF B_view =
|
||||||
MMKernel::DecompressB(B, row_b, range_K, B_storage_view);
|
MMDecompress::DecompressB(B, row_b, range_K, B_storage_view);
|
||||||
MMKernel::A2C0(A_view, B_view, args.mr, range_M, row_b, K, MMSetC(),
|
MMKernel::A2C0(A_view, B_view, args.mr, range_M, row_b, K, MMSetC(),
|
||||||
args, C_rows);
|
args, C_rows);
|
||||||
}
|
}
|
||||||
|
|
@ -1026,8 +1028,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
static const auto zone = env.ctx.profiler.AddZone("MM.MatMul");
|
static const auto zone = env.ctx.profiler.AddZone("MM.MatMul");
|
||||||
const size_t cluster_idx = options.cluster_idx;
|
const size_t cluster_idx = options.cluster_idx;
|
||||||
HWY_DASSERT(cluster_idx < env.row_ptrs.size());
|
HWY_DASSERT(cluster_idx < env.row_ptrs.size());
|
||||||
PROFILER_ZONE3(env.ctx.profiler,
|
PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), zone);
|
||||||
cluster_idx * env.ctx.pools.MaxWorkersPerCluster(), zone);
|
|
||||||
|
|
||||||
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]);
|
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]);
|
||||||
|
|
||||||
|
|
@ -1041,7 +1042,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
|
|
||||||
// (Also auto-tunes, hence outside the timed section to prevent interference.)
|
// (Also auto-tunes, hence outside the timed section to prevent interference.)
|
||||||
const StridedViewBF A_view =
|
const StridedViewBF A_view =
|
||||||
MMImpl::MaybeDecompressA(A, per_key.autotune_par_a, env, options);
|
MMDecompress::MaybeDecompressA(A, per_key.autotune_par_a, env, options);
|
||||||
|
|
||||||
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
||||||
if (HWY_LIKELY(tuner.Best())) {
|
if (HWY_LIKELY(tuner.Best())) {
|
||||||
|
|
|
||||||
18
ops/matmul.h
18
ops/matmul.h
|
|
@ -116,7 +116,7 @@ struct MMParallelNone {
|
||||||
size_t /*n_multiple*/, size_t inner_tasks, size_t cluster_idx,
|
size_t /*n_multiple*/, size_t inner_tasks, size_t cluster_idx,
|
||||||
const Func& func) const {
|
const Func& func) const {
|
||||||
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||||
const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
const size_t worker = ctx.Worker(cluster_idx);
|
||||||
func(range_n, worker);
|
func(range_n, worker);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -125,7 +125,7 @@ struct MMParallelNone {
|
||||||
const IndexRangePartition& ranges_mc,
|
const IndexRangePartition& ranges_mc,
|
||||||
const IndexRangePartition& ranges_nc, size_t cluster_idx,
|
const IndexRangePartition& ranges_nc, size_t cluster_idx,
|
||||||
const Func& func) const {
|
const Func& func) const {
|
||||||
const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
const size_t worker = ctx.Worker(cluster_idx);
|
||||||
|
|
||||||
for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) {
|
for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) {
|
||||||
const IndexRange range_mc = ranges_mc.Range(i);
|
const IndexRange range_mc = ranges_mc.Range(i);
|
||||||
|
|
@ -139,7 +139,7 @@ struct MMParallelNone {
|
||||||
template <class Func>
|
template <class Func>
|
||||||
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
||||||
size_t cluster_idx, const Func& func) const {
|
size_t cluster_idx, const Func& func) const {
|
||||||
const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
const size_t worker = ctx.Worker(cluster_idx);
|
||||||
for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) {
|
for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) {
|
||||||
func(row_a, worker);
|
func(row_a, worker);
|
||||||
}
|
}
|
||||||
|
|
@ -154,7 +154,7 @@ struct MMParallelWithinCluster {
|
||||||
|
|
||||||
const size_t pkg_idx = 0;
|
const size_t pkg_idx = 0;
|
||||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
const size_t base = ctx.Worker(cluster_idx);
|
||||||
|
|
||||||
const IndexRangePartition worker_ranges = StaticPartition(
|
const IndexRangePartition worker_ranges = StaticPartition(
|
||||||
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
|
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
|
||||||
|
|
@ -171,7 +171,7 @@ struct MMParallelWithinCluster {
|
||||||
const Func& func) const {
|
const Func& func) const {
|
||||||
const size_t pkg_idx = 0;
|
const size_t pkg_idx = 0;
|
||||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
const size_t base = ctx.Worker(cluster_idx);
|
||||||
|
|
||||||
// Low-batch: avoid Divide/Remainder.
|
// Low-batch: avoid Divide/Remainder.
|
||||||
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
||||||
|
|
@ -192,7 +192,7 @@ struct MMParallelWithinCluster {
|
||||||
size_t cluster_idx, const Func& func) const {
|
size_t cluster_idx, const Func& func) const {
|
||||||
const size_t pkg_idx = 0;
|
const size_t pkg_idx = 0;
|
||||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
const size_t base = ctx.Worker(cluster_idx);
|
||||||
|
|
||||||
cluster.Run(
|
cluster.Run(
|
||||||
range_mc.begin(), range_mc.end(),
|
range_mc.begin(), range_mc.end(),
|
||||||
|
|
@ -233,8 +233,7 @@ struct MMParallelHierarchical {
|
||||||
n_ranges, all_clusters,
|
n_ranges, all_clusters,
|
||||||
[&](const IndexRange& n_range, const size_t cluster_idx) {
|
[&](const IndexRange& n_range, const size_t cluster_idx) {
|
||||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
const size_t cluster_base =
|
const size_t cluster_base = ctx.Worker(cluster_idx);
|
||||||
cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
|
||||||
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
|
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
|
||||||
const IndexRangePartition worker_ranges = StaticPartition(
|
const IndexRangePartition worker_ranges = StaticPartition(
|
||||||
n_range, cluster.NumWorkers() * inner_tasks, n_multiple);
|
n_range, cluster.NumWorkers() * inner_tasks, n_multiple);
|
||||||
|
|
@ -284,8 +283,7 @@ struct MMParallelHierarchical {
|
||||||
ParallelizeOneRange(
|
ParallelizeOneRange(
|
||||||
ranges_nc, all_clusters,
|
ranges_nc, all_clusters,
|
||||||
[&](const IndexRange range_nc, size_t cluster_idx) {
|
[&](const IndexRange range_nc, size_t cluster_idx) {
|
||||||
const size_t cluster_base =
|
const size_t cluster_base = ctx.Worker(cluster_idx);
|
||||||
cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
|
||||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
ParallelizeOneRange(ranges_mc, cluster,
|
ParallelizeOneRange(ranges_mc, cluster,
|
||||||
[&](const IndexRange& range_mc, size_t worker) {
|
[&](const IndexRange& range_mc, size_t worker) {
|
||||||
|
|
|
||||||
|
|
@ -97,6 +97,13 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
||||||
struct ThreadingContext {
|
struct ThreadingContext {
|
||||||
explicit ThreadingContext(const ThreadingArgs& args);
|
explicit ThreadingContext(const ThreadingArgs& args);
|
||||||
|
|
||||||
|
// Returns a worker index compatible with those from `ParallelFor`, assuming
|
||||||
|
// the current thread is running on one thread per cluster, which happens
|
||||||
|
// when `ParallelismStrategy` is `kAcrossClusters`.
|
||||||
|
size_t Worker(size_t cluster_idx) const {
|
||||||
|
return cluster_idx * pools.MaxWorkersPerCluster();
|
||||||
|
}
|
||||||
|
|
||||||
// Singleton; pass around a reference to reduce overhead.
|
// Singleton; pass around a reference to reduce overhead.
|
||||||
hwy::Profiler& profiler;
|
hwy::Profiler& profiler;
|
||||||
|
|
||||||
|
|
@ -158,7 +165,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
|
||||||
|
|
||||||
switch (parallelism) {
|
switch (parallelism) {
|
||||||
case ParallelismStrategy::kNone: {
|
case ParallelismStrategy::kNone: {
|
||||||
const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
const size_t worker = ctx.Worker(cluster_idx);
|
||||||
for (size_t task = 0; task < num_tasks; ++task) {
|
for (size_t task = 0; task < num_tasks; ++task) {
|
||||||
func(task, worker);
|
func(task, worker);
|
||||||
}
|
}
|
||||||
|
|
@ -173,7 +180,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
|
||||||
case ParallelismStrategy::kWithinCluster: {
|
case ParallelismStrategy::kWithinCluster: {
|
||||||
// Ensure the worker argument is unique across clusters, because it is
|
// Ensure the worker argument is unique across clusters, because it is
|
||||||
// used for TLS indexing for example in profiler.h.
|
// used for TLS indexing for example in profiler.h.
|
||||||
const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
const size_t base = ctx.Worker(cluster_idx);
|
||||||
return ctx.pools.Cluster(pkg_idx, cluster_idx)
|
return ctx.pools.Cluster(pkg_idx, cluster_idx)
|
||||||
.Run(0, num_tasks, [&](uint64_t task, size_t worker) {
|
.Run(0, num_tasks, [&](uint64_t task, size_t worker) {
|
||||||
func(task, base + worker);
|
func(task, base + worker);
|
||||||
|
|
@ -193,8 +200,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
|
||||||
|
|
||||||
return ctx.pools.AllClusters(pkg_idx).Run(
|
return ctx.pools.AllClusters(pkg_idx).Run(
|
||||||
0, num_tasks, [&](uint64_t task, size_t cluster_idx) {
|
0, num_tasks, [&](uint64_t task, size_t cluster_idx) {
|
||||||
const size_t worker =
|
const size_t worker = ctx.Worker(cluster_idx);
|
||||||
cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
|
||||||
func(task, worker);
|
func(task, worker);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue