Prepare profiler annotations for new API

PiperOrigin-RevId: 792808391
This commit is contained in:
Jan Wassenberg 2025-08-08 16:50:54 -07:00 committed by Copybara-Service
parent 2e9c93a609
commit eef564e8f0
1 changed files with 25 additions and 25 deletions

View File

@ -912,7 +912,7 @@ class MMPerPackage {
// Single M and K ranges, parallel N. Fills all of C directly.
template <typename TB, typename TC>
HWY_INLINE void DoNT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT");
static const auto zone = PROFILER_ADD_ZONE("MM.NT");
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
const IndexRange& range_M = ranges_mc_.Range(0);
@ -926,8 +926,8 @@ class MMPerPackage {
args_.env->parallel.ForNP(
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
MMZone zone;
zone.MaybeEnter(worker, zone_id, args_);
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args_);
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const StridedViewBF B_storage_view(B_storage, K, B_stride);
@ -947,7 +947,7 @@ class MMPerPackage {
// Single M range, parallel N, sequential K. Fills all of partial.
template <typename TB, typename TC>
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_K");
static const auto zone = PROFILER_ADD_ZONE("MM.NT_K");
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
const IndexRange& range_mc = ranges_mc_.Range(0);
@ -975,8 +975,8 @@ class MMPerPackage {
args_.env->parallel.ForNP(
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
MMZone zone;
zone.MaybeEnter(worker, zone_id, args_);
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args_);
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
@ -991,16 +991,16 @@ class MMPerPackage {
});
if (out_ == MMOut::kCopy) {
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_K.FillC.Copy");
static const auto zone = PROFILER_ADD_ZONE("MM.NT_K.FillC.Copy");
MMZone fill_zone;
fill_zone.MaybeEnter(0, zone_id, args_);
fill_zone.MaybeEnter(0, zone, args_);
MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C_rows);
} else if (out_ == MMOut::kParM) {
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_K.FillC.ParM");
static const auto zone = PROFILER_ADD_ZONE("MM.NT_K.FillC.ParM");
args_.env->parallel.ForRangeMC(
range_mc, pkg_idx_, [&](size_t row_a, size_t worker) HWY_ATTR {
MMZone fill_zone;
fill_zone.MaybeEnter(worker, zone_id, args_);
fill_zone.MaybeEnter(worker, zone, args_);
MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_,
args_, C_rows);
});
@ -1013,7 +1013,7 @@ class MMPerPackage {
// Fills `mc x nc` sections of C directly, in parallel.
template <typename TB, typename TC>
HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_MT");
static const auto zone = PROFILER_ADD_ZONE("MM.NT_MT");
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
const IndexRange& range_K = ranges_kc_.Range(0);
const size_t K = range_K.Num();
@ -1026,8 +1026,8 @@ class MMPerPackage {
ranges_mc_, ranges_nc_, pkg_idx_,
[&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) HWY_ATTR {
MMZone zone;
zone.MaybeEnter(worker, zone_id, args_);
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args_);
const StridedViewBF& A_view = A_.View(range_mc.begin(), 0, K);
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
@ -1049,8 +1049,8 @@ class MMPerPackage {
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
template <typename TB, typename TC>
HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_MT_K");
static const uint32_t fill_zone_id = PROFILER_ADD_ZONE("MM.NT_MT_K.FillC");
static const auto zone = PROFILER_ADD_ZONE("MM.NT_MT_K");
static const auto fill_zone = PROFILER_ADD_ZONE("MM.NT_MT_K.FillC");
const size_t kc_max = ranges_kc_.TaskSize();
HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
const size_t B_stride =
@ -1078,8 +1078,8 @@ class MMPerPackage {
ranges_mc_, ranges_nc_, pkg_idx_,
[&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) HWY_ATTR {
MMZone zone;
zone.MaybeEnter(worker, zone_id, args_);
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args_);
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const StridedViewBF B_storage_view(B_storage, kc_max, B_stride);
@ -1098,8 +1098,8 @@ class MMPerPackage {
// Already in parallel section, hence no `kParM`, and
// `kDirect` is only used with `kNT_MT`.
HWY_DASSERT(out_ == MMOut::kCopy);
MMZone fill_zone;
fill_zone.MaybeEnter(worker, fill_zone_id, args_);
MMZone fill_mm_zone;
fill_mm_zone.MaybeEnter(worker, fill_zone, args_);
MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C_rows);
});
}
@ -1116,13 +1116,13 @@ class MMPerPackage {
const size_t NBF = hn::Lanes(dbf);
static_assert(hwy::IsSameEither<TA, BF16, float>(), "Can seek");
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DecompressA");
static const auto zone = PROFILER_ADD_ZONE("MM.DecompressA");
const auto do_range = [&](const IndexRange& range_M,
const IndexRange& range_K,
size_t worker) HWY_ATTR {
MMZone zone;
zone.MaybeEnter(worker, zone_id, args_);
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args_);
const size_t col0 = range_K.begin();
const size_t cols = range_K.Num();
@ -1280,14 +1280,14 @@ struct MMImpl {
RowPtrs<TC> C_rows, const MMArgs& args,
const MMConfig& config) {
PROFILER_ZONE("MM.DoMatMul");
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DoMatMul.PerPkg");
static const auto zone = PROFILER_ADD_ZONE("MM.DoMatMul.PerPkg");
if constexpr (kMaxPackages > 1) {
// Outermost loop: static NUMA-aware partition of B rows across packages.
args.env->parallel.ForPkg(
args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) {
MMZone matmul_zone;
matmul_zone.MaybeEnter(pkg_idx, zone_id, args);
MMZone mm_zone;
mm_zone.MaybeEnter(pkg_idx, zone, args);
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows);
});