#include "util/zones.h" #include #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/profiler.h" namespace gcpp { namespace { const char* ZoneName(Zones zone) { switch (zone) { case Zones::kFlashAttentionFlashAttention: return "FlashAttention.FlashAttention"; case Zones::kFlashAttentionInclusive: return "FlashAttention.Inclusive"; case Zones::kFlashAttentionRmsNormAndPositionalEncoding: return "FlashAttention.RMSNormAndPositionalEncoding"; case Zones::kFlashAttentionSingleFlashAttention: return "FlashAttention.SingleFlashAttention"; case Zones::kFlashAttentionTileFlashAttention: return "FlashAttention.TileFlashAttention"; case Zones::kFlashAttentionTileFlashAttention4: return "FlashAttention.TileFlashAttention4"; case Zones::kFlashAttentionTransposeQ: return "FlashAttention.TransposeQ"; case Zones::kGenActivation: return "Gen.Activation"; case Zones::kGenActivationFused: return "Gen.ActivationFused"; case Zones::kGenAttention: return "Gen.Attention"; case Zones::kGenAttentionComputeQKV: return "Gen.Attention.ComputeQKV"; case Zones::kGenAttentionDotSoftmaxWeightedSumInclusive: return "Gen.Attention.DotSoftmaxWeightedSumInclusive"; case Zones::kGenAttentionDotSoftmaxWeightedSumPar: return "Gen.Attention.DotSoftmaxWeightedSum.par"; case Zones::kGenAttentionQDotK: return "Gen.Attention.QDotK"; case Zones::kGenAttentionSumHeads: return "Gen.Attention.SumHeads"; case Zones::kGenEmbed: return "Gen.Embed"; case Zones::kGenEmbeddingMatmul: return "Gen.EmbeddingMatmul"; case Zones::kGenFFW: return "Gen.FFW"; case Zones::kGenSampleTop1: return "Gen.SampleTop1"; case Zones::kGenSampleTopK: return "Gen.SampleTopK"; case Zones::kGenStats: return "Gen.Stats"; case Zones::kMMDecompressA: return "MM.DecompressA"; case Zones::kMMDispatch: return "MM.Dispatch"; case Zones::kMMMatMul: return "MM.MatMul"; case Zones::kMMNT_K: return "MM.NT_K"; case Zones::kMMNT_MT_K: return "MM.NT_MT_K"; case Zones::kMMNT_MT: return "MM.NT_MT"; case Zones::kMMNT: return "MM.NT"; case Zones::kMMTwoMatMul: return "MM.TwoMatMul"; case Zones::kOpsAddFrom: return "Ops.AddFrom"; case Zones::kOpsLogitsSoftCap: return "Ops.LogitsSoftCap"; // case Zones::kOpsMulByConst: // removed due to overhead // case Zones::kOpsMulByConstAndAdd: // removed due to overhead // case Zones::kOpsMulByConstAndAddTile: // removed due to overhead // case Zones::kOpsMulByConstAndAddTile4: // removed due to overhead // case Zones::kOpsMulByConstAndAddVector: // removed due to overhead case Zones::kOpsMulByConstTo: return "Ops.MulByConstTo"; case Zones::kOpsRmsNorm: return "Ops.RMSNorm"; case Zones::kOpsRmsNormInplace: return "Ops.RMSNormInplace"; case Zones::kOpsRmsNormMul: return "Ops.RMSNormMul"; case Zones::kOpsRope: return "Ops.Rope"; case Zones::kOpsRopeAndMulBy: return "Ops.RopeAndMulBy"; case Zones::kOpsSoftmax: return "Ops.Softmax"; case Zones::kStartupWeightsReadAllToBF16: return "Startup.Weights.ReadAllToBF16"; case Zones::kStartupWeightsReadBatches: return "Startup.Weights.ReadBatches"; default: HWY_ABORT("Invalid zone %d.", static_cast(zone)); } } hwy::ProfilerFlags ZoneFlags(Zones zone) { switch (zone) { case Zones::kFlashAttentionInclusive: case Zones::kGenAttention: case Zones::kGenAttentionComputeQKV: case Zones::kGenAttentionDotSoftmaxWeightedSumInclusive: case Zones::kGenAttentionSumHeads: case Zones::kGenEmbed: case Zones::kGenEmbeddingMatmul: case Zones::kGenFFW: return hwy::ProfilerFlags::kInclusive; default: return hwy::ProfilerFlags::kDefault; } } const char* CallerName(Callers caller) { switch (caller) { case Callers::kActivationBatched: return "ActivationBatched"; case Callers::kAllocateAndBindAll: return "AllocateAndBindAll"; case Callers::kAttComputeQKV: return "Att.ComputeQKV"; case Callers::kAttDotSoftmaxWeightedSum: return "Att.DotSoftmaxWeightedSum"; case Callers::kBlobWriter: return "BlobWriter"; case Callers::kCompress: return "Compress"; case Callers::kFixupWeights: return "FixupWeights"; case Callers::kFlashAttention: return "FlashAttention"; case Callers::kFlashRMSNormAndPositionalEncoding: return "Flash.RMSNormAndPositionalEncoding"; case Callers::kFlashTransposeQ: return "Flash.TransposeQ"; case Callers::kMMClusterForMC: return "MM.ClusterForMC"; case Callers::kMMClusterForMCNC: return "MM.ClusterForMCNC"; case Callers::kMMClusterForN: return "MM.ClusterForN"; case Callers::kMMHierForMC: return "MM.HierForMC"; case Callers::kMMHierForMCNC: return "MM.HierForMCNC"; case Callers::kMMHierForN: return "MM.HierForN"; case Callers::kOpsAddFromBatched: return "Ops.AddFromBatched"; case Callers::kOpsMaybeLogitsSoftCapBatched: return "Ops.MaybeLogitsSoftCapBatched"; case Callers::kOpsRMSNormBatched: return "Ops.RMSNormBatched"; case Callers::kOpsRMSNormInplaceBatched: return "Ops.RMSNormInplaceBatched"; case Callers::kReadAllToBF16: return "ReadAllToBF16"; case Callers::kReadBatches: return "ReadBatches"; case Callers::kSampleAndStream: return "SampleAndStream"; case Callers::kTensorStats: return "TensorStats"; case Callers::kTest: // only for unit tests. return "Test-only!"; case Callers::kTunePool: return "TunePool"; case Callers::kVitDotSoftmax1: return "Vit.DotSoftmax1"; case Callers::kVitDotSoftmax2: return "Vit.DotSoftmax2"; case Callers::kVitDotSoftmax3: return "Vit.DotSoftmax3"; case Callers::kVitDotSoftmax4: return "Vit.DotSoftmax4"; default: HWY_ABORT("Invalid caller %d.", static_cast(caller)); } } } // namespace ProfilerZones::ProfilerZones(hwy::Profiler& profiler) { for (size_t i = 0;; ++i) { const Zones zone = static_cast(i); if (zone == Zones::kNumZones) break; handles_[i] = profiler.AddZone(ZoneName(zone), ZoneFlags(zone)); } } PoolCallers::PoolCallers() { for (size_t i = 0;; ++i) { const Callers caller = static_cast(i); if (caller == Callers::kNumCallers) break; callers_[i] = hwy::ThreadPool::AddCaller(CallerName(caller)); } } } // namespace gcpp