diff --git a/compression/types.h b/compression/types.h index 667265a..661bc42 100644 --- a/compression/types.h +++ b/compression/types.h @@ -191,12 +191,13 @@ constexpr bool SupportsPointerArithmetic() { return !IsNuqStream(); } -// Tensor types for loading weights. -enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64 }; +// Tensor types for loading weights. Not all of these are supported weight +// types, some are only used for `Activations`. +enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kU32, kU64 }; // These are used in `ModelConfig.Specifier`, hence the strings will not // change, though new ones may be added. -static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", - "sfp", "nuq", "f64"}; +static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp", + "nuq", "f64", "u32", "u64"}; static constexpr size_t kNumTypes = sizeof(kTypeStrings) / sizeof(kTypeStrings[0]); static constexpr size_t kTypeBits[] = { @@ -206,6 +207,8 @@ static constexpr size_t kTypeBits[] = { 8 * sizeof(SfpStream), 4 /* NuqStream, actually 4.5 */, 8 * sizeof(double), + 8 * sizeof(uint32_t), + 8 * sizeof(uint64_t), }; static inline bool EnumValid(Type type) { @@ -226,6 +229,10 @@ Type TypeEnum() { return Type::kNUQ; } else if constexpr (hwy::IsSame()) { return Type::kF64; + } else if constexpr (hwy::IsSame()) { + return Type::kU32; + } else if constexpr (hwy::IsSame()) { + return Type::kU64; } else { HWY_DASSERT(false); return Type::kUnknown; diff --git a/gemma/activations.h b/gemma/activations.h index 14994d3..cd714ae 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -21,14 +21,14 @@ #include #include +#include #include -#include "gemma/configs.h" // ModelConfig -#include "ops/matmul.h" // MatMulEnv -#include "ops/ops.h" // CreateInvTimescale -#include "util/allocator.h" // Allocator -#include "util/basics.h" // BF16 -#include "util/mat.h" // MatStorageT +#include "gemma/configs.h" // ModelConfig +#include "ops/ops.h" // CreateInvTimescale +#include "util/basics.h" // BF16 +#include "util/mat.h" // MatStorageT +#include "util/threading_context.h" namespace gcpp { @@ -150,24 +150,28 @@ struct AttentionActivations { struct Activations { Activations(const ModelConfig& config, size_t batch_size, size_t seq_len, - const Allocator& allocator, + ThreadingContext& ctx, std::vector>& row_ptrs) : layer_config(config.layer_configs[0]), - x(MatFactory("x", batch_size, config.model_dim, allocator)), - x_bf(MatFactory("x_bf", batch_size, config.model_dim, allocator)), - logits(MatFactory("logits", batch_size, config.vocab_size, allocator)), + x(MatFactory("x", batch_size, config.model_dim, ctx.allocator)), + x_bf(MatFactory("x_bf", batch_size, config.model_dim, ctx.allocator)), + logits( + MatFactory("logits", batch_size, config.vocab_size, ctx.allocator)), pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size, - config.model_dim, allocator)), - C1(MatFactory("C1", batch_size, layer_config.ff_hidden_dim, allocator)), - C2(MatFactory("C2", batch_size, layer_config.ff_hidden_dim, allocator)), - ffw_out(MatFactory("ffw_out", batch_size, config.model_dim, allocator)), + config.model_dim, ctx.allocator)), + C1(MatFactory("C1", batch_size, layer_config.ff_hidden_dim, + ctx.allocator)), + C2(MatFactory("C2", batch_size, layer_config.ff_hidden_dim, + ctx.allocator)), + ffw_out( + MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)), - attention(config, layer_config, batch_size, seq_len, allocator, + attention(config, layer_config, batch_size, seq_len, ctx.allocator, row_ptrs), griffin(config, config.model == Model::GRIFFIN_2B ? batch_size : 0, - allocator) { + ctx.allocator) { HWY_ASSERT(batch_size != 0); // For MatMul outputs, precompute their row pointers. diff --git a/gemma/attention.cc b/gemma/attention.cc index bd76329..21e5019 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -19,7 +19,6 @@ #include #include "compression/types.h" // GEMMA_DISABLED_TARGETS -#include "util/threading_context.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS @@ -29,7 +28,7 @@ #include "gemma/gemma.h" #include "gemma/weights.h" #include "util/threading.h" -#include "hwy/contrib/thread_pool/thread_pool.h" +#include "util/threading_context.h" #include "hwy/profiler.h" // Compiles this file for multiple architectures via "foreach_target.h", to @@ -234,8 +233,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, { PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin"); // Full parallelism is helpful, kAcrossClusters is insufficient. - NestedParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, - ctx.pools, func); + HierarchicalParallelFor( + num_tokens * div_qbatch.GetDivisor() * layer_config.heads, ctx.pools, + func); } } @@ -285,9 +285,9 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, // Apply positional encodings for K. // Note that 2D parallelism is not worth the fork/join overhead because the // tasks are very lightweight. - env.ctx.pools.Pool(0).Run( - 0, kv_heads * num_interleaved, - [&](uint64_t task, size_t thread) HWY_ATTR { + ParallelFor( + ParallelismStrategy::kFlat, kv_heads * num_interleaved, env.ctx, + /*cluster_idx=*/0, [&](size_t task, size_t worker) HWY_ATTR { const size_t head = task % kv_heads; const size_t interleaved_idx = task / kv_heads; const size_t qi = div_qbatch.Remainder(interleaved_idx); @@ -308,12 +308,12 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, if (layer.key_norm_scale.HasPtr()) { CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) { RMSNormInplace(weights_t->PackedScale1(), kv_f32, qkv_dim, - env.ctx.profiler, thread); + env.ctx.profiler, worker); }); } PositionalEncodingQK(kv_f32, layer_idx, layer, activations, - env.ctx.profiler, thread, pos); + env.ctx.profiler, worker, pos); CompressPerThread tls; Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); }); diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 80ec0ee..cb7ae6a 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -69,9 +69,9 @@ template void ActivationBatched( ActivationType activation, Mat& c1, ThreadingContext& ctx, size_t cluster_idx = 0, - ParallelismType parallelism = ParallelismType::kAcrossClusters) { + ParallelismStrategy parallelism = ParallelismStrategy::kFlat) { using T = typename Mat::T; - ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx, + ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, [&](uint64_t task, size_t worker) { // Cast to correct type so type deduction works. Activation(activation, c1.Row(task), @@ -84,16 +84,16 @@ template HWY_NOINLINE void ActivationBatched( ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx, size_t cluster_idx = 0, - ParallelismType parallelism = ParallelismType::kAcrossClusters) { + ParallelismStrategy parallelism = ParallelismStrategy::kFlat) { HWY_DASSERT(c1.SameShape(*c2)); if (c2 && c2->HasPtr()) { - ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx, + ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, [&](uint64_t task, size_t worker) { Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), ctx.profiler, worker); }); } else { // No multiplier - ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx, + ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, [&](uint64_t task, size_t worker) { Activation(activation, c1.Row(task), static_cast(nullptr), diff --git a/gemma/gemma.cc b/gemma/gemma.cc index fc1f238..a0949fe 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -574,7 +574,7 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, const WeightsPtrs& weights, KVCache& kv_cache, MatMulEnv& env, TimingInfo& timing_info) { Activations activations(config, runtime_config.prefill_tbatch_size, - kv_cache.SeqLen(), env.ctx.allocator, env.row_ptrs); + kv_cache.SeqLen(), env.ctx, env.row_ptrs); AllQueries all_queries(prompt, pos, prefix_end, hwy::Span(&kv_cache, 1)); @@ -592,7 +592,7 @@ void GenerateBatchT(const ModelConfig& config, const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size, runtime_config.prefill_tbatch_size); Activations activations(config, max_batch_size, - all_queries[0].kv_cache.SeqLen(), env.ctx.allocator, + all_queries[0].kv_cache.SeqLen(), env.ctx, env.row_ptrs); for (size_t start = 0; start < all_queries.NumQueries(); @@ -616,8 +616,8 @@ void GenerateImageTokensT(const ModelConfig& config, const size_t num_tokens = vit_config.max_seq_len; prefill_runtime_config.prefill_tbatch_size = num_tokens / (vit_config.pool_dim * vit_config.pool_dim); - Activations prefill_activations(vit_config, num_tokens, num_tokens, - env.ctx.allocator, env.row_ptrs); + Activations prefill_activations(vit_config, num_tokens, num_tokens, env.ctx, + env.row_ptrs); // Weights are for the full PaliGemma model, not just the ViT part. PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, prefill_activations, env); diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index 04d535e..1be2bed 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -111,8 +111,8 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { // Ensure usage conditions are set before autotuning. Both binding and // spinning may materially affect the choice of config. No harm in calling // BindB/C if there is a single package: they will be a no-op. - BindB(b_trans, sizeof(TC), env.parallel); - BindC(C, env.parallel); + BindB(env.ctx, b_trans, sizeof(TC)); + BindC(env.ctx, C); C.AllocateAndAttachRowPtrs(env.row_ptrs); Tristate use_spinning = Tristate::kDefault; @@ -160,10 +160,10 @@ void BenchAllMatMul() { ctx.pools.PinString()); MatMulEnv env(ctx); - for (size_t batch_size : {1, 4, 128, 512}) { + for (size_t batch_size : {128, 512}) { constexpr bool kAdd = false; - BenchMatMul(batch_size, 24576, 3072, kAdd, env); - BenchMatMul(batch_size, 3072, 24576, kAdd, env); + BenchMatMul(batch_size, 24576, 3072, kAdd, env); + BenchMatMul(batch_size, 3072, 24576, kAdd, env); } PROFILER_PRINT_RESULTS(); diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 152e6ce..c915b14 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -565,46 +565,204 @@ class MMKernel { } }; -// Called on the main thread with the entire N range, or by each package with -// a static partition of N. This class contains several variants of the -// outer M/N/K loops, and calls `A2C0` which loops over the inner KC and MC. -// Its member variables avoid long argument lists in Do*(). -class MMPerPackage { - public: - MMPerPackage(const Extents2D A, const MMArgs& args, const MMConfig& config, - size_t pkg_idx, size_t cluster_idx, const IndexRange& range_np) - : args_(args), - pkg_idx_(pkg_idx), - cluster_idx_(cluster_idx), - range_np_(range_np), - mr_(config.MR()), - ranges_mc_(config.RangesOfMC(A.rows)), - ranges_kc_(config.RangesOfKC(A.cols)), - ranges_nc_(config.RangesOfNC(range_np)), - order_(config.Order()), - inner_tasks_(config.InnerTasks()), - line_bytes_(args.env->ctx.allocator.LineBytes()) {} +// Miscellaneous stateless helper functions. +struct MMImpl { + // Returns existing entry for the given key or -1. + static HWY_INLINE intptr_t IndexOfKey(MMKeys::Key key, const MMKeys& keys) { + const hwy::Span all_keys = keys.Keys(); + // TODO: SIMD scan + for (size_t i = 0; i < all_keys.size(); ++i) { + if (all_keys[i] == key) return static_cast(i); + } + return -1; + } - // B and maybe A are decompressed several call layers lower, but not all - // member functions depend on TA/TB, so pass them as an argument instead of - // templating the class. - template - HWY_NOINLINE void operator()(const MMParallelPolicyT& parallel_policy, - const MatPtrT& A, const MatPtrT& B, - RowPtrs C_rows) const { + static size_t Worker(const MMArgs& args) { + return args.options.cluster_idx * + args.env->ctx.pools.MaxWorkersPerCluster(); + } + + template + static void DispatchParallelism(ParallelismStrategy parallelism, + const Func& func) { + switch (parallelism) { + case ParallelismStrategy::kHierarchical: + return func(MMParallelHierarchical()); + case ParallelismStrategy::kNone: + return func(MMParallelNone()); + case ParallelismStrategy::kWithinCluster: + return func(MMParallelWithinCluster()); + default: + HWY_UNREACHABLE; + } + } + + // Decompresses all `M x K` from `A` into padded BF16 `A_view`. + static HWY_NOINLINE void DoDecompressA(const MatPtrT& A, + const StridedViewBF A_view, + MMParA par_a, const MMArgs& args) { + const IndexRange all_M(0, A.Rows()); + const IndexRange all_K(0, A.Cols()); + HWY_DASSERT(all_K.Num() == A_view.Cols()); + + const hn::ScalableTag dbf; + const size_t NBF = hn::Lanes(dbf); + + static const auto zone = args.env->ctx.profiler.AddZone("MM.DecompressA"); + + const auto do_range = + [&](const IndexRange& range_M, const IndexRange& range_K, size_t worker) + HWY_ATTR { + MMZone mm_zone; + mm_zone.MaybeEnter(worker, zone, args); + + const size_t col0 = range_K.begin(); + const size_t cols = range_K.Num(); + // Must be a vector multiple, or the last range before row + // padding, otherwise `DecompressAndZeroPad` overwrites neighbors. + HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols()); + for (size_t row_a : range_M) { + const PackedSpan from = + MakeSpan(A.Row(row_a) + col0, cols); + BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0; + DecompressAndZeroPad(dbf, from, 0, to, cols); + // Verify that we zero-padded. + if constexpr (HWY_IS_DEBUG_BUILD) { + for (size_t i = cols; i < hwy::RoundUpTo(cols, NBF); ++i) { + HWY_DASSERT(hwy::ConvertScalarTo(to[i]) == 0.0f); + } + } + } + }; + + switch (par_a) { + case MMParA::kNone: + do_range(all_M, all_K, MMImpl::Worker(args)); + break; + + case MMParA::kK1: + case MMParA::kK2: + case MMParA::kK4: { + const size_t inner_tasks = static_cast(par_a); + // At least one vector, otherwise DecompressAndZeroPad will add + // padding, which might overwrite neighboring tasks. Also a whole cache + // line to avoid false sharing. + const size_t multiple_K = HWY_MAX(NBF, args.line_bytes / sizeof(BF16)); + + DispatchParallelism( + args.options.parallelism, [&](const auto& parallel) { + parallel.ForNP(args.env->ctx, all_K, multiple_K, inner_tasks, + args.options.cluster_idx, + [&](const IndexRange& range_K, size_t worker) { + do_range(all_M, range_K, worker); + }); + }); + break; + } + case MMParA::kM: + DispatchParallelism( + args.options.parallelism, [&](const auto& parallel) { + parallel.ForRangeMC( + args.env->ctx, all_M, args.options.cluster_idx, + [&](size_t row_a, size_t worker) { + do_range(IndexRange(row_a, row_a + 1), all_K, worker); + }); + }); + break; + } + } + + // Autotuning wrapper for `DoDecompressA`. + static HWY_INLINE void DecompressA(const MatPtrT& A, + const StridedViewBF A_view, + const MMArgs& args) { + MMAutoTune& autotune = args.per_key->autotune_par_a[/*pkg_idx=*/0]; + + if (HWY_LIKELY(autotune.Best())) { + return DoDecompressA(A, A_view, *autotune.Best(), args); + } + + // First call: generate candidates. + if (HWY_UNLIKELY(!autotune.HasCandidates())) { + const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM; + std::vector candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4, + other}; + autotune.SetCandidates(candidates); + } + + const MMParA& par_a = autotune.NextConfig(); + const uint64_t t0 = hwy::timer::Start(); + DoDecompressA(A, A_view, par_a, args); + const uint64_t t1 = + args.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); + const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0); + if (HWY_UNLIKELY(args.env->print_measurement && autotune.ShouldPrint())) { + fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a), + static_cast(min_elapsed) / + hwy::platform::InvariantTicksPerSecond() * 1E6); + } + } + + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. + template + static StridedView View(const MatPtrT& AB, size_t r, size_t c, + size_t cols) { + HWY_DASSERT(c < AB.Cols()); + HWY_DASSERT(cols <= AB.Cols() - c); + return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); + } + + template + static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT& A, + const MMArgs& args) { if constexpr (IsBF16()) { // We can use a view, regardless of columns/padding, because `LoopKC` // supports non-vector multiples. - DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), B, C_rows); + return View(A, 0, 0, A.Cols()); } else { // Always decompress. To reduce code size/compile time, we no longer // support a separate F32 kernel; most A are already BF16. const StridedViewBF A_view = - args_.env->storage[cluster_idx_].A(pkg_idx_, A.Extents()); - DecompressA(A, A_view); - DispatchOrder(parallel_policy, A_view, B, C_rows); + args.env->storage[args.options.cluster_idx].A(/*pkg_idx=*/0, + A.Extents()); + DecompressA(A, A_view, args); + return A_view; } } +}; + +// Contains several variants of the outer M/N/K loops, and calls `A2C0` which +// loops over the inner KC and MC. Member variables avoid long argument lists. +class MMState { + public: + MMState(const Extents2D A, const MMArgs& args, const MMConfig& config) + : args_(args), + range_np_(args.per_key->ranges_np.Range(/*pkg_idx=*/0)), + mr_(config.MR()), + ranges_mc_(config.RangesOfMC(A.rows)), + ranges_kc_(config.RangesOfKC(A.cols)), + ranges_nc_(config.RangesOfNC(range_np_)), + order_(config.Order()), + inner_tasks_(config.InnerTasks()) { + HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1); + } + + // Called from `MatMul` from two places: either with the next autotune config, + // or with the best config. + template + HWY_NOINLINE void DispatchParallelism(const StridedViewBF A, + const MatPtrT& B, + RowPtrs C_rows) const { + /* Disabled due to unknown thread-safety issue: + static const auto zone = + args_.env->ctx.profiler.AddZone("MM.DispatchParallelism"); + PROFILER_ZONE3(args_.env->ctx.profiler, MMImpl::Worker(args_), zone); + */ + + MMImpl::DispatchParallelism( + args_.options.parallelism, + [&](const auto& parallel) { DispatchOrder(parallel, A, B, C_rows); }); + } private: // Compute size of per-worker storage for `kNR` row ranges of B. Stack @@ -616,40 +774,32 @@ class MMPerPackage { // Granularity of `ForNP`. B rows produce C columns, so we // want a multiple of the line size to prevent false sharing. size_t MultipleNP(size_t sizeof_TC) const { - return HWY_MAX(kNR, line_bytes_ / sizeof_TC); + return HWY_MAX(kNR, args_.line_bytes / sizeof_TC); } - // Returns 2D subrange whose top-left is `r, c` and width is `cols`. - template - static StridedView View(const MatPtrT& AB, size_t r, size_t c, - size_t cols) { - HWY_DASSERT(c < AB.Cols()); - HWY_DASSERT(cols <= AB.Cols() - c); - return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); - } - - // `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`. - template - HWY_INLINE void DispatchOrder(const MMParallelPolicyT& parallel_policy, - const StridedView A, const MatPtrT& B, - RowPtrs C_rows) const { + // B is decompressed several call layers lower, but not all member functions + // depend on `TB`, so pass it as an argument instead of templating the class. + template + HWY_NOINLINE void DispatchOrder(const ParallelT& parallel_policy, + const StridedViewBF A, const MatPtrT& B, + RowPtrs C_rows) const { switch (order_) { case MMOrder::kNT: - return DoNT(parallel_policy, A, B, C_rows); + return DoNT(parallel_policy, A, B, C_rows); case MMOrder::kNT_K: - return DoNT_K(parallel_policy, A, B, C_rows); + return DoNT_K(parallel_policy, A, B, C_rows); case MMOrder::kNT_MT: - return DoNT_MT(parallel_policy, A, B, C_rows); + return DoNT_MT(parallel_policy, A, B, C_rows); case MMOrder::kNT_MT_K: - return DoNT_MT_K(parallel_policy, A, B, C_rows); + return DoNT_MT_K(parallel_policy, A, B, C_rows); default: HWY_UNREACHABLE; } } // Single M and K ranges, parallel N. Fills all of C directly. - template - HWY_INLINE void DoNT(MMParallelPolicyT, const StridedView A, + template + HWY_INLINE void DoNT(ParallelT parallel, const StridedViewBF A, const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT"); HWY_DASSERT(ranges_mc_.NumTasks() == 1); @@ -657,14 +807,14 @@ class MMPerPackage { const IndexRange& range_M = ranges_mc_.Range(0); const IndexRange& range_K = ranges_kc_.Range(0); const size_t K = range_K.Num(); - const StridedView A_view = A.View(range_M.begin(), 0, K); + const StridedViewBF A_view = A.View(range_M.begin(), 0, K); const size_t B_stride = - Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_); + Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes); // Similar to `loop_nc` below, but here we hoisted `A_view`. - MMParallelPolicyT::ForNP( + parallel.ForNP( args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, - pkg_idx_, cluster_idx_, + args_.options.cluster_idx, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args_); @@ -683,8 +833,8 @@ class MMPerPackage { } // Single M range, parallel N, sequential K. Sets C, then accumulates. - template - HWY_INLINE void DoNT_K(MMParallelPolicyT, const StridedView A, + template + HWY_INLINE void DoNT_K(ParallelT parallel, const StridedViewBF A, const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K"); HWY_DASSERT(ranges_mc_.NumTasks() == 1); @@ -697,11 +847,11 @@ class MMPerPackage { const IndexRange& range_nc, auto out_tag) HWY_ATTR { const size_t kc = range_kc.Num(); - const StridedView A_view = + const StridedViewBF A_view = A.View(range_mc.begin(), range_kc.begin(), kc); const StridedViewBF B_storage_view( B_storage, kc, - Stride(MatPadding::kOdd, kc, sizeof(BF16), line_bytes_)); + Stride(MatPadding::kOdd, kc, sizeof(BF16), args_.line_bytes)); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { @@ -711,9 +861,9 @@ class MMPerPackage { } }; - MMParallelPolicyT::ForNP( + parallel.ForNP( args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, - pkg_idx_, cluster_idx_, + args_.options.cluster_idx, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args_); @@ -733,26 +883,26 @@ class MMPerPackage { // Parallel loops over mc/nc blocks of M/range_np, single K. // Fills `mc x nc` sections of C directly, in parallel. - template - HWY_INLINE void DoNT_MT(MMParallelPolicyT, const StridedView A, + template + HWY_INLINE void DoNT_MT(ParallelT parallel, const StridedViewBF A, const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT"); HWY_DASSERT(ranges_kc_.NumTasks() == 1); const IndexRange& range_K = ranges_kc_.Range(0); const size_t K = range_K.Num(); const size_t B_stride = - Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_); + Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes); // Sequential loop over NC/MC/KC, similar to `loop_nc` below // except for the profiler strings and `out_tag`. - MMParallelPolicyT::ForRangesMC_NC( - args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, cluster_idx_, + parallel.ForRangesMC_NC( + args_.env->ctx, ranges_mc_, ranges_nc_, args_.options.cluster_idx, [&](const IndexRange& range_mc, const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args_); - const StridedView A_view = A.View(range_mc.begin(), 0, K); + const StridedViewBF A_view = A.View(range_mc.begin(), 0, K); HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS const StridedViewBF B_storage_view(B_storage, K, B_stride); @@ -768,14 +918,14 @@ class MMPerPackage { // Parallel loops over mc/nc blocks of M/range_np, sequential K. // Fills `mc x nc` sections of `partial`, then `C`, in parallel. - template - HWY_INLINE void DoNT_MT_K(MMParallelPolicyT, const StridedView A, + template + HWY_INLINE void DoNT_MT_K(ParallelT parallel, const StridedViewBF A, const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K"); const size_t kc_max = ranges_kc_.TaskSize(); HWY_DASSERT(kc_max <= MMStorage::kMaxKC); const size_t B_stride = - Stride(MatPadding::kOdd, kc_max, sizeof(BF16), line_bytes_); + Stride(MatPadding::kOdd, kc_max, sizeof(BF16), args_.line_bytes); // Sequential loop over NC/MC/KC, for when the M/N loops are // already parallel. This is B3A2C0 in MOMMS terminology: we read // `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`. @@ -785,7 +935,7 @@ class MMPerPackage { const IndexRange& range_nc, auto out_tag) HWY_ATTR { const size_t kc = range_kc.Num(); - const StridedView A_view = + const StridedViewBF A_view = A.View(range_mc.begin(), range_kc.begin(), kc); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); @@ -795,8 +945,8 @@ class MMPerPackage { C_rows); } }; // loop_nc - MMParallelPolicyT::ForRangesMC_NC( - args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, cluster_idx_, + parallel.ForRangesMC_NC( + args_.env->ctx, ranges_mc_, ranges_nc_, args_.options.cluster_idx, [&](const IndexRange& range_mc, const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; @@ -816,106 +966,6 @@ class MMPerPackage { }); } - // Decompresses all `M x K` from `A` into padded BF16 `A_view`. - template - HWY_NOINLINE void DoDecompressA(const MatPtrT& A, - const StridedViewBF A_view, - MMParA par_a) const { - const IndexRange all_M(0, A.Rows()); - const IndexRange all_K(0, A.Cols()); - HWY_DASSERT(all_K.Num() == A_view.Cols()); - - const hn::ScalableTag dbf; - const size_t NBF = hn::Lanes(dbf); - - static const auto zone = args_.env->ctx.profiler.AddZone("MM.DecompressA"); - - const auto do_range = [&](const IndexRange& range_M, - const IndexRange& range_K, - size_t worker) HWY_ATTR { - MMZone mm_zone; - mm_zone.MaybeEnter(worker, zone, args_); - - const size_t col0 = range_K.begin(); - const size_t cols = range_K.Num(); - // Must be a vector multiple, or the last range before row padding, - // otherwise `DecompressAndZeroPad` overwrites neighbors. - HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols()); - for (size_t row_a : range_M) { - const PackedSpan from = - MakeSpan(A.Row(row_a) + col0, cols); - BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0; - DecompressAndZeroPad(dbf, from, 0, to, cols); - // Verify that we zero-padded. - if constexpr (HWY_IS_DEBUG_BUILD) { - for (size_t i = cols; i < hwy::RoundUpTo(cols, NBF); ++i) { - HWY_DASSERT(hwy::ConvertScalarTo(to[i]) == 0.0f); - } - } - } - }; - - switch (par_a) { - case MMParA::kNone: - do_range(all_M, all_K, /*worker=*/0); - break; - case MMParA::kK1: - case MMParA::kK2: - case MMParA::kK4: { - const size_t inner_tasks = static_cast(par_a); - // At least one vector, otherwise DecompressAndZeroPad will add - // padding, which might overwrite neighboring tasks. Also a whole cache - // line to avoid false sharing. - const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16)); - - MMParallelPolicyT::ForNP(args_.env->ctx, all_K, multiple_K, inner_tasks, - pkg_idx_, cluster_idx_, - [&](const IndexRange& range_K, size_t worker) { - do_range(all_M, range_K, worker); - }); - break; - } - case MMParA::kM: - MMParallelPolicyT::ForRangeMC( - args_.env->ctx, all_M, pkg_idx_, cluster_idx_, - [&](size_t row_a, size_t worker) { - do_range(IndexRange(row_a, row_a + 1), all_K, worker); - }); - break; - } - } - - // Autotuning wrapper for `DoDecompressA`. - template - HWY_INLINE void DecompressA(const MatPtrT& A, - const StridedViewBF A_view) const { - MMAutoTune& autotune = args_.per_key->autotune_par_a[pkg_idx_]; - - if (HWY_LIKELY(autotune.Best())) { - return DoDecompressA(A, A_view, *autotune.Best()); - } - - // First call: generate candidates. - if (HWY_UNLIKELY(!autotune.HasCandidates())) { - const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM; - std::vector candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4, - other}; - autotune.SetCandidates(candidates); - } - - const MMParA& par_a = autotune.NextConfig(); - const uint64_t t0 = hwy::timer::Start(); - DoDecompressA(A, A_view, par_a); - const uint64_t t1 = - args_.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); - const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0); - if (HWY_UNLIKELY(args_.env->print_measurement && autotune.ShouldPrint())) { - fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a), - static_cast(min_elapsed) / - hwy::platform::InvariantTicksPerSecond() * 1E6); - } - } - // Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0, // col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL` // thanks to its large table lookups, and less so on other targets. @@ -928,7 +978,7 @@ class MMPerPackage { // Neither A nor B require padding because `LoopKC` handles remainders. if constexpr (hwy::IsSame()) { - return View(B, row_b, range_kc.begin(), range_kc.Num()); + return MMImpl::View(B, row_b, range_kc.begin(), range_kc.Num()); } const PackedSpan B_span = B.PaddedSpan(); @@ -951,8 +1001,6 @@ class MMPerPackage { } const MMArgs args_; // copy for locality - const size_t pkg_idx_; - const size_t cluster_idx_; // 0 for sequential and nested. const IndexRange range_np_; // From MMConfig: @@ -962,52 +1010,7 @@ class MMPerPackage { const IndexRangePartition ranges_nc_; const MMOrder order_; const size_t inner_tasks_; - const size_t line_bytes_; -}; // MMPerPackage - -// Stateless, wraps member functions. -struct MMImpl { - // Returns existing entry for the given key or -1. - static HWY_INLINE intptr_t IndexOfKey(MMKeys::Key key, const MMKeys& keys) { - const hwy::Span all_keys = keys.Keys(); - // TODO: SIMD scan - for (size_t i = 0; i < all_keys.size(); ++i) { - if (all_keys[i] == key) return static_cast(i); - } - return -1; - } - - // Called from `MatMul` from two places: either with the next autotune config, - // or with the best config. - template - static HWY_NOINLINE void DoMatMul(const MatPtrT& A, const MatPtrT& B, - RowPtrs C_rows, const MMArgs& args, - const MMConfig& config, MMOptions options) { - PROFILER_ZONE("MM.DoMatMul"); - const size_t pkg_idx = 0; - HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1); - const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - - switch (options.parallelism_type) { - case ParallelismType::kNested: - HWY_DASSERT(options.cluster_idx == 0); - MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx, - range_np)(MMNestedParallelPolicy(), A, B, C_rows); - break; - case ParallelismType::kSequential: - MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx, - range_np)(MMSequentialPolicy(), A, B, C_rows); - case ParallelismType::kWithinCluster: - MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx, - range_np)(MMClusterParallelPolicy(), A, B, C_rows); - break; - default: - HWY_ABORT("Parallelism type %s not implemented.", - static_cast(options.parallelism_type)); - break; - } - } -}; +}; // MMState // Computes the matrix product `A * B * scale [+ add]` and stores it in `C`. // @@ -1033,17 +1036,19 @@ template HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const float* HWY_RESTRICT add, MatMulEnv& env, MatPtrT& C, MMOptions options = MMOptions()) { - RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[0]); + HWY_DASSERT(options.cluster_idx < env.row_ptrs.size()); + RowPtrs C_rows = + GetOrSetTempRowPtrs(C, env.row_ptrs[options.cluster_idx]); const Allocator& allocator = env.ctx.allocator; const size_t M = A.Rows(); const size_t K = A.Cols(); const size_t N = B.Rows(); const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N); - intptr_t index = MMImpl::IndexOfKey(key, env.keys); + intptr_t index = MMImpl::IndexOfKey(key, env.keys[options.cluster_idx]); // First time we see this shape/key. if (HWY_UNLIKELY(index < 0)) { - env.keys.Append(key, allocator); + env.keys[options.cluster_idx].Append(key, allocator); size_t max_packages = kMaxPackages; // For low-batch, multiple sockets only help if binding is enabled. @@ -1052,16 +1057,19 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, } // invalidates `MMAutoTune::Best()` - index = env.per_key.size(); - env.per_key.push_back(MMPerKey(env.ctx, max_packages, N, sizeof(TC), kNR)); + std::vector& stored_keys = env.per_key[options.cluster_idx]; + index = stored_keys.size(); + stored_keys.push_back(MMPerKey(env.ctx, max_packages, N, sizeof(TC), kNR)); } - MMPerKey& per_key = env.per_key[index]; + MMPerKey& per_key = env.per_key[options.cluster_idx][index]; MMAutoTune& tuner = per_key.autotune; const MMArgs args(env, per_key, static_cast(A.Scale()) * B.Scale(), - add); + add, options); if (HWY_LIKELY(tuner.Best())) { - MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), options); + const MMState state(A.Extents(), args, *tuner.Best()); + const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args); + state.DispatchParallelism(A_view, B, C_rows); return &per_key; } @@ -1089,7 +1097,9 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const MMConfig& cfg = tuner.NextConfig(); const uint64_t t0 = hwy::timer::Start(); - MMImpl::DoMatMul(A, B, C_rows, args, cfg, options); + MMState state(A.Extents(), args, cfg); + const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args); + state.DispatchParallelism(A_view, B, C_rows); const uint64_t t1 = env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); const double min_elapsed = static_cast(tuner.NotifyTicks(t1 - t0)) / diff --git a/ops/matmul.cc b/ops/matmul.cc index 83fc036..75b37a2 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -405,20 +405,15 @@ IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages, MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx) { // Create storage per cluster. This only applies to in-cluster parallelism. // For nested and sequential parallelism, a single MMStorage is used. - size_t num_packages = ctx.topology.NumPackages(); - size_t num_clusters = 0; - for (size_t pkg_idx = 0; pkg_idx < num_packages; ++pkg_idx) { - num_clusters += ctx.topology.NumClusters(pkg_idx); - } + const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers(); storage.reserve(num_clusters); for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { storage.push_back(MMStorage(ctx)); + row_ptrs.push_back(hwy::AllocateAligned(kMaxBatchSize)); // C } char cpu100[100]; have_timer_stop = hwy::platform::HaveTimerStop(cpu100); - - row_ptrs.push_back(hwy::AllocateAligned(kMaxBatchSize)); // C } void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) { diff --git a/ops/matmul.h b/ops/matmul.h index e76d37b..16cb51c 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -58,147 +58,127 @@ IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages, size_t N, size_t sizeof_TC, size_t nr); struct MMOptions { - ParallelismType parallelism_type = ParallelismType::kNested; - uint8_t cluster_idx = 0; + uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`. + ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical; }; -struct MMSequentialPolicy { - template - static void ForPkg(ThreadingContext& ctx, const size_t max_packages, - const Func& func) { - func(/*pkg_idx=*/0); - } +// Policy classes for parallelism, implementing some of `ParallelismStrategy`. +struct MMParallelNone { template - static void ForNP(ThreadingContext& ctx, const IndexRange& range_np, - size_t nx_multiple, size_t inner_tasks, size_t pkg_idx, - size_t cluster_idx, const Func& func) { + void ForNP(ThreadingContext& ctx, const IndexRange& range_np, + size_t nx_multiple, size_t inner_tasks, size_t cluster_idx, + const Func& func) const { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); - const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() + - cluster_idx * ctx.pools.MaxWorkersPerCluster(); - func(range_np, base_idx); + const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + func(range_np, worker); } template - static void ForRangesMC_NC(ThreadingContext& ctx, - const IndexRangePartition& ranges_mc, - const IndexRangePartition& ranges_nc, - size_t pkg_idx, size_t cluster_idx, - const Func& func) { - const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() + - cluster_idx * ctx.pools.MaxWorkersPerCluster(); + void ForRangesMC_NC(ThreadingContext& ctx, + const IndexRangePartition& ranges_mc, + const IndexRangePartition& ranges_nc, size_t cluster_idx, + const Func& func) const { + const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster(); for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) { const IndexRange range_mc = ranges_mc.Range(i); for (size_t j = 0; j < ranges_nc.NumTasks(); ++j) { const IndexRange range_nc = ranges_nc.Range(j); - func(range_mc, range_nc, base_idx); + func(range_mc, range_nc, worker); } } } template - static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, - size_t pkg_idx, size_t cluster_idx, const Func& func) { - const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() + - cluster_idx * ctx.pools.MaxWorkersPerCluster(); + void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, + size_t cluster_idx, const Func& func) const { + const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster(); for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) { - func(row_a, base_idx); + func(row_a, worker); } } }; -struct MMClusterParallelPolicy { +struct MMParallelWithinCluster { template - static void ForPkg(ThreadingContext& ctx, const size_t max_packages, - const Func& func) { - func(/*pkg_idx=*/0); - } - - template - static void ForNP(ThreadingContext& ctx, const IndexRange& range_np, - size_t nx_multiple, size_t inner_tasks, size_t pkg_idx, - size_t cluster_idx, const Func& func) { + void ForNP(ThreadingContext& ctx, const IndexRange& range_np, + size_t nx_multiple, size_t inner_tasks, size_t cluster_idx, + const Func& func) const { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); + const size_t pkg_idx = 0; hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); + const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const IndexRangePartition worker_ranges = StaticPartition( range_np, cluster.NumWorkers() * inner_tasks, nx_multiple); ParallelizeOneRange(worker_ranges, cluster, [&](const IndexRange& worker_range, size_t worker) { - func(worker_range, worker); + func(worker_range, base + worker); }); } template - static void ForRangesMC_NC(ThreadingContext& ctx, - const IndexRangePartition& ranges_mc, - const IndexRangePartition& ranges_nc, - size_t pkg_idx, size_t cluster_idx, - const Func& func) { + void ForRangesMC_NC(ThreadingContext& ctx, + const IndexRangePartition& ranges_mc, + const IndexRangePartition& ranges_nc, size_t cluster_idx, + const Func& func) const { + const size_t pkg_idx = 0; hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); + const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster(); // Low-batch: avoid Divide/Remainder. if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { ParallelizeOneRange(ranges_nc, cluster, - [&](const IndexRange& range_nc, size_t thread) { - func(ranges_mc.Range(0), range_nc, thread); + [&](const IndexRange& range_nc, size_t worker) { + func(ranges_mc.Range(0), range_nc, base + worker); }); } else { ParallelizeTwoRanges( ranges_mc, ranges_nc, cluster, [&](const IndexRange& range_mc, const IndexRange& range_nc, - size_t thread) { func(range_mc, range_nc, thread); }); + size_t worker) { func(range_mc, range_nc, base + worker); }); } } template - static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, - size_t pkg_idx, size_t cluster_idx, const Func& func) { + void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, + size_t cluster_idx, const Func& func) const { + const size_t pkg_idx = 0; hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); - cluster.Run(range_mc.begin(), range_mc.end(), - [&](uint64_t row_a, size_t thread) { func(row_a, thread); }); + const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + + cluster.Run( + range_mc.begin(), range_mc.end(), + [&](uint64_t row_a, size_t worker) { func(row_a, base + worker); }); } }; -struct MMNestedParallelPolicy { - template - static void ForPkg(ThreadingContext& ctx, const size_t max_packages, - const Func& func) { - if constexpr (kMaxPackages > 1) { - ctx.pools.AllPackages().Run( - 0, HWY_MIN(max_packages, ctx.pools.NumPackages()), - [&](uint64_t task, size_t pkg_idx) { - HWY_DASSERT(task == pkg_idx); - (void)task; - func(pkg_idx); - }); - } else { - func(/*pkg_idx=*/0); - } - } - +struct MMParallelHierarchical { // Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is // the granularity of per-cluster tasks. Calls `func(worker_range, worker)`. - // `cluster_idx` is not used here as all clusters within a package are used. template - static void ForNP(ThreadingContext& ctx, const IndexRange& range_np, - size_t nx_multiple, size_t inner_tasks, size_t pkg_idx, - size_t /*cluster_idx*/, const Func& func) { + void ForNP(ThreadingContext& ctx, const IndexRange& range_np, + size_t nx_multiple, size_t inner_tasks, + HWY_MAYBE_UNUSED size_t caller_cluster_idx, + const Func& func) const { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); - const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); + HWY_DASSERT(caller_cluster_idx == 0); // Single cluster: parallel-for over static partition of `range_np`. + const size_t pkg_idx = 0; hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx); const size_t num_clusters = all_clusters.NumWorkers(); if (num_clusters == 1) { - hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, 0); + const size_t cluster_idx = 0; + hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); const IndexRangePartition worker_ranges = StaticPartition( range_np, cluster.NumWorkers() * inner_tasks, nx_multiple); return ParallelizeOneRange( worker_ranges, cluster, - [&](const IndexRange& worker_range, size_t thread) { - func(worker_range, pkg_base + thread); + [&](const IndexRange& worker_range, size_t worker) { + func(worker_range, worker); }); } @@ -210,28 +190,29 @@ struct MMNestedParallelPolicy { [&](const IndexRange& nx_range, const size_t cluster_idx) { hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); const size_t cluster_base = - pkg_base + cluster_idx * ctx.pools.MaxWorkersPerCluster(); + cluster_idx * ctx.pools.MaxWorkersPerCluster(); // Parallel-for over sub-ranges of `cluster_range` within the cluster. const IndexRangePartition worker_ranges = StaticPartition( nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple); ParallelizeOneRange( worker_ranges, cluster, - [&](const IndexRange& worker_range, size_t thread) { - func(worker_range, cluster_base + thread); + [&](const IndexRange& worker_range, size_t worker) { + func(worker_range, cluster_base + worker); }); }); } // Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B // rows). Calls `func(range_mc, range_nc, worker)`. - // `cluster_idx` is not used here as all clusters within a package are used. template - static void ForRangesMC_NC(ThreadingContext& ctx, - const IndexRangePartition& ranges_mc, - const IndexRangePartition& ranges_nc, - size_t pkg_idx, size_t /*cluster_idx*/, - const Func& func) { - const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); + void ForRangesMC_NC(ThreadingContext& ctx, + const IndexRangePartition& ranges_mc, + const IndexRangePartition& ranges_nc, + HWY_MAYBE_UNUSED size_t caller_cluster_idx, + const Func& func) const { + const size_t pkg_idx = 0; + HWY_DASSERT(caller_cluster_idx == 0); + hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx); // `all_clusters` is a pool with one worker per cluster in a package. const size_t num_clusters = all_clusters.NumWorkers(); @@ -243,16 +224,14 @@ struct MMNestedParallelPolicy { // Low-batch: avoid Divide/Remainder. if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { return ParallelizeOneRange( - ranges_nc, cluster, [&](const IndexRange& range_nc, size_t thread) { - func(ranges_mc.Range(0), range_nc, pkg_base + thread); + ranges_nc, cluster, [&](const IndexRange& range_nc, size_t worker) { + func(ranges_mc.Range(0), range_nc, worker); }); } else { return ParallelizeTwoRanges( ranges_mc, ranges_nc, cluster, [&](const IndexRange& range_mc, const IndexRange& range_nc, - size_t thread) { - func(range_mc, range_nc, pkg_base + thread); - }); + size_t worker) { func(range_mc, range_nc, worker); }); } } @@ -262,25 +241,23 @@ struct MMNestedParallelPolicy { ranges_nc, all_clusters, [&](const IndexRange range_nc, size_t cluster_idx) { const size_t cluster_base = - pkg_base + cluster_idx * ctx.pools.MaxWorkersPerCluster(); + cluster_idx * ctx.pools.MaxWorkersPerCluster(); hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); ParallelizeOneRange(ranges_mc, cluster, - [&](const IndexRange& range_mc, size_t thread) { - func(range_mc, range_nc, cluster_base + thread); + [&](const IndexRange& range_mc, size_t worker) { + func(range_mc, range_nc, cluster_base + worker); }); }); } // Calls `func(row_a, worker)` in parallel. - // `cluster_idx` is not used here as all clusters within a package are used. template - static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, - size_t pkg_idx, size_t /*cluster_idx*/, - const Func& func) { - const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); - ctx.pools.Pool(pkg_idx).Run( - range_mc.begin(), range_mc.end(), - [&](uint64_t row_a, size_t thread) { func(row_a, pkg_base + thread); }); + void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, + size_t caller_cluster_idx, const Func& func) const { + HierarchicalParallelFor(range_mc.Num(), ctx.pools, + [&](size_t task, size_t worker) { + func(range_mc.begin() + task, worker); + }); } }; @@ -340,27 +317,22 @@ class MMStorage { // Internally threaded; must not be called concurrently with the same // `ThreadingContext` (used via `parallel`). MMStorage(ThreadingContext& ctx) { - // Per-package allocation so each can decompress A into its own copy. - // Must be padded, see `DoDecompressA`. - // Default to nested parallel policy. - MMNestedParallelPolicy::ForPkg(ctx, kMaxPackages, [&](size_t pkg_idx) { - Allocator& allocator = ctx.allocator; + Allocator& allocator = ctx.allocator; + const size_t pkg_idx = 0; - // 0.5 GiB per package. - pkg_A_[pkg_idx].reset( - new MatStorageT("pkg_A", Extents2D(kMaxBatchSize, kMaxK), - allocator, MatPadding::kOdd)); + // 0.5 GiB per package. Must be padded, see `DoDecompressA`. + pkg_A_[pkg_idx].reset(new MatStorageT( + "pkg_A", Extents2D(kMaxBatchSize, kMaxK), allocator, MatPadding::kOdd)); - if (allocator.ShouldBind()) { - const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node(); - size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() * - pkg_A_[pkg_idx]->ElementBytes(); - bytes = hwy::RoundDownTo(bytes, allocator.BasePageBytes()); - if (!allocator.BindMemory(pkg_A_[pkg_idx]->Row(0), bytes, node)) { - HWY_WARN("Failed to bind memory for package %zu", pkg_idx); - } + if (allocator.ShouldBind()) { + const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node(); + size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() * + pkg_A_[pkg_idx]->ElementBytes(); + bytes = hwy::RoundDownTo(bytes, allocator.BasePageBytes()); + if (!allocator.BindMemory(pkg_A_[pkg_idx]->Row(0), bytes, node)) { + HWY_WARN("Failed to bind memory for package %zu", pkg_idx); } - }); + } } // Returns per-package matrix view. Converting A=F32 to BF16 up-front is @@ -735,16 +707,18 @@ struct MatMulEnv { bool print_best = false; std::vector storage; - MMKeys keys; - std::vector per_key; + MMKeys keys[kMaxClusters]; + std::vector per_key[kMaxClusters]; // Storage for arbitrary output rows, see `MatPtr::AllocateAndAttachRowPtrs`. // Most MatMul callers use strided MatPtr, but GemmaAttention::ComputeQKV // writes to differing KV positions per query / output row. - // The first entry is sufficient for any C argument, but also potentially - // overwritten by each MatMul. Subsequent entries are precomputed for tensors - // and not overwritten. Per-tensor allocations make it likelier that asan - // detects bugs such as use after free, overrun, and dangling references. + // The first `num_clusters` entries are sufficient for any C argument, and + // must be indexed by `options.cluster_idx`. Note that they are potentially + // overwritten by each `MatMul`. Subsequent entries are for specific tensors + // and only written once by their allocator. A per-tensor allocation makes it + // likelier that asan detects bugs such as use after free, overrun, and + // dangling references. std::vector> row_ptrs; }; @@ -752,14 +726,21 @@ struct MatMulEnv { // Reduces register pressure compared to individual values/references. struct MMArgs { MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale, - const float* HWY_RESTRICT add) - : env(&env), per_key(&per_key), scale(scale), add(add) {} + const float* HWY_RESTRICT add, MMOptions options) + : env(&env), + per_key(&per_key), + scale(scale), + add(add), + options(options), + line_bytes(env.ctx.allocator.LineBytes()) {} MatMulEnv* env; MMPerKey* per_key; double scale; const float* HWY_RESTRICT add; + MMOptions options; + size_t line_bytes; }; // Wrapper over hwy::Zone that is only enabled when autotuning finished. diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 0173ee8..19a39aa 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -501,12 +501,12 @@ void RMSNormBatched(const MatPtrT& activations, const MatPtr& weights, HWY_DASSERT(activations.SameShape(out)); CallUpcasted(&weights, [&](const auto* weights_t) { - ParallelFor( - ParallelismType::kAcrossClusters, activations.Rows(), ctx.pools, - cluster_idx, [&](uint64_t token_idx, size_t worker) { - RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), - out.Row(token_idx), activations.Cols(), ctx.profiler, worker); - }); + ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx, + cluster_idx, [&](uint64_t token_idx, size_t worker) { + RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), + out.Row(token_idx), activations.Cols(), ctx.profiler, + worker); + }); }); } @@ -517,12 +517,12 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT& inout, HWY_DASSERT(weights.Cols() == inout.Cols()); CallUpcasted(&weights, [&](const auto* weights_t) { - ParallelFor( - ParallelismType::kAcrossClusters, inout.Rows(), ctx.pools, cluster_idx, - [&](uint64_t token_idx, size_t worker) { - RMSNormInplace(weights_t->PackedScale1(), inout.Row(token_idx), - inout.Cols(), ctx.profiler, worker); - }); + ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx, + [&](uint64_t token_idx, size_t worker) { + RMSNormInplace(weights_t->PackedScale1(), + inout.Row(token_idx), inout.Cols(), + ctx.profiler, worker); + }); }); } @@ -548,8 +548,8 @@ static HWY_INLINE void AddFromBatched(const MatPtrT& x, MatPtrT& out, ThreadingContext& ctx, size_t cluster_idx = 0) { HWY_DASSERT(out.SameShape(x)); - ParallelFor(ParallelismType::kAcrossClusters, out.Rows(), ctx.pools, - cluster_idx, [&](uint64_t token_idx, size_t worker) { + ParallelFor(ParallelismStrategy::kFlat, out.Rows(), ctx, cluster_idx, + [&](uint64_t token_idx, size_t worker) { AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), ctx.profiler, worker); }); @@ -782,8 +782,8 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched( const float cap, MatPtrT& x, const hwy::BitSet4096<>& non_eos, ThreadingContext& ctx, size_t cluster_idx = 0) { if (cap == 0.0f) return; - ParallelFor(ParallelismType::kAcrossClusters, x.Rows(), ctx.pools, - cluster_idx, [&](uint64_t task, size_t worker) { + ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx, + [&](uint64_t task, size_t worker) { if (non_eos.Get(task)) { LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler, worker); diff --git a/util/basics.h b/util/basics.h index 13d0362..7cdc17c 100644 --- a/util/basics.h +++ b/util/basics.h @@ -35,6 +35,8 @@ namespace gcpp { // is disabled if this is 1. HWY_INLINE_VAR constexpr size_t kMaxPackages = 1; +HWY_INLINE_VAR constexpr size_t kMaxClusters = 128; // TODO: shrink + // TODO: extend to 16k after updating non_eos. HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096; diff --git a/util/threading.h b/util/threading.h index ef4f1c7..5dde114 100644 --- a/util/threading.h +++ b/util/threading.h @@ -326,7 +326,8 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1, // Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes // over clusters of ONE package, then within each cluster. template -void NestedParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) { +void HierarchicalParallelFor(size_t num_tasks, NestedPools& pools, + const Func& func) { // Even if there are multiple packages, we only use the first. const size_t pkg_idx = 0; @@ -356,59 +357,6 @@ void NestedParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) { }); } -// Which pool(s) to use for parallelizing: -enum class ParallelismType : uint8_t { - // None: single-threaded loop on the calling thread. - kSequential, - // One thread per cluster within the first package; or one per core if there - // is only one cluster. Use for few or lightweight tasks, or to maximize - // memory bandwidth availability. - kAcrossClusters, - // All cores within the cluster identified by `cluster_idx`. Use if already - // within a `kAcrossClusters` parallel-for, or if latency is more important - // than memory bandwidth. - kWithinCluster, - // First statically partitions `kAcrossClusters`, then `kWithinCluster`. This - // utilizes all cores, but has higher fork-join overhead (two barriers); use - // if there are many or heavy tasks. - kNested, -}; - -// Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the -// number/type of workers determined by `parallelism`. `cluster_idx` is only -// used if `parallelism == kWithinCluster`. -template -void ParallelFor(ParallelismType parallelism, size_t num_tasks, - NestedPools& pools, size_t cluster_idx, const Func& func) { - if (cluster_idx != 0) { - // If already running across clusters, must not use across-cluster modes. - HWY_DASSERT(parallelism != ParallelismType::kAcrossClusters && - parallelism != ParallelismType::kNested); - } - - const size_t pkg_idx = 0; - switch (parallelism) { - case ParallelismType::kSequential: - for (size_t task = 0; task < num_tasks; ++task) { - func(task, /*worker=*/0); - } - return; - - case ParallelismType::kAcrossClusters: - return pools.Pool(pkg_idx).Run( - 0, num_tasks, - [&](uint64_t task, size_t worker) { func(task, worker); }); - - case ParallelismType::kWithinCluster: - return pools.Cluster(pkg_idx, cluster_idx) - .Run(0, num_tasks, - [&](uint64_t task, size_t worker) { func(task, worker); }); - - case ParallelismType::kNested: - return NestedParallelFor(num_tasks, pools, func); - } -} - } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ diff --git a/util/threading_context.h b/util/threading_context.h index d4fdc17..847ce81 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -116,6 +116,95 @@ struct ThreadingContext { NestedPools pools; }; +// Describes the strategy for distributing parallel work across cores. +enum class ParallelismStrategy : uint8_t { + // Execute using a single-threaded loop on the calling thread. The `worker` + // index passed to the user's `Func` is unique across clusters. + kNone, + // One thread per cluster within the first package. The `worker` index passed + // to the user's `Func` is a `cluster_idx <= NumClusters()`. Some CPUs may + // only have a single cluster, hence `Func` should also contain a nested + // `ParallelFor` with `kWithinCluster`. + kAcrossClusters, + // All cores within the cluster identified by `cluster_idx`. The `worker` + // index passed to the user's `Func` is unique across clusters. Choose this + // strategy if already within a `ParallelFor` call with `kAcrossClusters`, + // or latency is more important than memory bandwidth. + kWithinCluster, + // Equivalent to `kAcrossClusters` if there are multiple clusters, otherwise + // `kWithinCluster`. Use for few or lightweight tasks (this only uses a + // single pool and barrier), or to maximize memory bandwidth availability. + kFlat, + // First statically partitions `kAcrossClusters`, then `kWithinCluster`. This + // utilizes all cores, but has higher fork-join overhead (two barriers); use + // if there are many or heavy tasks. + kHierarchical, +}; + +// Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the +// number/type of workers determined by `parallelism`. `cluster_idx` is for +// `parallelism == kWithinCluster`, and should be 0 if unknown. +template +void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, + ThreadingContext& ctx, size_t cluster_idx, const Func& func) { + HWY_DASSERT(ctx.topology.NumPackages() == 1); + const size_t pkg_idx = 0; + + HWY_DASSERT(cluster_idx < ctx.topology.NumClusters(pkg_idx)); + if (cluster_idx != 0) { + // If already running across clusters, only use within-cluster modes. + HWY_DASSERT(parallelism == ParallelismStrategy::kNone || + parallelism == ParallelismStrategy::kWithinCluster); + } + + switch (parallelism) { + case ParallelismStrategy::kNone: { + const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + for (size_t task = 0; task < num_tasks; ++task) { + func(task, worker); + } + return; + } + + case ParallelismStrategy::kAcrossClusters: + return ctx.pools.AllClusters(pkg_idx).Run( + 0, num_tasks, + [&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); }); + + case ParallelismStrategy::kWithinCluster: { + // Ensure the worker argument is unique across clusters, because it is + // used for TLS indexing for example in profiler.h. + const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + return ctx.pools.Cluster(pkg_idx, cluster_idx) + .Run(0, num_tasks, [&](uint64_t task, size_t worker) { + func(task, base + worker); + }); + } + + case ParallelismStrategy::kFlat: { + // Check for single cluster; if not, we must compute `cluster_base` for + // consistent and non-overlapping worker indices. + hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx); + const size_t num_clusters = all_clusters.NumWorkers(); + if (num_clusters == 1) { + return ctx.pools.Cluster(pkg_idx, cluster_idx) + .Run(0, num_tasks, + [&](uint64_t task, size_t worker) { func(task, worker); }); + } + + return ctx.pools.AllClusters(pkg_idx).Run( + 0, num_tasks, [&](uint64_t task, size_t cluster_idx) { + const size_t worker = + cluster_idx * ctx.pools.MaxWorkersPerCluster(); + func(task, worker); + }); + } + + case ParallelismStrategy::kHierarchical: + return HierarchicalParallelFor(num_tasks, ctx.pools, func); + } +} + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_CONTEXT_H_