diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 9f12407..52bd507 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -51,7 +51,7 @@ int main(int argc, char** argv) { } // Instantiate model and KV Cache - gcpp::ThreadingContext ctx(gcpp::UpdateArgs(threading, inference)); + gcpp::ThreadingContext ctx(threading); gcpp::MatMulEnv env(ctx); gcpp::Gemma gemma(loader, inference, ctx); gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator); diff --git a/examples/simplified_gemma/gemma.hpp b/examples/simplified_gemma/gemma.hpp index 7f6e4c2..7800233 100644 --- a/examples/simplified_gemma/gemma.hpp +++ b/examples/simplified_gemma/gemma.hpp @@ -35,7 +35,7 @@ class SimplifiedGemma { SimplifiedGemma(const gcpp::LoaderArgs& loader, const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(), const gcpp::InferenceArgs& inference = gcpp::InferenceArgs()) - : ctx_(UpdateArgs(threading, inference)), + : ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_), kv_cache_(gemma_.Config(), inference, ctx_.allocator) { diff --git a/gemma/attention.cc b/gemma/attention.cc index 74ea77a..7fc8e76 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -228,10 +228,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, { PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin"); - const size_t pkg_idx = 0; // Full parallelism is helpful, SmallParallelFor is insufficient. ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, - pools, pkg_idx, func); + pools, func); } } diff --git a/gemma/bindings/context.cc b/gemma/bindings/context.cc index 76ebe1e..a6ebe30 100644 --- a/gemma/bindings/context.cc +++ b/gemma/bindings/context.cc @@ -103,7 +103,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader, int max_generated_tokens) : inference_args(inference_args), threading_args(threading_args), - ctx(UpdateArgs(threading_args, inference_args)), + ctx(threading_args), matmul_env(ctx), active_conversation_name("default"), model(loader, inference_args, matmul_env.ctx) { diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index e3f9a19..6c75cd2 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -66,13 +66,11 @@ void Activation(ActivationType activation, T* HWY_RESTRICT c1, template void ActivationBatched(ActivationType activation, Mat& c1, NestedPools& pools) { using T = typename Mat::T; - const size_t pkg_idx = 0; - SmallParallelFor( - c1.Rows(), pools, pkg_idx, [&](uint64_t task, size_t worker) { - // Cast to correct type so type deduction works. - Activation(activation, c1.Row(task), static_cast(nullptr), - c1.Cols(), worker); - }); + SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) { + // Cast to correct type so type deduction works. + Activation(activation, c1.Row(task), static_cast(nullptr), + c1.Cols(), worker); + }); } template @@ -80,19 +78,15 @@ HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1, const Mat* c2, NestedPools& pools) { using T = typename Mat::T; HWY_DASSERT(c1.SameShape(*c2)); - const size_t pkg_idx = 0; if (c2 && c2->HasPtr()) { - SmallParallelFor(c1.Rows(), pools, pkg_idx, - [&](uint64_t task, size_t worker) { - Activation(activation, c1.Row(task), c2->Row(task), - c1.Cols(), worker); - }); + SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) { + Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), worker); + }); } else { // No multiplier - SmallParallelFor( - c1.Rows(), pools, pkg_idx, [&](uint64_t task, size_t worker) { - Activation(activation, c1.Row(task), static_cast(nullptr), - c1.Cols(), worker); - }); + SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) { + Activation(activation, c1.Row(task), static_cast(nullptr), + c1.Cols(), worker); + }); } } diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 70268c7..161e9a5 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -258,16 +258,6 @@ struct InferenceArgs : public ArgsBase { } }; -static inline ThreadingArgs UpdateArgs(const ThreadingArgs& threading_args, - const InferenceArgs& inference_args) { - if (inference_args.decode_qbatch_size >= 256) { - ThreadingArgs copy = threading_args; - copy.max_packages = 1; - return copy; - } - return threading_args; -} - } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ diff --git a/gemma/run.cc b/gemma/run.cc index dd5165c..286c6ee 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -253,7 +253,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading, const InferenceArgs& inference) { PROFILER_ZONE("Run.misc"); - ThreadingContext ctx(UpdateArgs(threading, inference)); + ThreadingContext ctx(threading); MatMulEnv env(ctx); if (inference.verbosity >= 2) env.print_best = true; const Gemma gemma(loader, inference, ctx); diff --git a/gemma/weights.cc b/gemma/weights.cc index 2f363e6..721cfb6 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -278,9 +278,8 @@ static WeightsPtrs::Mode ChooseMode(uint64_t file_bytes, if (to_bf16 == Tristate::kDefault) { // Heuristic: sub-bf16 compression is not helpful if compute-bound. - const size_t batch_size = - HWY_MAX(inference.prefill_tbatch_size, inference.decode_qbatch_size); - to_bf16 = (batch_size >= 128) ? Tristate::kTrue : Tristate::kFalse; + to_bf16 = (inference.decode_qbatch_size >= 128) ? Tristate::kTrue + : Tristate::kFalse; } if (map == Tristate::kDefault) { diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 6712da3..dfd8835 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -1282,14 +1282,21 @@ struct MMImpl { PROFILER_ZONE("MM.DoMatMul"); static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DoMatMul.PerPkg"); - // Outermost loop: static NUMA-aware partition of B rows across packages. - args.env->parallel.ForPkg( - args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) { - MMZone matmul_zone; - matmul_zone.MaybeEnter(pkg_idx, zone_id, args); - const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows); - }); + if constexpr (kMaxPackages > 1) { + // Outermost loop: static NUMA-aware partition of B rows across packages. + args.env->parallel.ForPkg( + args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) { + MMZone matmul_zone; + matmul_zone.MaybeEnter(pkg_idx, zone_id, args); + const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); + MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows); + }); + } else { + const size_t pkg_idx = 0; + HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1); + const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); + MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows); + } } }; @@ -1333,7 +1340,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, if (HWY_UNLIKELY(index < 0)) { env.keys.Append(key, allocator); - size_t max_packages = MMParallel::kMaxPackages; + size_t max_packages = kMaxPackages; // For low-batch, multiple sockets only help if binding is enabled. if (!allocator.ShouldBind() && M <= 4) { max_packages = 1; diff --git a/ops/matmul.cc b/ops/matmul.cc index de6d52b..c51acbd 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -441,7 +441,7 @@ void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) { PROFILER_ZONE("Startup.BindB"); const IndexRangePartition ranges_np = - parallel.RangesOfNP(MMParallel::kMaxPackages, B.Rows(), sizeof_TC, kNR); + parallel.RangesOfNP(kMaxPackages, B.Rows(), sizeof_TC, kNR); for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { const IndexRange& rows_b = ranges_np.Range(pkg_idx); const size_t node = parallel.Node(pkg_idx); @@ -464,8 +464,8 @@ void BindC(MatPtr& C, MMParallel& parallel) { PROFILER_ZONE("Startup.BindC"); - const IndexRangePartition ranges_np = parallel.RangesOfNP( - MMParallel::kMaxPackages, C.Cols(), C.ElementBytes(), kNR); + const IndexRangePartition ranges_np = + parallel.RangesOfNP(kMaxPackages, C.Cols(), C.ElementBytes(), kNR); bool ok = true; for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { const IndexRange& cols_c = ranges_np.Range(pkg_idx); diff --git a/ops/matmul.h b/ops/matmul.h index 99477d3..69e2256 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -57,11 +57,12 @@ static constexpr size_t kMaxMR = 4; // the ThreadingContext to shorten call sites. class MMParallel { public: - static constexpr size_t kMaxPackages = 4; - // `ctx` must outlive this object. MMParallel(ThreadingContext& ctx) : ctx_(ctx) { - HWY_DASSERT(ctx_.pools.NumPackages() <= kMaxPackages); + if (ctx_.pools.NumPackages() > kMaxPackages) { + HWY_WARN("CPU and --max_packages allow %zu > matmul.h kMaxPackages %zu.", + ctx_.pools.NumPackages(), kMaxPackages); + } } Allocator& allocator() const { return ctx_.allocator; } @@ -78,13 +79,17 @@ class MMParallel { // Calls `func(pkg_idx)` for each package in parallel. template void ForPkg(const size_t max_packages, const Func& func) { - ctx_.pools.AllPackages().Run( - 0, HWY_MIN(max_packages, ctx_.pools.NumPackages()), - [&](uint64_t task, size_t pkg_idx) { - HWY_DASSERT(task == pkg_idx); - (void)task; - func(pkg_idx); - }); + if constexpr (kMaxPackages > 1) { + ctx_.pools.AllPackages().Run( + 0, HWY_MIN(max_packages, ctx_.pools.NumPackages()), + [&](uint64_t task, size_t pkg_idx) { + HWY_DASSERT(task == pkg_idx); + (void)task; + func(pkg_idx); + }); + } else { + func(/*pkg_idx=*/0); + } } // Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is @@ -257,7 +262,7 @@ class MMStorage { partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) { // Per-package allocation so each can decompress A into its own copy. // Must be padded, see `DoDecompressA`. - parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) { + parallel.ForPkg(kMaxPackages, [&](size_t pkg_idx) { pkg_A_[pkg_idx].reset(new MatStorageT( "pkg_A", Extents2D(kMaxM, kMaxK), allocator, MatPadding::kOdd)); @@ -287,7 +292,7 @@ class MMStorage { StridedViewD Partial() const { return partial_; } private: - std::unique_ptr> pkg_A_[MMParallel::kMaxPackages]; + std::unique_ptr> pkg_A_[kMaxPackages]; MatStorageT partial_storage_; StridedViewD partial_; }; @@ -646,7 +651,9 @@ class MMKeys { struct MMPerKey { MMPerKey(size_t max_packages, size_t N, size_t sizeof_TC, size_t nr, MMParallel& parallel) - : ranges_np(parallel.RangesOfNP(max_packages, N, sizeof_TC, nr)) {} + : ranges_np(parallel.RangesOfNP(max_packages, N, sizeof_TC, nr)) { + HWY_DASSERT(ranges_np.NumTasks() <= max_packages); + } // Only profile if enabled and the main autotuner finished (the par_a // autotuner is per-package and we want to avoid synchronization). @@ -654,7 +661,7 @@ struct MMPerKey { const IndexRangePartition ranges_np; MMAutoTune autotune; - MMAutoTune autotune_par_a[MMParallel::kMaxPackages]; + MMAutoTune autotune_par_a[kMaxPackages]; }; // Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 2ec77f5..122012e 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -264,7 +264,7 @@ void TestTiny() { MatMulEnv env(ctx); NestedPools& pools = env.ctx.pools; - if constexpr (GEMMA_DISABLE_TOPOLOGY) { + if constexpr (GEMMA_DISABLE_TOPOLOGY || kMaxPackages == 1) { if (max_packages == 2) break; // we only have one package } else { // If less than the limit, we have already tested all num_packages. diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 0806846..688450b 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -576,8 +576,7 @@ void RMSNormBatched(const MatPtrT& activations, const MatPtr& weights, HWY_DASSERT(activations.SameShape(out)); CallUpcasted(&weights, [&](const auto* weights_t) { - const size_t pkg_idx = 0; - SmallParallelFor(activations.Rows(), ctx.pools, pkg_idx, + SmallParallelFor(activations.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) { RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), 0, out.Row(token_idx), @@ -593,8 +592,7 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT& inout, HWY_DASSERT(weights.Cols() == inout.Cols()); CallUpcasted(&weights, [&](const auto* weights_t) { - const size_t pkg_idx = 0; - SmallParallelFor(inout.Rows(), ctx.pools, pkg_idx, + SmallParallelFor(inout.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) { RMSNormInplace(weights_t->PackedScale1(), 0, inout.Row(token_idx), inout.Cols(), @@ -624,9 +622,8 @@ template static HWY_INLINE void AddFromBatched(const MatPtrT& x, MatPtrT& out, ThreadingContext& ctx) { HWY_DASSERT(out.SameShape(x)); - const size_t pkg_idx = 0; SmallParallelFor( - out.Rows(), ctx.pools, pkg_idx, [&](uint64_t token_idx, size_t worker) { + out.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) { AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), worker); }); } diff --git a/util/basics.h b/util/basics.h index 40545fd..30864b2 100644 --- a/util/basics.h +++ b/util/basics.h @@ -30,6 +30,11 @@ namespace gcpp { +// Maximum number of packages (CPU sockets) to use. `ThreadingArgs` verifies the +// runtime `max_packages` does not exceed this. MatMul's outer per-package loop +// is disabled if this is 1. +constexpr size_t kMaxPackages = 1; + enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 }; static inline const char* ToString(Tristate t) { diff --git a/util/threading.h b/util/threading.h index 8d2c013..6c2e187 100644 --- a/util/threading.h +++ b/util/threading.h @@ -324,9 +324,9 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1, // Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes // over clusters of ONE package, then within each cluster. template -void ParallelFor(size_t num_tasks, NestedPools& pools, size_t pkg_idx, - const Func& func) { - const size_t pkg_base = pkg_idx * pools.MaxWorkersPerPackage(); +void ParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) { + // Even if there are multiple packages, we only use the first. + const size_t pkg_idx = 0; // If few tasks, run on a single cluster. Also avoids a bit of overhead if // there is only one cluster. @@ -335,7 +335,7 @@ void ParallelFor(size_t num_tasks, NestedPools& pools, size_t pkg_idx, hwy::ThreadPool& cluster = pools.Cluster(pkg_idx, 0); if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) { return cluster.Run(0, num_tasks, [&](uint64_t task, size_t thread) { - func(task, pkg_base + thread); + func(task, thread); }); } @@ -346,8 +346,7 @@ void ParallelFor(size_t num_tasks, NestedPools& pools, size_t pkg_idx, ranges, all_clusters, [&](const IndexRange& range, const size_t cluster_idx) { hwy::ThreadPool& cluster = pools.Cluster(pkg_idx, cluster_idx); - const size_t cluster_base = - pkg_base + cluster_idx * pools.MaxWorkersPerCluster(); + const size_t cluster_base = cluster_idx * pools.MaxWorkersPerCluster(); cluster.Run(range.begin(), range.end(), [&](uint64_t task, size_t thread) { func(task, cluster_base + thread); @@ -357,13 +356,12 @@ void ParallelFor(size_t num_tasks, NestedPools& pools, size_t pkg_idx, // As above, but for lightweight tasks. Uses only one pool. template -void SmallParallelFor(size_t num_tasks, NestedPools& pools, size_t pkg_idx, - const Func& func) { - const size_t pkg_base = pkg_idx * pools.MaxWorkersPerPackage(); +void SmallParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) { + // Even if there are multiple packages, we only use the first. + const size_t pkg_idx = 0; - pools.Pool(pkg_idx).Run(0, num_tasks, [&](uint64_t task, size_t thread) { - func(task, pkg_base + thread); - }); + pools.Pool(pkg_idx).Run( + 0, num_tasks, [&](uint64_t task, size_t thread) { func(task, thread); }); } } // namespace gcpp diff --git a/util/threading_context.h b/util/threading_context.h index 564ea90..8d14fdf 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -25,7 +25,7 @@ // IWYU pragma: begin_exports #include "util/allocator.h" #include "util/args.h" -#include "util/basics.h" // Tristate +#include "util/basics.h" // Tristate, kMaxPackages #include "util/threading.h" #include "util/topology.h" // IWYU pragma: end_exports @@ -60,8 +60,9 @@ class ThreadingArgs : public ArgsBase { // all available resources. visitor(skip_packages, "skip_packages", size_t{0}, "Index of the first socket to use; default 0 = unlimited.", 2); - visitor(max_packages, "max_packages", size_t{0}, - "Max sockets to use; default 0 = all unless large batch size.", 2); + visitor(max_packages, "max_packages", size_t{1}, + "Max sockets to use; default = 1, 0 = unlimited.", 2); + HWY_ASSERT(max_packages <= kMaxPackages); visitor(skip_clusters, "skip_clusters", size_t{0}, "Index of the first CCX to use; default 0 = unlimited.", 2); visitor(max_clusters, "max_clusters", size_t{0},