Update instrumentation for new Highway wall-time profiler

Pass the thread index through and use new zone_id.

PiperOrigin-RevId: 773344242
This commit is contained in:
Jan Wassenberg 2025-06-19 07:45:30 -07:00 committed by Copybara-Service
parent 1665ecc5c2
commit 4f5785b0fd
8 changed files with 97 additions and 57 deletions

View File

@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 12d9fa908e0c1d3346c298d472584687a24e4ce6 EXCLUDE_FROM_ALL) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 01019e979cd098f2ee618f39bb6718f1b4a3d901 EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(highway) FetchContent_MakeAvailable(highway)
## Note: absl needs to be installed by sentencepiece. This will only happen if ## Note: absl needs to be installed by sentencepiece. This will only happen if

View File

@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5")
# Require a more recent version. # Require a more recent version.
git_override( git_override(
module_name = "highway", module_name = "highway",
commit = "12d9fa908e0c1d3346c298d472584687a24e4ce6", commit = "01019e979cd098f2ee618f39bb6718f1b4a3d901",
remote = "https://github.com/google/highway", remote = "https://github.com/google/highway",
) )
@ -71,6 +71,7 @@ pip.parse(
requirements_lock = "//compression/python:requirements.txt", requirements_lock = "//compression/python:requirements.txt",
) )
use_repo(pip, "compression_deps") use_repo(pip, "compression_deps")
pip.parse( pip.parse(
hub_name = "python_deps", hub_name = "python_deps",
python_version = "3.11", python_version = "3.11",

View File

@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
include(FetchContent) include(FetchContent)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 12d9fa908e0c1d3346c298d472584687a24e4ce6) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 01019e979cd098f2ee618f39bb6718f1b4a3d901)
FetchContent_MakeAvailable(highway) FetchContent_MakeAvailable(highway)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
FetchContent_MakeAvailable(sentencepiece) FetchContent_MakeAvailable(sentencepiece)
@ -32,7 +32,7 @@ if (NOT BUILD_MODE)
endif() endif()
if (BUILD_MODE STREQUAL "local") if (BUILD_MODE STREQUAL "local")
# Relative path to gemma.cpp from examples/hello_world/build/ # Relative path to gemma.cpp from examples/hello_world/build/
FetchContent_Declare(gemma SOURCE_DIR ../../..) FetchContent_Declare(gemma SOURCE_DIR ../../..)
else() else()
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e) FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e)
endif() endif()

View File

@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
include(FetchContent) include(FetchContent)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 12d9fa908e0c1d3346c298d472584687a24e4ce6) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 01019e979cd098f2ee618f39bb6718f1b4a3d901)
FetchContent_MakeAvailable(highway) FetchContent_MakeAvailable(highway)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
FetchContent_MakeAvailable(sentencepiece) FetchContent_MakeAvailable(sentencepiece)
@ -32,7 +32,7 @@ if (NOT BUILD_MODE)
endif() endif()
if (BUILD_MODE STREQUAL "local") if (BUILD_MODE STREQUAL "local")
# Relative path to gemma.cpp from examples/simplified_gemma/build/ # Relative path to gemma.cpp from examples/simplified_gemma/build/
FetchContent_Declare(gemma SOURCE_DIR ../../..) FetchContent_Declare(gemma SOURCE_DIR ../../..)
else() else()
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e) FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e)
endif() endif()

View File

@ -164,7 +164,10 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch, AttentionActivations& activations, QBatch& qbatch,
NestedPools& pools) { NestedPools& pools) {
PROFILER_ZONE("Gen.Attention.DotSoftmax"); PROFILER_ZONE("Gen.Attention.DotSoftmax.misc");
static const uint32_t HWY_MAYBE_UNUSED zone_id_par =
PROFILER_ADD_ZONE("Gen.Attention.DotSoftmax.par");
const hwy::Divisor div_qbatch(qbatch.Size()); const hwy::Divisor div_qbatch(qbatch.Size());
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
const size_t qkv_dim = layer_config.qkv_dim; const size_t qkv_dim = layer_config.qkv_dim;
@ -186,9 +189,12 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
ParallelizeOneRange( ParallelizeOneRange(
tq_ranges, pools.AllPackages(), tq_ranges, pools.AllPackages(),
[&](const IndexRange& tq_range, const size_t pkg_idx) { [&](const IndexRange& tq_range, const size_t pkg_idx) {
const size_t pkg_base = pkg_idx * pools.MaxWorkersPerPackage();
pools.AllClusters(pkg_idx).Run( pools.AllClusters(pkg_idx).Run(
tq_range.begin(), tq_range.end(), tq_range.begin(), tq_range.end(),
[&](const size_t tq_idx, const size_t cluster_idx) { [&](const size_t tq_idx, const size_t cluster_idx) {
const HWY_MAYBE_UNUSED size_t cluster_base =
pkg_base + cluster_idx * pools.MaxWorkersPerCluster();
const size_t qi = div_qbatch.Remainder(tq_idx); const size_t qi = div_qbatch.Remainder(tq_idx);
const size_t batch_idx = div_qbatch.Divide(tq_idx); const size_t batch_idx = div_qbatch.Divide(tq_idx);
auto& kv_cache = qbatch.KV(qi).kv_cache; auto& kv_cache = qbatch.KV(qi).kv_cache;
@ -209,6 +215,11 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
.Run( .Run(
0, layer_config.heads, 0, layer_config.heads,
[&](const size_t head, size_t thread) HWY_ATTR { [&](const size_t head, size_t thread) HWY_ATTR {
#if PROFILER_ENABLED
const hwy::Zone zone(cluster_base + thread,
zone_id_par);
#endif
const size_t head_offset = const size_t head_offset =
(head / kHeadGroups) * qkv_dim * 2; (head / kHeadGroups) * qkv_dim * 2;

View File

@ -385,9 +385,8 @@ static void DecompressToBF16(MatPtr& mat,
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors, static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
const BlobReader& reader, hwy::ThreadPool& pool) { const BlobReader& reader, hwy::ThreadPool& pool) {
PROFILER_ZONE("Startup.Weights.ReadBF16"); pool.Run(0, tensors.size(), [&](uint64_t task, size_t thread) {
PROFILER_ZONE2(thread, "Startup.Weights.ReadBF16");
pool.Run(0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
const TensorToRead& tensor = tensors[task]; const TensorToRead& tensor = tensors[task];
MatPtr& mat = *tensor.mat; MatPtr& mat = *tensor.mat;
@ -465,9 +464,9 @@ static std::vector<IOBatch> MakeBatches(
static void ReadBatches(const BlobReader& reader, static void ReadBatches(const BlobReader& reader,
const std::vector<IOBatch>& batches, const std::vector<IOBatch>& batches,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
PROFILER_ZONE("Startup.Weights.Read");
// >5x speedup from parallel reads when cached. // >5x speedup from parallel reads when cached.
pool.Run(0, batches.size(), [&](uint64_t i, size_t /*thread*/) { pool.Run(0, batches.size(), [&](uint64_t i, size_t thread) {
PROFILER_ZONE2(thread, "Startup.Weights.Read");
const IOBatch& batch = batches[i]; const IOBatch& batch = batches[i];
const std::string& key = reader.Keys()[batch.KeyIdx()]; const std::string& key = reader.Keys()[batch.KeyIdx()];
const uint64_t bytes_read = batch.Read(reader.file()); const uint64_t bytes_read = batch.Read(reader.file());

View File

@ -875,8 +875,9 @@ class MMPerPackage {
inner_tasks_(config.InnerTasks()), inner_tasks_(config.InnerTasks()),
out_(config.Out()), out_(config.Out()),
line_bytes_(args.env->ctx.allocator.LineBytes()) { line_bytes_(args.env->ctx.allocator.LineBytes()) {
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DecompressA");
MMZone zone; MMZone zone;
zone.MaybeEnter("MM.DecompressA", args_); zone.MaybeEnter(pkg_idx, zone_id, args_);
A_ = DecompressA(A); A_ = DecompressA(A);
} }
@ -914,8 +915,7 @@ class MMPerPackage {
// Single M and K ranges, parallel N. Fills all of C directly. // Single M and K ranges, parallel N. Fills all of C directly.
template <typename TB, typename TC> template <typename TB, typename TC>
HWY_INLINE void DoNT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const { HWY_INLINE void DoNT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
MMZone zone; static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT");
zone.MaybeEnter("MM.NT", args_);
HWY_DASSERT(ranges_mc_.NumTasks() == 1); HWY_DASSERT(ranges_mc_.NumTasks() == 1);
HWY_DASSERT(ranges_kc_.NumTasks() == 1); HWY_DASSERT(ranges_kc_.NumTasks() == 1);
const IndexRange& range_M = ranges_mc_.Range(0); const IndexRange& range_M = ranges_mc_.Range(0);
@ -928,7 +928,10 @@ class MMPerPackage {
// Similar to `loop_nc` below, but here we hoisted `A_view`. // Similar to `loop_nc` below, but here we hoisted `A_view`.
args_.env->parallel.ForNP( args_.env->parallel.ForNP(
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
[&](const IndexRange& range_nc) HWY_ATTR { [&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
MMZone zone;
zone.MaybeEnter(worker, zone_id, args_);
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const StridedViewBF B_storage_view(B_storage, K, B_stride); const StridedViewBF B_storage_view(B_storage, K, B_stride);
@ -947,8 +950,7 @@ class MMPerPackage {
// Single M range, parallel N, sequential K. Fills all of partial. // Single M range, parallel N, sequential K. Fills all of partial.
template <typename TB, typename TC> template <typename TB, typename TC>
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const { HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
MMZone zone; static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_K");
zone.MaybeEnter("MM.NT_K", args_);
HWY_DASSERT(ranges_mc_.NumTasks() == 1); HWY_DASSERT(ranges_mc_.NumTasks() == 1);
const IndexRange& range_mc = ranges_mc_.Range(0); const IndexRange& range_mc = ranges_mc_.Range(0);
@ -975,7 +977,10 @@ class MMPerPackage {
args_.env->parallel.ForNP( args_.env->parallel.ForNP(
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
[&](const IndexRange& range_nc) HWY_ATTR { [&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
MMZone zone;
zone.MaybeEnter(worker, zone_id, args_);
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
// Peel off the first iteration of the kc loop: avoid // Peel off the first iteration of the kc loop: avoid
@ -988,14 +993,17 @@ class MMPerPackage {
}); });
}); });
MMZone fill_zone;
if (out_ == MMOut::kCopy) { if (out_ == MMOut::kCopy) {
fill_zone.MaybeEnter("MM.NT_K.FillC", args_); static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_K.FillC.Copy");
MMZone fill_zone;
fill_zone.MaybeEnter(0, zone_id, args_);
MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C_rows); MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C_rows);
} else if (out_ == MMOut::kParM) { } else if (out_ == MMOut::kParM) {
fill_zone.MaybeEnter("MM.NT_K.FillC.ParM", args_); static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_K.FillC.ParM");
args_.env->parallel.ForRangeMC( args_.env->parallel.ForRangeMC(
range_mc, pkg_idx_, [&](size_t row_a) HWY_ATTR { range_mc, pkg_idx_, [&](size_t row_a, size_t worker) HWY_ATTR {
MMZone fill_zone;
fill_zone.MaybeEnter(worker, zone_id, args_);
MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_, MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_,
args_, C_rows); args_, C_rows);
}); });
@ -1008,8 +1016,7 @@ class MMPerPackage {
// Fills `mc x nc` sections of C directly, in parallel. // Fills `mc x nc` sections of C directly, in parallel.
template <typename TB, typename TC> template <typename TB, typename TC>
HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const { HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
MMZone zone; static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_MT");
zone.MaybeEnter("MM.NT_MT", args_);
HWY_DASSERT(ranges_kc_.NumTasks() == 1); HWY_DASSERT(ranges_kc_.NumTasks() == 1);
const IndexRange& range_K = ranges_kc_.Range(0); const IndexRange& range_K = ranges_kc_.Range(0);
const size_t K = range_K.Num(); const size_t K = range_K.Num();
@ -1020,7 +1027,11 @@ class MMPerPackage {
// except for the profiler strings and `out_tag`. // except for the profiler strings and `out_tag`.
args_.env->parallel.ForRangesMC_NC( args_.env->parallel.ForRangesMC_NC(
ranges_mc_, ranges_nc_, pkg_idx_, ranges_mc_, ranges_nc_, pkg_idx_,
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR { [&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) HWY_ATTR {
MMZone zone;
zone.MaybeEnter(worker, zone_id, args_);
const StridedViewBF& A_view = A_.View(range_mc.begin(), 0, K); const StridedViewBF& A_view = A_.View(range_mc.begin(), 0, K);
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const StridedViewBF B_storage_view(B_storage, K, B_stride); const StridedViewBF B_storage_view(B_storage, K, B_stride);
@ -1041,8 +1052,8 @@ class MMPerPackage {
// Fills `mc x nc` sections of `partial`, then `C`, in parallel. // Fills `mc x nc` sections of `partial`, then `C`, in parallel.
template <typename TB, typename TC> template <typename TB, typename TC>
HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const { HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
MMZone zone; static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_MT_K");
zone.MaybeEnter("MM.NT_MT_K", args_); static const uint32_t fill_zone_id = PROFILER_ADD_ZONE("MM.NT_MT_K.FillC");
const size_t kc_max = ranges_kc_.TaskSize(); const size_t kc_max = ranges_kc_.TaskSize();
HWY_DASSERT(kc_max <= MMStorage::kMaxKC); HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
const size_t B_stride = const size_t B_stride =
@ -1068,7 +1079,11 @@ class MMPerPackage {
}; // loop_nc }; // loop_nc
args_.env->parallel.ForRangesMC_NC( args_.env->parallel.ForRangesMC_NC(
ranges_mc_, ranges_nc_, pkg_idx_, ranges_mc_, ranges_nc_, pkg_idx_,
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR { [&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) HWY_ATTR {
MMZone zone;
zone.MaybeEnter(worker, zone_id, args_);
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const StridedViewBF B_storage_view(B_storage, kc_max, B_stride); const StridedViewBF B_storage_view(B_storage, kc_max, B_stride);
@ -1087,7 +1102,7 @@ class MMPerPackage {
// `kDirect` is only used with `kNT_MT`. // `kDirect` is only used with `kNT_MT`.
HWY_DASSERT(out_ == MMOut::kCopy); HWY_DASSERT(out_ == MMOut::kCopy);
MMZone fill_zone; MMZone fill_zone;
fill_zone.MaybeEnter("MM.NT_MT_K.FillC", args_); fill_zone.MaybeEnter(worker, fill_zone_id, args_);
MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C_rows); MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C_rows);
}); });
} }
@ -1139,13 +1154,16 @@ class MMPerPackage {
args_.env->parallel.ForNP( args_.env->parallel.ForNP(
all_K, multiple_K, inner_tasks, pkg_idx_, all_K, multiple_K, inner_tasks, pkg_idx_,
[&](const IndexRange& range_K) { do_range(all_M, range_K); }); [&](const IndexRange& range_K, size_t /*worker*/) {
do_range(all_M, range_K);
});
break; break;
} }
case MMParA::kM: case MMParA::kM:
args_.env->parallel.ForRangeMC(all_M, pkg_idx_, [&](size_t row_a) { args_.env->parallel.ForRangeMC(
do_range(IndexRange(row_a, row_a + 1), all_K); all_M, pkg_idx_, [&](size_t row_a, size_t /*worker*/) {
}); do_range(IndexRange(row_a, row_a + 1), all_K);
});
break; break;
} }
} }
@ -1261,12 +1279,13 @@ struct MMImpl {
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B, static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows, const MMArgs& args, RowPtrs<TC> C_rows, const MMArgs& args,
const MMConfig& config) { const MMConfig& config) {
MMZone matmul_zone; static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DoMatMul");
matmul_zone.MaybeEnter("MM.DoMatMul", args);
// Outermost loop: static NUMA-aware partition of B rows across packages. // Outermost loop: static NUMA-aware partition of B rows across packages.
args.env->parallel.ForPkg( args.env->parallel.ForPkg(
args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) { args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) {
MMZone matmul_zone;
matmul_zone.MaybeEnter(pkg_idx, zone_id, args);
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows); MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows);
}); });

View File

@ -88,11 +88,13 @@ class MMParallel {
} }
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is // Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
// the granularity of per-cluster tasks. Calls `func(worker_range)`. // the granularity of per-cluster tasks. Calls `func(worker_range, worker)`.
template <class Func> template <class Func>
void ForNP(const IndexRange& range_np, size_t nx_multiple, size_t inner_tasks, void ForNP(const IndexRange& range_np, size_t nx_multiple, size_t inner_tasks,
size_t pkg_idx, const Func& func) { size_t pkg_idx, const Func& func) {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage();
// Single cluster: parallel-for over static partition of `range_np`. // Single cluster: parallel-for over static partition of `range_np`.
hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx); hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx);
const size_t num_clusters = all_clusters.NumWorkers(); const size_t num_clusters = all_clusters.NumWorkers();
@ -102,8 +104,8 @@ class MMParallel {
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple); range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
return ParallelizeOneRange( return ParallelizeOneRange(
worker_ranges, cluster, worker_ranges, cluster,
[&](const IndexRange& worker_range, size_t /*thread*/) { [&](const IndexRange& worker_range, size_t thread) {
func(worker_range); func(worker_range, pkg_base + thread);
}); });
} }
@ -114,21 +116,26 @@ class MMParallel {
nx_ranges, all_clusters, nx_ranges, all_clusters,
[&](const IndexRange& nx_range, const size_t cluster_idx) { [&](const IndexRange& nx_range, const size_t cluster_idx) {
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx); hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
const size_t cluster_base =
pkg_base + cluster_idx * ctx_.pools.MaxWorkersPerCluster();
// Parallel-for over sub-ranges of `cluster_range` within the cluster. // Parallel-for over sub-ranges of `cluster_range` within the cluster.
const IndexRangePartition worker_ranges = StaticPartition( const IndexRangePartition worker_ranges = StaticPartition(
nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple); nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple);
ParallelizeOneRange(worker_ranges, cluster, ParallelizeOneRange(
[&](const IndexRange& worker_range, worker_ranges, cluster,
size_t /*thread*/) { func(worker_range); }); [&](const IndexRange& worker_range, size_t thread) {
func(worker_range, cluster_base + thread);
});
}); });
} }
// Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B // Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B
// rows). Calls `func(range_mc, range_nc)`. // rows). Calls `func(range_mc, range_nc, worker)`.
template <class Func> template <class Func>
void ForRangesMC_NC(const IndexRangePartition& ranges_mc, void ForRangesMC_NC(const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc, size_t pkg_idx, const IndexRangePartition& ranges_nc, size_t pkg_idx,
const Func& func) { const Func& func) {
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage();
hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx); hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx);
// `all_clusters` is a pool with one worker per cluster in a package. // `all_clusters` is a pool with one worker per cluster in a package.
const size_t num_clusters = all_clusters.NumWorkers(); const size_t num_clusters = all_clusters.NumWorkers();
@ -140,15 +147,16 @@ class MMParallel {
// Low-batch: avoid Divide/Remainder. // Low-batch: avoid Divide/Remainder.
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
return ParallelizeOneRange( return ParallelizeOneRange(
ranges_nc, cluster, ranges_nc, cluster, [&](const IndexRange& range_nc, size_t thread) {
[&](const IndexRange& range_nc, size_t /*thread*/) { func(ranges_mc.Range(0), range_nc, pkg_base + thread);
func(ranges_mc.Range(0), range_nc);
}); });
} else { } else {
return ParallelizeTwoRanges( return ParallelizeTwoRanges(
ranges_mc, ranges_nc, cluster, ranges_mc, ranges_nc, cluster,
[&](const IndexRange& range_mc, const IndexRange& range_nc, [&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t /*thread*/) { func(range_mc, range_nc); }); size_t thread) {
func(range_mc, range_nc, pkg_base + thread);
});
} }
} }
@ -157,22 +165,24 @@ class MMParallel {
ParallelizeOneRange( ParallelizeOneRange(
ranges_nc, all_clusters, ranges_nc, all_clusters,
[&](const IndexRange range_nc, size_t cluster_idx) { [&](const IndexRange range_nc, size_t cluster_idx) {
const size_t cluster_base =
pkg_base + cluster_idx * ctx_.pools.MaxWorkersPerCluster();
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx); hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
ParallelizeOneRange( ParallelizeOneRange(ranges_mc, cluster,
ranges_mc, cluster, [&](const IndexRange& range_mc, size_t thread) {
[&](const IndexRange& range_mc, size_t /*thread*/) { func(range_mc, range_nc, cluster_base + thread);
func(range_mc, range_nc); });
});
}); });
} }
// Calls `func(row_a)` in parallel. // Calls `func(row_a, worker)` in parallel.
template <class Func> template <class Func>
void ForRangeMC(const IndexRange& range_mc, size_t pkg_idx, void ForRangeMC(const IndexRange& range_mc, size_t pkg_idx,
const Func& func) { const Func& func) {
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage();
ctx_.pools.Pool(pkg_idx).Run( ctx_.pools.Pool(pkg_idx).Run(
range_mc.begin(), range_mc.end(), range_mc.begin(), range_mc.end(),
[&](uint64_t row_a, size_t /*thread*/) { func(row_a); }); [&](uint64_t row_a, size_t thread) { func(row_a, pkg_base + thread); });
} }
private: private:
@ -714,9 +724,9 @@ class MMZone {
} }
// `name` must be a string literal. // `name` must be a string literal.
void MaybeEnter(const char* name, const MMArgs& args) { void MaybeEnter(size_t thread_id, uint32_t zone_id, const MMArgs& args) {
if (args.per_key->WantProfile()) { if (args.per_key->WantProfile()) {
new (&data_) Zone(name); new (&data_) Zone(thread_id, zone_id);
used_ = true; used_ = true;
} }
} }
@ -727,7 +737,7 @@ class MMZone {
}; };
#else #else
struct MMZone { struct MMZone {
void MaybeEnter(const char*, const MMArgs&) {} void MaybeEnter(size_t, uint32_t, const MMArgs&) {}
}; };
#endif // PROFILER_ENABLED #endif // PROFILER_ENABLED