mirror of https://github.com/google/gemma.cpp.git
Prepare profiler annotations for new API
PiperOrigin-RevId: 792808391
This commit is contained in:
parent
2e9c93a609
commit
eef564e8f0
|
|
@ -912,7 +912,7 @@ class MMPerPackage {
|
||||||
// Single M and K ranges, parallel N. Fills all of C directly.
|
// Single M and K ranges, parallel N. Fills all of C directly.
|
||||||
template <typename TB, typename TC>
|
template <typename TB, typename TC>
|
||||||
HWY_INLINE void DoNT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
HWY_INLINE void DoNT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT");
|
static const auto zone = PROFILER_ADD_ZONE("MM.NT");
|
||||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||||
const IndexRange& range_M = ranges_mc_.Range(0);
|
const IndexRange& range_M = ranges_mc_.Range(0);
|
||||||
|
|
@ -926,8 +926,8 @@ class MMPerPackage {
|
||||||
args_.env->parallel.ForNP(
|
args_.env->parallel.ForNP(
|
||||||
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
|
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
|
||||||
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||||
MMZone zone;
|
MMZone mm_zone;
|
||||||
zone.MaybeEnter(worker, zone_id, args_);
|
mm_zone.MaybeEnter(worker, zone, args_);
|
||||||
|
|
||||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||||
const StridedViewBF B_storage_view(B_storage, K, B_stride);
|
const StridedViewBF B_storage_view(B_storage, K, B_stride);
|
||||||
|
|
@ -947,7 +947,7 @@ class MMPerPackage {
|
||||||
// Single M range, parallel N, sequential K. Fills all of partial.
|
// Single M range, parallel N, sequential K. Fills all of partial.
|
||||||
template <typename TB, typename TC>
|
template <typename TB, typename TC>
|
||||||
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_K");
|
static const auto zone = PROFILER_ADD_ZONE("MM.NT_K");
|
||||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||||
const IndexRange& range_mc = ranges_mc_.Range(0);
|
const IndexRange& range_mc = ranges_mc_.Range(0);
|
||||||
|
|
||||||
|
|
@ -975,8 +975,8 @@ class MMPerPackage {
|
||||||
args_.env->parallel.ForNP(
|
args_.env->parallel.ForNP(
|
||||||
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
|
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
|
||||||
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||||
MMZone zone;
|
MMZone mm_zone;
|
||||||
zone.MaybeEnter(worker, zone_id, args_);
|
mm_zone.MaybeEnter(worker, zone, args_);
|
||||||
|
|
||||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||||
|
|
||||||
|
|
@ -991,16 +991,16 @@ class MMPerPackage {
|
||||||
});
|
});
|
||||||
|
|
||||||
if (out_ == MMOut::kCopy) {
|
if (out_ == MMOut::kCopy) {
|
||||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_K.FillC.Copy");
|
static const auto zone = PROFILER_ADD_ZONE("MM.NT_K.FillC.Copy");
|
||||||
MMZone fill_zone;
|
MMZone fill_zone;
|
||||||
fill_zone.MaybeEnter(0, zone_id, args_);
|
fill_zone.MaybeEnter(0, zone, args_);
|
||||||
MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C_rows);
|
MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C_rows);
|
||||||
} else if (out_ == MMOut::kParM) {
|
} else if (out_ == MMOut::kParM) {
|
||||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_K.FillC.ParM");
|
static const auto zone = PROFILER_ADD_ZONE("MM.NT_K.FillC.ParM");
|
||||||
args_.env->parallel.ForRangeMC(
|
args_.env->parallel.ForRangeMC(
|
||||||
range_mc, pkg_idx_, [&](size_t row_a, size_t worker) HWY_ATTR {
|
range_mc, pkg_idx_, [&](size_t row_a, size_t worker) HWY_ATTR {
|
||||||
MMZone fill_zone;
|
MMZone fill_zone;
|
||||||
fill_zone.MaybeEnter(worker, zone_id, args_);
|
fill_zone.MaybeEnter(worker, zone, args_);
|
||||||
MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_,
|
MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_,
|
||||||
args_, C_rows);
|
args_, C_rows);
|
||||||
});
|
});
|
||||||
|
|
@ -1013,7 +1013,7 @@ class MMPerPackage {
|
||||||
// Fills `mc x nc` sections of C directly, in parallel.
|
// Fills `mc x nc` sections of C directly, in parallel.
|
||||||
template <typename TB, typename TC>
|
template <typename TB, typename TC>
|
||||||
HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_MT");
|
static const auto zone = PROFILER_ADD_ZONE("MM.NT_MT");
|
||||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||||
const IndexRange& range_K = ranges_kc_.Range(0);
|
const IndexRange& range_K = ranges_kc_.Range(0);
|
||||||
const size_t K = range_K.Num();
|
const size_t K = range_K.Num();
|
||||||
|
|
@ -1026,8 +1026,8 @@ class MMPerPackage {
|
||||||
ranges_mc_, ranges_nc_, pkg_idx_,
|
ranges_mc_, ranges_nc_, pkg_idx_,
|
||||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||||
size_t worker) HWY_ATTR {
|
size_t worker) HWY_ATTR {
|
||||||
MMZone zone;
|
MMZone mm_zone;
|
||||||
zone.MaybeEnter(worker, zone_id, args_);
|
mm_zone.MaybeEnter(worker, zone, args_);
|
||||||
|
|
||||||
const StridedViewBF& A_view = A_.View(range_mc.begin(), 0, K);
|
const StridedViewBF& A_view = A_.View(range_mc.begin(), 0, K);
|
||||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||||
|
|
@ -1049,8 +1049,8 @@ class MMPerPackage {
|
||||||
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
|
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
|
||||||
template <typename TB, typename TC>
|
template <typename TB, typename TC>
|
||||||
HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_MT_K");
|
static const auto zone = PROFILER_ADD_ZONE("MM.NT_MT_K");
|
||||||
static const uint32_t fill_zone_id = PROFILER_ADD_ZONE("MM.NT_MT_K.FillC");
|
static const auto fill_zone = PROFILER_ADD_ZONE("MM.NT_MT_K.FillC");
|
||||||
const size_t kc_max = ranges_kc_.TaskSize();
|
const size_t kc_max = ranges_kc_.TaskSize();
|
||||||
HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
|
HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
|
||||||
const size_t B_stride =
|
const size_t B_stride =
|
||||||
|
|
@ -1078,8 +1078,8 @@ class MMPerPackage {
|
||||||
ranges_mc_, ranges_nc_, pkg_idx_,
|
ranges_mc_, ranges_nc_, pkg_idx_,
|
||||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||||
size_t worker) HWY_ATTR {
|
size_t worker) HWY_ATTR {
|
||||||
MMZone zone;
|
MMZone mm_zone;
|
||||||
zone.MaybeEnter(worker, zone_id, args_);
|
mm_zone.MaybeEnter(worker, zone, args_);
|
||||||
|
|
||||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||||
const StridedViewBF B_storage_view(B_storage, kc_max, B_stride);
|
const StridedViewBF B_storage_view(B_storage, kc_max, B_stride);
|
||||||
|
|
@ -1098,8 +1098,8 @@ class MMPerPackage {
|
||||||
// Already in parallel section, hence no `kParM`, and
|
// Already in parallel section, hence no `kParM`, and
|
||||||
// `kDirect` is only used with `kNT_MT`.
|
// `kDirect` is only used with `kNT_MT`.
|
||||||
HWY_DASSERT(out_ == MMOut::kCopy);
|
HWY_DASSERT(out_ == MMOut::kCopy);
|
||||||
MMZone fill_zone;
|
MMZone fill_mm_zone;
|
||||||
fill_zone.MaybeEnter(worker, fill_zone_id, args_);
|
fill_mm_zone.MaybeEnter(worker, fill_zone, args_);
|
||||||
MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C_rows);
|
MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C_rows);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -1116,13 +1116,13 @@ class MMPerPackage {
|
||||||
const size_t NBF = hn::Lanes(dbf);
|
const size_t NBF = hn::Lanes(dbf);
|
||||||
static_assert(hwy::IsSameEither<TA, BF16, float>(), "Can seek");
|
static_assert(hwy::IsSameEither<TA, BF16, float>(), "Can seek");
|
||||||
|
|
||||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DecompressA");
|
static const auto zone = PROFILER_ADD_ZONE("MM.DecompressA");
|
||||||
|
|
||||||
const auto do_range = [&](const IndexRange& range_M,
|
const auto do_range = [&](const IndexRange& range_M,
|
||||||
const IndexRange& range_K,
|
const IndexRange& range_K,
|
||||||
size_t worker) HWY_ATTR {
|
size_t worker) HWY_ATTR {
|
||||||
MMZone zone;
|
MMZone mm_zone;
|
||||||
zone.MaybeEnter(worker, zone_id, args_);
|
mm_zone.MaybeEnter(worker, zone, args_);
|
||||||
|
|
||||||
const size_t col0 = range_K.begin();
|
const size_t col0 = range_K.begin();
|
||||||
const size_t cols = range_K.Num();
|
const size_t cols = range_K.Num();
|
||||||
|
|
@ -1280,14 +1280,14 @@ struct MMImpl {
|
||||||
RowPtrs<TC> C_rows, const MMArgs& args,
|
RowPtrs<TC> C_rows, const MMArgs& args,
|
||||||
const MMConfig& config) {
|
const MMConfig& config) {
|
||||||
PROFILER_ZONE("MM.DoMatMul");
|
PROFILER_ZONE("MM.DoMatMul");
|
||||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DoMatMul.PerPkg");
|
static const auto zone = PROFILER_ADD_ZONE("MM.DoMatMul.PerPkg");
|
||||||
|
|
||||||
if constexpr (kMaxPackages > 1) {
|
if constexpr (kMaxPackages > 1) {
|
||||||
// Outermost loop: static NUMA-aware partition of B rows across packages.
|
// Outermost loop: static NUMA-aware partition of B rows across packages.
|
||||||
args.env->parallel.ForPkg(
|
args.env->parallel.ForPkg(
|
||||||
args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) {
|
args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) {
|
||||||
MMZone matmul_zone;
|
MMZone mm_zone;
|
||||||
matmul_zone.MaybeEnter(pkg_idx, zone_id, args);
|
mm_zone.MaybeEnter(pkg_idx, zone, args);
|
||||||
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
||||||
MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows);
|
MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows);
|
||||||
});
|
});
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue