mirror of https://github.com/google/gemma.cpp.git
Refactor: move Worker to ThreadingContext, factor out MMDecompress
PiperOrigin-RevId: 804909921
This commit is contained in:
parent
461a9c7d1b
commit
24b1760f03
361
ops/matmul-inl.h
361
ops/matmul-inl.h
|
|
@ -219,6 +219,181 @@ class MMStoreHorizontalSumsIntoC {
|
|||
}
|
||||
}; // MMStoreHorizontalSumsIntoC
|
||||
|
||||
// Stateless, wraps member functions.
|
||||
class MMDecompress {
|
||||
public:
|
||||
// Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0,
|
||||
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
|
||||
// thanks to its large table lookups, and less so on other targets.
|
||||
template <typename TB>
|
||||
static StridedViewBF DecompressB(const MatPtrT<TB>& B, const size_t row_b,
|
||||
const IndexRange& range_kc,
|
||||
const StridedViewBF B_view) {
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
|
||||
|
||||
// Neither A nor B require padding because `LoopKC` handles remainders.
|
||||
if constexpr (hwy::IsSame<TB, BF16>()) {
|
||||
return View(B, row_b, range_kc.begin(), range_kc.Num());
|
||||
}
|
||||
|
||||
const PackedSpan<const TB> B_span = B.PaddedSpan();
|
||||
|
||||
const size_t kc = range_kc.Num();
|
||||
const size_t col0 = range_kc.begin();
|
||||
|
||||
for (size_t r = 0; r < kNR; ++r) {
|
||||
const size_t packed_ofs = (row_b + r) * B.Stride() + col0;
|
||||
BF16* HWY_RESTRICT to = B_view.Row(r);
|
||||
DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc);
|
||||
// Verify that we zero-padded.
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) {
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(to[i]) == 0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
return B_view;
|
||||
}
|
||||
|
||||
template <typename TA>
|
||||
static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT<TA>& A,
|
||||
MMAutoTune<MMParA>& autotune,
|
||||
const MatMulEnv& env,
|
||||
MMOptions options) {
|
||||
if constexpr (IsBF16<TA>()) {
|
||||
// We can use a view, regardless of columns/padding, because
|
||||
// `MMKernel::LoopKC` supports non-vector multiples.
|
||||
return View(A, 0, 0, A.Cols());
|
||||
} else {
|
||||
// Always decompress. To reduce code size/compile time, we no longer
|
||||
// support a separate F32 kernel; most A are already BF16. We also only
|
||||
// have a single MMStorage.
|
||||
HWY_ASSERT(options.cluster_idx == 0);
|
||||
const StridedViewBF A_view = env.storage.A(A.Extents());
|
||||
AutotuneDecompressA(A, A_view, autotune, env, options);
|
||||
return A_view;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
||||
template <typename T>
|
||||
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c,
|
||||
size_t cols) {
|
||||
HWY_DASSERT(c < AB.Cols());
|
||||
HWY_DASSERT(cols <= AB.Cols() - c);
|
||||
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
|
||||
}
|
||||
|
||||
// Decompresses all `M x K` from `A` into padded BF16 `A_view`.
|
||||
static HWY_NOINLINE void DecompressA(const MatPtrT<float>& A,
|
||||
const StridedViewBF A_view,
|
||||
MMAutoTune<MMParA>& autotune,
|
||||
MMParA par_a, const MatMulEnv& env,
|
||||
const MMOptions& options) {
|
||||
const IndexRange all_M(0, A.Rows());
|
||||
const IndexRange all_K(0, A.Cols());
|
||||
HWY_DASSERT(all_K.Num() == A_view.Cols());
|
||||
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
const size_t NBF = hn::Lanes(dbf);
|
||||
|
||||
static const auto zone = env.ctx.profiler.AddZone("MM.DecompressA");
|
||||
|
||||
const auto do_range =
|
||||
[&](const IndexRange& range_M, const IndexRange& range_K, size_t worker)
|
||||
HWY_ATTR {
|
||||
MMZone mm_zone;
|
||||
mm_zone.MaybeEnter(worker, zone, env, &autotune);
|
||||
|
||||
const size_t col0 = range_K.begin();
|
||||
const size_t cols = range_K.Num();
|
||||
// Must be a vector multiple, or the last range before row
|
||||
// padding, otherwise `DecompressAndZeroPad` overwrites neighbors.
|
||||
HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols());
|
||||
for (size_t row_a : range_M) {
|
||||
const PackedSpan<const float> from =
|
||||
MakeSpan(A.Row(row_a) + col0, cols);
|
||||
BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0;
|
||||
DecompressAndZeroPad(dbf, from, 0, to, cols);
|
||||
// Verify that we zero-padded.
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
for (size_t i = cols; i < hwy::RoundUpTo(cols, NBF); ++i) {
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(to[i]) == 0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
switch (par_a) {
|
||||
case MMParA::kNone:
|
||||
do_range(all_M, all_K, env.ctx.Worker(options.cluster_idx));
|
||||
break;
|
||||
|
||||
case MMParA::kK1:
|
||||
case MMParA::kK2:
|
||||
case MMParA::kK4: {
|
||||
const size_t inner_tasks = static_cast<size_t>(par_a);
|
||||
// At least one vector, otherwise DecompressAndZeroPad will add
|
||||
// padding, which might overwrite neighboring tasks. Also a whole cache
|
||||
// line to avoid false sharing.
|
||||
const size_t multiple_K =
|
||||
HWY_MAX(NBF, env.ctx.cache_info.LineBytes() / sizeof(BF16));
|
||||
|
||||
DispatchParallelism(options.parallelism, [&](const auto& parallel) {
|
||||
parallel.ForN(env.ctx, all_K, multiple_K, inner_tasks,
|
||||
options.cluster_idx,
|
||||
[&](const IndexRange& range_K, size_t worker) {
|
||||
do_range(all_M, range_K, worker);
|
||||
});
|
||||
});
|
||||
break;
|
||||
}
|
||||
case MMParA::kM:
|
||||
DispatchParallelism(options.parallelism, [&](const auto& parallel) {
|
||||
parallel.ForRangeMC(env.ctx, all_M, options.cluster_idx,
|
||||
[&](size_t row_a, size_t worker) {
|
||||
do_range(IndexRange(row_a, row_a + 1), all_K,
|
||||
worker);
|
||||
});
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Autotuning wrapper for `DoDecompressA`.
|
||||
static HWY_INLINE void AutotuneDecompressA(const MatPtrT<float>& A,
|
||||
const StridedViewBF A_view,
|
||||
MMAutoTune<MMParA>& autotune,
|
||||
const MatMulEnv& env,
|
||||
const MMOptions& options) {
|
||||
if (HWY_LIKELY(autotune.Best())) {
|
||||
return DecompressA(A, A_view, autotune, *autotune.Best(), env, options);
|
||||
}
|
||||
|
||||
// First call: generate candidates.
|
||||
if (HWY_UNLIKELY(!autotune.HasCandidates())) {
|
||||
const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM;
|
||||
std::vector<MMParA> candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4,
|
||||
other};
|
||||
autotune.SetCandidates(candidates);
|
||||
}
|
||||
|
||||
const MMParA& par_a = autotune.NextConfig();
|
||||
const uint64_t t0 = hwy::timer::Start();
|
||||
DecompressA(A, A_view, autotune, par_a, env, options);
|
||||
const uint64_t t1 =
|
||||
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||
const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0);
|
||||
if (HWY_UNLIKELY(env.print_measurement && autotune.ShouldPrint())) {
|
||||
fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a),
|
||||
static_cast<double>(min_elapsed) /
|
||||
hwy::platform::InvariantTicksPerSecond() * 1E6);
|
||||
}
|
||||
}
|
||||
}; // MMDecompress
|
||||
|
||||
// Stateless, wraps member functions. Contains the innermost 2-4 loops.
|
||||
class MMKernel {
|
||||
// Compute size of per-worker storage for `kNR` row ranges of B. Stack
|
||||
|
|
@ -288,49 +463,6 @@ class MMKernel {
|
|||
|
||||
static constexpr size_t B_storage_max = kNR * B_stride_max;
|
||||
|
||||
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
||||
template <typename T>
|
||||
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c,
|
||||
size_t cols) {
|
||||
HWY_DASSERT(c < AB.Cols());
|
||||
HWY_DASSERT(cols <= AB.Cols() - c);
|
||||
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
|
||||
}
|
||||
|
||||
// Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0,
|
||||
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
|
||||
// thanks to its large table lookups, and less so on other targets.
|
||||
template <typename TB>
|
||||
static StridedViewBF DecompressB(const MatPtrT<TB>& B, const size_t row_b,
|
||||
const IndexRange& range_kc,
|
||||
const StridedViewBF B_view) {
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
|
||||
|
||||
// Neither A nor B require padding because `LoopKC` handles remainders.
|
||||
if constexpr (hwy::IsSame<TB, BF16>()) {
|
||||
return View(B, row_b, range_kc.begin(), range_kc.Num());
|
||||
}
|
||||
|
||||
const PackedSpan<const TB> B_span = B.PaddedSpan();
|
||||
|
||||
const size_t kc = range_kc.Num();
|
||||
const size_t col0 = range_kc.begin();
|
||||
|
||||
for (size_t r = 0; r < kNR; ++r) {
|
||||
const size_t packed_ofs = (row_b + r) * B.Stride() + col0;
|
||||
BF16* HWY_RESTRICT to = B_view.Row(r);
|
||||
DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc);
|
||||
// Verify that we zero-padded.
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) {
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(to[i]) == 0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
return B_view;
|
||||
}
|
||||
|
||||
// Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads
|
||||
// `mc x kc` of A, `nc x kc` of B, and updates `mc x nc` of C. Called by
|
||||
// `ForeachKC` and when there is only a single KC task.
|
||||
|
|
@ -350,7 +482,8 @@ class MMKernel {
|
|||
|
||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
|
||||
StridedViewBF B_view =
|
||||
MMDecompress::DecompressB(B, row_b, range_kc, B_storage_view);
|
||||
A2C0(A_view, B_view, mr, range_mc, row_b, kc, out_tag, args, C_rows);
|
||||
}
|
||||
}
|
||||
|
|
@ -742,137 +875,6 @@ class MMImpl {
|
|||
HWY_ASSERT(hwy::IsAligned(A.RowBytes(1), vector_bytes));
|
||||
}
|
||||
}
|
||||
|
||||
static size_t Worker(const MatMulEnv& env, size_t cluster_idx) {
|
||||
return cluster_idx * env.ctx.pools.MaxWorkersPerCluster();
|
||||
}
|
||||
|
||||
// Decompresses all `M x K` from `A` into padded BF16 `A_view`.
|
||||
static HWY_NOINLINE void DoDecompressA(const MatPtrT<float>& A,
|
||||
const StridedViewBF A_view,
|
||||
MMAutoTune<MMParA>& autotune,
|
||||
MMParA par_a, const MatMulEnv& env,
|
||||
const MMOptions& options) {
|
||||
const IndexRange all_M(0, A.Rows());
|
||||
const IndexRange all_K(0, A.Cols());
|
||||
HWY_DASSERT(all_K.Num() == A_view.Cols());
|
||||
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
const size_t NBF = hn::Lanes(dbf);
|
||||
|
||||
static const auto zone = env.ctx.profiler.AddZone("MM.DecompressA");
|
||||
|
||||
const auto do_range =
|
||||
[&](const IndexRange& range_M, const IndexRange& range_K, size_t worker)
|
||||
HWY_ATTR {
|
||||
MMZone mm_zone;
|
||||
mm_zone.MaybeEnter(worker, zone, env, &autotune);
|
||||
|
||||
const size_t col0 = range_K.begin();
|
||||
const size_t cols = range_K.Num();
|
||||
// Must be a vector multiple, or the last range before row
|
||||
// padding, otherwise `DecompressAndZeroPad` overwrites neighbors.
|
||||
HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols());
|
||||
for (size_t row_a : range_M) {
|
||||
const PackedSpan<const float> from =
|
||||
MakeSpan(A.Row(row_a) + col0, cols);
|
||||
BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0;
|
||||
DecompressAndZeroPad(dbf, from, 0, to, cols);
|
||||
// Verify that we zero-padded.
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
for (size_t i = cols; i < hwy::RoundUpTo(cols, NBF); ++i) {
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(to[i]) == 0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
switch (par_a) {
|
||||
case MMParA::kNone:
|
||||
do_range(all_M, all_K, Worker(env, options.cluster_idx));
|
||||
break;
|
||||
|
||||
case MMParA::kK1:
|
||||
case MMParA::kK2:
|
||||
case MMParA::kK4: {
|
||||
const size_t inner_tasks = static_cast<size_t>(par_a);
|
||||
// At least one vector, otherwise DecompressAndZeroPad will add
|
||||
// padding, which might overwrite neighboring tasks. Also a whole cache
|
||||
// line to avoid false sharing.
|
||||
const size_t multiple_K =
|
||||
HWY_MAX(NBF, env.ctx.cache_info.LineBytes() / sizeof(BF16));
|
||||
|
||||
DispatchParallelism(options.parallelism, [&](const auto& parallel) {
|
||||
parallel.ForN(env.ctx, all_K, multiple_K, inner_tasks,
|
||||
options.cluster_idx,
|
||||
[&](const IndexRange& range_K, size_t worker) {
|
||||
do_range(all_M, range_K, worker);
|
||||
});
|
||||
});
|
||||
break;
|
||||
}
|
||||
case MMParA::kM:
|
||||
DispatchParallelism(options.parallelism, [&](const auto& parallel) {
|
||||
parallel.ForRangeMC(env.ctx, all_M, options.cluster_idx,
|
||||
[&](size_t row_a, size_t worker) {
|
||||
do_range(IndexRange(row_a, row_a + 1), all_K,
|
||||
worker);
|
||||
});
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Autotuning wrapper for `DoDecompressA`.
|
||||
static HWY_INLINE void DecompressA(const MatPtrT<float>& A,
|
||||
const StridedViewBF A_view,
|
||||
MMAutoTune<MMParA>& autotune,
|
||||
const MatMulEnv& env,
|
||||
const MMOptions& options) {
|
||||
if (HWY_LIKELY(autotune.Best())) {
|
||||
return DoDecompressA(A, A_view, autotune, *autotune.Best(), env, options);
|
||||
}
|
||||
|
||||
// First call: generate candidates.
|
||||
if (HWY_UNLIKELY(!autotune.HasCandidates())) {
|
||||
const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM;
|
||||
std::vector<MMParA> candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4,
|
||||
other};
|
||||
autotune.SetCandidates(candidates);
|
||||
}
|
||||
|
||||
const MMParA& par_a = autotune.NextConfig();
|
||||
const uint64_t t0 = hwy::timer::Start();
|
||||
DoDecompressA(A, A_view, autotune, par_a, env, options);
|
||||
const uint64_t t1 =
|
||||
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||
const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0);
|
||||
if (HWY_UNLIKELY(env.print_measurement && autotune.ShouldPrint())) {
|
||||
fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a),
|
||||
static_cast<double>(min_elapsed) /
|
||||
hwy::platform::InvariantTicksPerSecond() * 1E6);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TA>
|
||||
static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT<TA>& A,
|
||||
MMAutoTune<MMParA>& autotune,
|
||||
const MatMulEnv& env,
|
||||
MMOptions options) {
|
||||
if constexpr (IsBF16<TA>()) {
|
||||
// We can use a view, regardless of columns/padding, because `LoopKC`
|
||||
// supports non-vector multiples.
|
||||
return MMKernel::View(A, 0, 0, A.Cols());
|
||||
} else {
|
||||
// Always decompress. To reduce code size/compile time, we no longer
|
||||
// support a separate F32 kernel; most A are already BF16. We also only
|
||||
// have a single MMStorage.
|
||||
HWY_ASSERT(options.cluster_idx == 0);
|
||||
const StridedViewBF A_view = env.storage.A(A.Extents());
|
||||
DecompressA(A, A_view, autotune, env, options);
|
||||
return A_view;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Defines several variants of the outer M/N/K loops (see `MMOrder`).
|
||||
|
|
@ -885,7 +887,7 @@ class MMLoops {
|
|||
RowPtrs<TC> C_rows, const MMArgs& args) {
|
||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch");
|
||||
PROFILER_ZONE3(args.env.ctx.profiler,
|
||||
MMImpl::Worker(args.env, args.options.cluster_idx), zone);
|
||||
args.env.ctx.Worker(args.options.cluster_idx), zone);
|
||||
|
||||
DispatchParallelism(
|
||||
args.options.parallelism, [&](const auto& parallel) HWY_ATTR {
|
||||
|
|
@ -931,7 +933,7 @@ class MMLoops {
|
|||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
StridedViewBF B_view =
|
||||
MMKernel::DecompressB(B, row_b, range_K, B_storage_view);
|
||||
MMDecompress::DecompressB(B, row_b, range_K, B_storage_view);
|
||||
MMKernel::A2C0(A_view, B_view, args.mr, range_M, row_b, K, MMSetC(),
|
||||
args, C_rows);
|
||||
}
|
||||
|
|
@ -1026,8 +1028,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
static const auto zone = env.ctx.profiler.AddZone("MM.MatMul");
|
||||
const size_t cluster_idx = options.cluster_idx;
|
||||
HWY_DASSERT(cluster_idx < env.row_ptrs.size());
|
||||
PROFILER_ZONE3(env.ctx.profiler,
|
||||
cluster_idx * env.ctx.pools.MaxWorkersPerCluster(), zone);
|
||||
PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), zone);
|
||||
|
||||
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]);
|
||||
|
||||
|
|
@ -1041,7 +1042,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
|
||||
// (Also auto-tunes, hence outside the timed section to prevent interference.)
|
||||
const StridedViewBF A_view =
|
||||
MMImpl::MaybeDecompressA(A, per_key.autotune_par_a, env, options);
|
||||
MMDecompress::MaybeDecompressA(A, per_key.autotune_par_a, env, options);
|
||||
|
||||
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
||||
if (HWY_LIKELY(tuner.Best())) {
|
||||
|
|
|
|||
18
ops/matmul.h
18
ops/matmul.h
|
|
@ -116,7 +116,7 @@ struct MMParallelNone {
|
|||
size_t /*n_multiple*/, size_t inner_tasks, size_t cluster_idx,
|
||||
const Func& func) const {
|
||||
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||
const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||
const size_t worker = ctx.Worker(cluster_idx);
|
||||
func(range_n, worker);
|
||||
}
|
||||
|
||||
|
|
@ -125,7 +125,7 @@ struct MMParallelNone {
|
|||
const IndexRangePartition& ranges_mc,
|
||||
const IndexRangePartition& ranges_nc, size_t cluster_idx,
|
||||
const Func& func) const {
|
||||
const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||
const size_t worker = ctx.Worker(cluster_idx);
|
||||
|
||||
for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) {
|
||||
const IndexRange range_mc = ranges_mc.Range(i);
|
||||
|
|
@ -139,7 +139,7 @@ struct MMParallelNone {
|
|||
template <class Func>
|
||||
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
||||
size_t cluster_idx, const Func& func) const {
|
||||
const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||
const size_t worker = ctx.Worker(cluster_idx);
|
||||
for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) {
|
||||
func(row_a, worker);
|
||||
}
|
||||
|
|
@ -154,7 +154,7 @@ struct MMParallelWithinCluster {
|
|||
|
||||
const size_t pkg_idx = 0;
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||
const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||
const size_t base = ctx.Worker(cluster_idx);
|
||||
|
||||
const IndexRangePartition worker_ranges = StaticPartition(
|
||||
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
|
||||
|
|
@ -171,7 +171,7 @@ struct MMParallelWithinCluster {
|
|||
const Func& func) const {
|
||||
const size_t pkg_idx = 0;
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||
const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||
const size_t base = ctx.Worker(cluster_idx);
|
||||
|
||||
// Low-batch: avoid Divide/Remainder.
|
||||
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
||||
|
|
@ -192,7 +192,7 @@ struct MMParallelWithinCluster {
|
|||
size_t cluster_idx, const Func& func) const {
|
||||
const size_t pkg_idx = 0;
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||
const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||
const size_t base = ctx.Worker(cluster_idx);
|
||||
|
||||
cluster.Run(
|
||||
range_mc.begin(), range_mc.end(),
|
||||
|
|
@ -233,8 +233,7 @@ struct MMParallelHierarchical {
|
|||
n_ranges, all_clusters,
|
||||
[&](const IndexRange& n_range, const size_t cluster_idx) {
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||
const size_t cluster_base =
|
||||
cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||
const size_t cluster_base = ctx.Worker(cluster_idx);
|
||||
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
|
||||
const IndexRangePartition worker_ranges = StaticPartition(
|
||||
n_range, cluster.NumWorkers() * inner_tasks, n_multiple);
|
||||
|
|
@ -284,8 +283,7 @@ struct MMParallelHierarchical {
|
|||
ParallelizeOneRange(
|
||||
ranges_nc, all_clusters,
|
||||
[&](const IndexRange range_nc, size_t cluster_idx) {
|
||||
const size_t cluster_base =
|
||||
cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||
const size_t cluster_base = ctx.Worker(cluster_idx);
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||
ParallelizeOneRange(ranges_mc, cluster,
|
||||
[&](const IndexRange& range_mc, size_t worker) {
|
||||
|
|
|
|||
|
|
@ -97,6 +97,13 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
|||
struct ThreadingContext {
|
||||
explicit ThreadingContext(const ThreadingArgs& args);
|
||||
|
||||
// Returns a worker index compatible with those from `ParallelFor`, assuming
|
||||
// the current thread is running on one thread per cluster, which happens
|
||||
// when `ParallelismStrategy` is `kAcrossClusters`.
|
||||
size_t Worker(size_t cluster_idx) const {
|
||||
return cluster_idx * pools.MaxWorkersPerCluster();
|
||||
}
|
||||
|
||||
// Singleton; pass around a reference to reduce overhead.
|
||||
hwy::Profiler& profiler;
|
||||
|
||||
|
|
@ -158,7 +165,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
|
|||
|
||||
switch (parallelism) {
|
||||
case ParallelismStrategy::kNone: {
|
||||
const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||
const size_t worker = ctx.Worker(cluster_idx);
|
||||
for (size_t task = 0; task < num_tasks; ++task) {
|
||||
func(task, worker);
|
||||
}
|
||||
|
|
@ -173,7 +180,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
|
|||
case ParallelismStrategy::kWithinCluster: {
|
||||
// Ensure the worker argument is unique across clusters, because it is
|
||||
// used for TLS indexing for example in profiler.h.
|
||||
const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||
const size_t base = ctx.Worker(cluster_idx);
|
||||
return ctx.pools.Cluster(pkg_idx, cluster_idx)
|
||||
.Run(0, num_tasks, [&](uint64_t task, size_t worker) {
|
||||
func(task, base + worker);
|
||||
|
|
@ -193,8 +200,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
|
|||
|
||||
return ctx.pools.AllClusters(pkg_idx).Run(
|
||||
0, num_tasks, [&](uint64_t task, size_t cluster_idx) {
|
||||
const size_t worker =
|
||||
cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||
const size_t worker = ctx.Worker(cluster_idx);
|
||||
func(task, worker);
|
||||
});
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue