mirror of https://github.com/google/gemma.cpp.git
Update instrumentation for new Highway wall-time profiler
Pass the thread index through and use new zone_id. PiperOrigin-RevId: 773344242
This commit is contained in:
parent
1665ecc5c2
commit
4f5785b0fd
|
|
@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 12d9fa908e0c1d3346c298d472584687a24e4ce6 EXCLUDE_FROM_ALL)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 01019e979cd098f2ee618f39bb6718f1b4a3d901 EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
|
||||
## Note: absl needs to be installed by sentencepiece. This will only happen if
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5")
|
|||
# Require a more recent version.
|
||||
git_override(
|
||||
module_name = "highway",
|
||||
commit = "12d9fa908e0c1d3346c298d472584687a24e4ce6",
|
||||
commit = "01019e979cd098f2ee618f39bb6718f1b4a3d901",
|
||||
remote = "https://github.com/google/highway",
|
||||
)
|
||||
|
||||
|
|
@ -71,6 +71,7 @@ pip.parse(
|
|||
requirements_lock = "//compression/python:requirements.txt",
|
||||
)
|
||||
use_repo(pip, "compression_deps")
|
||||
|
||||
pip.parse(
|
||||
hub_name = "python_deps",
|
||||
python_version = "3.11",
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 12d9fa908e0c1d3346c298d472584687a24e4ce6)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 01019e979cd098f2ee618f39bb6718f1b4a3d901)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
|
@ -32,7 +32,7 @@ if (NOT BUILD_MODE)
|
|||
endif()
|
||||
if (BUILD_MODE STREQUAL "local")
|
||||
# Relative path to gemma.cpp from examples/hello_world/build/
|
||||
FetchContent_Declare(gemma SOURCE_DIR ../../..)
|
||||
FetchContent_Declare(gemma SOURCE_DIR ../../..)
|
||||
else()
|
||||
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e)
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 12d9fa908e0c1d3346c298d472584687a24e4ce6)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 01019e979cd098f2ee618f39bb6718f1b4a3d901)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
|
@ -32,7 +32,7 @@ if (NOT BUILD_MODE)
|
|||
endif()
|
||||
if (BUILD_MODE STREQUAL "local")
|
||||
# Relative path to gemma.cpp from examples/simplified_gemma/build/
|
||||
FetchContent_Declare(gemma SOURCE_DIR ../../..)
|
||||
FetchContent_Declare(gemma SOURCE_DIR ../../..)
|
||||
else()
|
||||
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e)
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -164,7 +164,10 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations, QBatch& qbatch,
|
||||
NestedPools& pools) {
|
||||
PROFILER_ZONE("Gen.Attention.DotSoftmax");
|
||||
PROFILER_ZONE("Gen.Attention.DotSoftmax.misc");
|
||||
static const uint32_t HWY_MAYBE_UNUSED zone_id_par =
|
||||
PROFILER_ADD_ZONE("Gen.Attention.DotSoftmax.par");
|
||||
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
|
|
@ -186,9 +189,12 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
ParallelizeOneRange(
|
||||
tq_ranges, pools.AllPackages(),
|
||||
[&](const IndexRange& tq_range, const size_t pkg_idx) {
|
||||
const size_t pkg_base = pkg_idx * pools.MaxWorkersPerPackage();
|
||||
pools.AllClusters(pkg_idx).Run(
|
||||
tq_range.begin(), tq_range.end(),
|
||||
[&](const size_t tq_idx, const size_t cluster_idx) {
|
||||
const HWY_MAYBE_UNUSED size_t cluster_base =
|
||||
pkg_base + cluster_idx * pools.MaxWorkersPerCluster();
|
||||
const size_t qi = div_qbatch.Remainder(tq_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(tq_idx);
|
||||
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
||||
|
|
@ -209,6 +215,11 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
.Run(
|
||||
0, layer_config.heads,
|
||||
[&](const size_t head, size_t thread) HWY_ATTR {
|
||||
#if PROFILER_ENABLED
|
||||
const hwy::Zone zone(cluster_base + thread,
|
||||
zone_id_par);
|
||||
#endif
|
||||
|
||||
const size_t head_offset =
|
||||
(head / kHeadGroups) * qkv_dim * 2;
|
||||
|
||||
|
|
|
|||
|
|
@ -385,9 +385,8 @@ static void DecompressToBF16(MatPtr& mat,
|
|||
|
||||
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
|
||||
const BlobReader& reader, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Startup.Weights.ReadBF16");
|
||||
|
||||
pool.Run(0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
|
||||
pool.Run(0, tensors.size(), [&](uint64_t task, size_t thread) {
|
||||
PROFILER_ZONE2(thread, "Startup.Weights.ReadBF16");
|
||||
const TensorToRead& tensor = tensors[task];
|
||||
MatPtr& mat = *tensor.mat;
|
||||
|
||||
|
|
@ -465,9 +464,9 @@ static std::vector<IOBatch> MakeBatches(
|
|||
static void ReadBatches(const BlobReader& reader,
|
||||
const std::vector<IOBatch>& batches,
|
||||
hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Startup.Weights.Read");
|
||||
// >5x speedup from parallel reads when cached.
|
||||
pool.Run(0, batches.size(), [&](uint64_t i, size_t /*thread*/) {
|
||||
pool.Run(0, batches.size(), [&](uint64_t i, size_t thread) {
|
||||
PROFILER_ZONE2(thread, "Startup.Weights.Read");
|
||||
const IOBatch& batch = batches[i];
|
||||
const std::string& key = reader.Keys()[batch.KeyIdx()];
|
||||
const uint64_t bytes_read = batch.Read(reader.file());
|
||||
|
|
|
|||
|
|
@ -875,8 +875,9 @@ class MMPerPackage {
|
|||
inner_tasks_(config.InnerTasks()),
|
||||
out_(config.Out()),
|
||||
line_bytes_(args.env->ctx.allocator.LineBytes()) {
|
||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DecompressA");
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.DecompressA", args_);
|
||||
zone.MaybeEnter(pkg_idx, zone_id, args_);
|
||||
A_ = DecompressA(A);
|
||||
}
|
||||
|
||||
|
|
@ -914,8 +915,7 @@ class MMPerPackage {
|
|||
// Single M and K ranges, parallel N. Fills all of C directly.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT", args_);
|
||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT");
|
||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||
const IndexRange& range_M = ranges_mc_.Range(0);
|
||||
|
|
@ -928,7 +928,10 @@ class MMPerPackage {
|
|||
// Similar to `loop_nc` below, but here we hoisted `A_view`.
|
||||
args_.env->parallel.ForNP(
|
||||
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
|
||||
[&](const IndexRange& range_nc) HWY_ATTR {
|
||||
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter(worker, zone_id, args_);
|
||||
|
||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||
const StridedViewBF B_storage_view(B_storage, K, B_stride);
|
||||
|
||||
|
|
@ -947,8 +950,7 @@ class MMPerPackage {
|
|||
// Single M range, parallel N, sequential K. Fills all of partial.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT_K", args_);
|
||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_K");
|
||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||
const IndexRange& range_mc = ranges_mc_.Range(0);
|
||||
|
||||
|
|
@ -975,7 +977,10 @@ class MMPerPackage {
|
|||
|
||||
args_.env->parallel.ForNP(
|
||||
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
|
||||
[&](const IndexRange& range_nc) HWY_ATTR {
|
||||
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter(worker, zone_id, args_);
|
||||
|
||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||
|
||||
// Peel off the first iteration of the kc loop: avoid
|
||||
|
|
@ -988,14 +993,17 @@ class MMPerPackage {
|
|||
});
|
||||
});
|
||||
|
||||
MMZone fill_zone;
|
||||
if (out_ == MMOut::kCopy) {
|
||||
fill_zone.MaybeEnter("MM.NT_K.FillC", args_);
|
||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_K.FillC.Copy");
|
||||
MMZone fill_zone;
|
||||
fill_zone.MaybeEnter(0, zone_id, args_);
|
||||
MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C_rows);
|
||||
} else if (out_ == MMOut::kParM) {
|
||||
fill_zone.MaybeEnter("MM.NT_K.FillC.ParM", args_);
|
||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_K.FillC.ParM");
|
||||
args_.env->parallel.ForRangeMC(
|
||||
range_mc, pkg_idx_, [&](size_t row_a) HWY_ATTR {
|
||||
range_mc, pkg_idx_, [&](size_t row_a, size_t worker) HWY_ATTR {
|
||||
MMZone fill_zone;
|
||||
fill_zone.MaybeEnter(worker, zone_id, args_);
|
||||
MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_,
|
||||
args_, C_rows);
|
||||
});
|
||||
|
|
@ -1008,8 +1016,7 @@ class MMPerPackage {
|
|||
// Fills `mc x nc` sections of C directly, in parallel.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT_MT", args_);
|
||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_MT");
|
||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||
const IndexRange& range_K = ranges_kc_.Range(0);
|
||||
const size_t K = range_K.Num();
|
||||
|
|
@ -1020,7 +1027,11 @@ class MMPerPackage {
|
|||
// except for the profiler strings and `out_tag`.
|
||||
args_.env->parallel.ForRangesMC_NC(
|
||||
ranges_mc_, ranges_nc_, pkg_idx_,
|
||||
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
|
||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||
size_t worker) HWY_ATTR {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter(worker, zone_id, args_);
|
||||
|
||||
const StridedViewBF& A_view = A_.View(range_mc.begin(), 0, K);
|
||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||
const StridedViewBF B_storage_view(B_storage, K, B_stride);
|
||||
|
|
@ -1041,8 +1052,8 @@ class MMPerPackage {
|
|||
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT_MT_K", args_);
|
||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_MT_K");
|
||||
static const uint32_t fill_zone_id = PROFILER_ADD_ZONE("MM.NT_MT_K.FillC");
|
||||
const size_t kc_max = ranges_kc_.TaskSize();
|
||||
HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
|
||||
const size_t B_stride =
|
||||
|
|
@ -1068,7 +1079,11 @@ class MMPerPackage {
|
|||
}; // loop_nc
|
||||
args_.env->parallel.ForRangesMC_NC(
|
||||
ranges_mc_, ranges_nc_, pkg_idx_,
|
||||
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
|
||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||
size_t worker) HWY_ATTR {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter(worker, zone_id, args_);
|
||||
|
||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||
const StridedViewBF B_storage_view(B_storage, kc_max, B_stride);
|
||||
|
||||
|
|
@ -1087,7 +1102,7 @@ class MMPerPackage {
|
|||
// `kDirect` is only used with `kNT_MT`.
|
||||
HWY_DASSERT(out_ == MMOut::kCopy);
|
||||
MMZone fill_zone;
|
||||
fill_zone.MaybeEnter("MM.NT_MT_K.FillC", args_);
|
||||
fill_zone.MaybeEnter(worker, fill_zone_id, args_);
|
||||
MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C_rows);
|
||||
});
|
||||
}
|
||||
|
|
@ -1139,13 +1154,16 @@ class MMPerPackage {
|
|||
|
||||
args_.env->parallel.ForNP(
|
||||
all_K, multiple_K, inner_tasks, pkg_idx_,
|
||||
[&](const IndexRange& range_K) { do_range(all_M, range_K); });
|
||||
[&](const IndexRange& range_K, size_t /*worker*/) {
|
||||
do_range(all_M, range_K);
|
||||
});
|
||||
break;
|
||||
}
|
||||
case MMParA::kM:
|
||||
args_.env->parallel.ForRangeMC(all_M, pkg_idx_, [&](size_t row_a) {
|
||||
do_range(IndexRange(row_a, row_a + 1), all_K);
|
||||
});
|
||||
args_.env->parallel.ForRangeMC(
|
||||
all_M, pkg_idx_, [&](size_t row_a, size_t /*worker*/) {
|
||||
do_range(IndexRange(row_a, row_a + 1), all_K);
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -1261,12 +1279,13 @@ struct MMImpl {
|
|||
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C_rows, const MMArgs& args,
|
||||
const MMConfig& config) {
|
||||
MMZone matmul_zone;
|
||||
matmul_zone.MaybeEnter("MM.DoMatMul", args);
|
||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DoMatMul");
|
||||
|
||||
// Outermost loop: static NUMA-aware partition of B rows across packages.
|
||||
args.env->parallel.ForPkg(
|
||||
args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) {
|
||||
MMZone matmul_zone;
|
||||
matmul_zone.MaybeEnter(pkg_idx, zone_id, args);
|
||||
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
||||
MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows);
|
||||
});
|
||||
|
|
|
|||
52
ops/matmul.h
52
ops/matmul.h
|
|
@ -88,11 +88,13 @@ class MMParallel {
|
|||
}
|
||||
|
||||
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
|
||||
// the granularity of per-cluster tasks. Calls `func(worker_range)`.
|
||||
// the granularity of per-cluster tasks. Calls `func(worker_range, worker)`.
|
||||
template <class Func>
|
||||
void ForNP(const IndexRange& range_np, size_t nx_multiple, size_t inner_tasks,
|
||||
size_t pkg_idx, const Func& func) {
|
||||
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage();
|
||||
|
||||
// Single cluster: parallel-for over static partition of `range_np`.
|
||||
hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx);
|
||||
const size_t num_clusters = all_clusters.NumWorkers();
|
||||
|
|
@ -102,8 +104,8 @@ class MMParallel {
|
|||
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
||||
return ParallelizeOneRange(
|
||||
worker_ranges, cluster,
|
||||
[&](const IndexRange& worker_range, size_t /*thread*/) {
|
||||
func(worker_range);
|
||||
[&](const IndexRange& worker_range, size_t thread) {
|
||||
func(worker_range, pkg_base + thread);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -114,21 +116,26 @@ class MMParallel {
|
|||
nx_ranges, all_clusters,
|
||||
[&](const IndexRange& nx_range, const size_t cluster_idx) {
|
||||
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
|
||||
const size_t cluster_base =
|
||||
pkg_base + cluster_idx * ctx_.pools.MaxWorkersPerCluster();
|
||||
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
|
||||
const IndexRangePartition worker_ranges = StaticPartition(
|
||||
nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
||||
ParallelizeOneRange(worker_ranges, cluster,
|
||||
[&](const IndexRange& worker_range,
|
||||
size_t /*thread*/) { func(worker_range); });
|
||||
ParallelizeOneRange(
|
||||
worker_ranges, cluster,
|
||||
[&](const IndexRange& worker_range, size_t thread) {
|
||||
func(worker_range, cluster_base + thread);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B
|
||||
// rows). Calls `func(range_mc, range_nc)`.
|
||||
// rows). Calls `func(range_mc, range_nc, worker)`.
|
||||
template <class Func>
|
||||
void ForRangesMC_NC(const IndexRangePartition& ranges_mc,
|
||||
const IndexRangePartition& ranges_nc, size_t pkg_idx,
|
||||
const Func& func) {
|
||||
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage();
|
||||
hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx);
|
||||
// `all_clusters` is a pool with one worker per cluster in a package.
|
||||
const size_t num_clusters = all_clusters.NumWorkers();
|
||||
|
|
@ -140,15 +147,16 @@ class MMParallel {
|
|||
// Low-batch: avoid Divide/Remainder.
|
||||
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
||||
return ParallelizeOneRange(
|
||||
ranges_nc, cluster,
|
||||
[&](const IndexRange& range_nc, size_t /*thread*/) {
|
||||
func(ranges_mc.Range(0), range_nc);
|
||||
ranges_nc, cluster, [&](const IndexRange& range_nc, size_t thread) {
|
||||
func(ranges_mc.Range(0), range_nc, pkg_base + thread);
|
||||
});
|
||||
} else {
|
||||
return ParallelizeTwoRanges(
|
||||
ranges_mc, ranges_nc, cluster,
|
||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||
size_t /*thread*/) { func(range_mc, range_nc); });
|
||||
size_t thread) {
|
||||
func(range_mc, range_nc, pkg_base + thread);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -157,22 +165,24 @@ class MMParallel {
|
|||
ParallelizeOneRange(
|
||||
ranges_nc, all_clusters,
|
||||
[&](const IndexRange range_nc, size_t cluster_idx) {
|
||||
const size_t cluster_base =
|
||||
pkg_base + cluster_idx * ctx_.pools.MaxWorkersPerCluster();
|
||||
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
|
||||
ParallelizeOneRange(
|
||||
ranges_mc, cluster,
|
||||
[&](const IndexRange& range_mc, size_t /*thread*/) {
|
||||
func(range_mc, range_nc);
|
||||
});
|
||||
ParallelizeOneRange(ranges_mc, cluster,
|
||||
[&](const IndexRange& range_mc, size_t thread) {
|
||||
func(range_mc, range_nc, cluster_base + thread);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Calls `func(row_a)` in parallel.
|
||||
// Calls `func(row_a, worker)` in parallel.
|
||||
template <class Func>
|
||||
void ForRangeMC(const IndexRange& range_mc, size_t pkg_idx,
|
||||
const Func& func) {
|
||||
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage();
|
||||
ctx_.pools.Pool(pkg_idx).Run(
|
||||
range_mc.begin(), range_mc.end(),
|
||||
[&](uint64_t row_a, size_t /*thread*/) { func(row_a); });
|
||||
[&](uint64_t row_a, size_t thread) { func(row_a, pkg_base + thread); });
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -714,9 +724,9 @@ class MMZone {
|
|||
}
|
||||
|
||||
// `name` must be a string literal.
|
||||
void MaybeEnter(const char* name, const MMArgs& args) {
|
||||
void MaybeEnter(size_t thread_id, uint32_t zone_id, const MMArgs& args) {
|
||||
if (args.per_key->WantProfile()) {
|
||||
new (&data_) Zone(name);
|
||||
new (&data_) Zone(thread_id, zone_id);
|
||||
used_ = true;
|
||||
}
|
||||
}
|
||||
|
|
@ -727,7 +737,7 @@ class MMZone {
|
|||
};
|
||||
#else
|
||||
struct MMZone {
|
||||
void MaybeEnter(const char*, const MMArgs&) {}
|
||||
void MaybeEnter(size_t, uint32_t, const MMArgs&) {}
|
||||
};
|
||||
#endif // PROFILER_ENABLED
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue