mirror of https://github.com/google/gemma.cpp.git
Default to disabling per-socket parallelization
weights: default to Read for small-batch (only look at qbatch, not the larger prefill tbatch) PiperOrigin-RevId: 790787643
This commit is contained in:
parent
b56b2f05e4
commit
701841897b
|
|
@ -51,7 +51,7 @@ int main(int argc, char** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Instantiate model and KV Cache
|
// Instantiate model and KV Cache
|
||||||
gcpp::ThreadingContext ctx(gcpp::UpdateArgs(threading, inference));
|
gcpp::ThreadingContext ctx(threading);
|
||||||
gcpp::MatMulEnv env(ctx);
|
gcpp::MatMulEnv env(ctx);
|
||||||
gcpp::Gemma gemma(loader, inference, ctx);
|
gcpp::Gemma gemma(loader, inference, ctx);
|
||||||
gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ class SimplifiedGemma {
|
||||||
SimplifiedGemma(const gcpp::LoaderArgs& loader,
|
SimplifiedGemma(const gcpp::LoaderArgs& loader,
|
||||||
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
|
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
|
||||||
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
|
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
|
||||||
: ctx_(UpdateArgs(threading, inference)),
|
: ctx_(threading),
|
||||||
env_(ctx_),
|
env_(ctx_),
|
||||||
gemma_(loader, inference, ctx_),
|
gemma_(loader, inference, ctx_),
|
||||||
kv_cache_(gemma_.Config(), inference, ctx_.allocator) {
|
kv_cache_(gemma_.Config(), inference, ctx_.allocator) {
|
||||||
|
|
|
||||||
|
|
@ -228,10 +228,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
|
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
|
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
|
||||||
const size_t pkg_idx = 0;
|
|
||||||
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
||||||
ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads,
|
ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads,
|
||||||
pools, pkg_idx, func);
|
pools, func);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -103,7 +103,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
|
||||||
int max_generated_tokens)
|
int max_generated_tokens)
|
||||||
: inference_args(inference_args),
|
: inference_args(inference_args),
|
||||||
threading_args(threading_args),
|
threading_args(threading_args),
|
||||||
ctx(UpdateArgs(threading_args, inference_args)),
|
ctx(threading_args),
|
||||||
matmul_env(ctx),
|
matmul_env(ctx),
|
||||||
active_conversation_name("default"),
|
active_conversation_name("default"),
|
||||||
model(loader, inference_args, matmul_env.ctx) {
|
model(loader, inference_args, matmul_env.ctx) {
|
||||||
|
|
|
||||||
|
|
@ -66,9 +66,7 @@ void Activation(ActivationType activation, T* HWY_RESTRICT c1,
|
||||||
template <class Mat>
|
template <class Mat>
|
||||||
void ActivationBatched(ActivationType activation, Mat& c1, NestedPools& pools) {
|
void ActivationBatched(ActivationType activation, Mat& c1, NestedPools& pools) {
|
||||||
using T = typename Mat::T;
|
using T = typename Mat::T;
|
||||||
const size_t pkg_idx = 0;
|
SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) {
|
||||||
SmallParallelFor(
|
|
||||||
c1.Rows(), pools, pkg_idx, [&](uint64_t task, size_t worker) {
|
|
||||||
// Cast to correct type so type deduction works.
|
// Cast to correct type so type deduction works.
|
||||||
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
|
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
|
||||||
c1.Cols(), worker);
|
c1.Cols(), worker);
|
||||||
|
|
@ -80,16 +78,12 @@ HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1,
|
||||||
const Mat* c2, NestedPools& pools) {
|
const Mat* c2, NestedPools& pools) {
|
||||||
using T = typename Mat::T;
|
using T = typename Mat::T;
|
||||||
HWY_DASSERT(c1.SameShape(*c2));
|
HWY_DASSERT(c1.SameShape(*c2));
|
||||||
const size_t pkg_idx = 0;
|
|
||||||
if (c2 && c2->HasPtr()) {
|
if (c2 && c2->HasPtr()) {
|
||||||
SmallParallelFor(c1.Rows(), pools, pkg_idx,
|
SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) {
|
||||||
[&](uint64_t task, size_t worker) {
|
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), worker);
|
||||||
Activation(activation, c1.Row(task), c2->Row(task),
|
|
||||||
c1.Cols(), worker);
|
|
||||||
});
|
});
|
||||||
} else { // No multiplier
|
} else { // No multiplier
|
||||||
SmallParallelFor(
|
SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) {
|
||||||
c1.Rows(), pools, pkg_idx, [&](uint64_t task, size_t worker) {
|
|
||||||
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
|
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
|
||||||
c1.Cols(), worker);
|
c1.Cols(), worker);
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -258,16 +258,6 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static inline ThreadingArgs UpdateArgs(const ThreadingArgs& threading_args,
|
|
||||||
const InferenceArgs& inference_args) {
|
|
||||||
if (inference_args.decode_qbatch_size >= 256) {
|
|
||||||
ThreadingArgs copy = threading_args;
|
|
||||||
copy.max_packages = 1;
|
|
||||||
return copy;
|
|
||||||
}
|
|
||||||
return threading_args;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||||
|
|
|
||||||
|
|
@ -253,7 +253,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
const InferenceArgs& inference) {
|
const InferenceArgs& inference) {
|
||||||
PROFILER_ZONE("Run.misc");
|
PROFILER_ZONE("Run.misc");
|
||||||
|
|
||||||
ThreadingContext ctx(UpdateArgs(threading, inference));
|
ThreadingContext ctx(threading);
|
||||||
MatMulEnv env(ctx);
|
MatMulEnv env(ctx);
|
||||||
if (inference.verbosity >= 2) env.print_best = true;
|
if (inference.verbosity >= 2) env.print_best = true;
|
||||||
const Gemma gemma(loader, inference, ctx);
|
const Gemma gemma(loader, inference, ctx);
|
||||||
|
|
|
||||||
|
|
@ -278,9 +278,8 @@ static WeightsPtrs::Mode ChooseMode(uint64_t file_bytes,
|
||||||
|
|
||||||
if (to_bf16 == Tristate::kDefault) {
|
if (to_bf16 == Tristate::kDefault) {
|
||||||
// Heuristic: sub-bf16 compression is not helpful if compute-bound.
|
// Heuristic: sub-bf16 compression is not helpful if compute-bound.
|
||||||
const size_t batch_size =
|
to_bf16 = (inference.decode_qbatch_size >= 128) ? Tristate::kTrue
|
||||||
HWY_MAX(inference.prefill_tbatch_size, inference.decode_qbatch_size);
|
: Tristate::kFalse;
|
||||||
to_bf16 = (batch_size >= 128) ? Tristate::kTrue : Tristate::kFalse;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (map == Tristate::kDefault) {
|
if (map == Tristate::kDefault) {
|
||||||
|
|
|
||||||
|
|
@ -1282,6 +1282,7 @@ struct MMImpl {
|
||||||
PROFILER_ZONE("MM.DoMatMul");
|
PROFILER_ZONE("MM.DoMatMul");
|
||||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DoMatMul.PerPkg");
|
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DoMatMul.PerPkg");
|
||||||
|
|
||||||
|
if constexpr (kMaxPackages > 1) {
|
||||||
// Outermost loop: static NUMA-aware partition of B rows across packages.
|
// Outermost loop: static NUMA-aware partition of B rows across packages.
|
||||||
args.env->parallel.ForPkg(
|
args.env->parallel.ForPkg(
|
||||||
args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) {
|
args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) {
|
||||||
|
|
@ -1290,6 +1291,12 @@ struct MMImpl {
|
||||||
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
||||||
MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows);
|
MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows);
|
||||||
});
|
});
|
||||||
|
} else {
|
||||||
|
const size_t pkg_idx = 0;
|
||||||
|
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
|
||||||
|
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
||||||
|
MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -1333,7 +1340,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
if (HWY_UNLIKELY(index < 0)) {
|
if (HWY_UNLIKELY(index < 0)) {
|
||||||
env.keys.Append(key, allocator);
|
env.keys.Append(key, allocator);
|
||||||
|
|
||||||
size_t max_packages = MMParallel::kMaxPackages;
|
size_t max_packages = kMaxPackages;
|
||||||
// For low-batch, multiple sockets only help if binding is enabled.
|
// For low-batch, multiple sockets only help if binding is enabled.
|
||||||
if (!allocator.ShouldBind() && M <= 4) {
|
if (!allocator.ShouldBind() && M <= 4) {
|
||||||
max_packages = 1;
|
max_packages = 1;
|
||||||
|
|
|
||||||
|
|
@ -441,7 +441,7 @@ void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
|
||||||
PROFILER_ZONE("Startup.BindB");
|
PROFILER_ZONE("Startup.BindB");
|
||||||
|
|
||||||
const IndexRangePartition ranges_np =
|
const IndexRangePartition ranges_np =
|
||||||
parallel.RangesOfNP(MMParallel::kMaxPackages, B.Rows(), sizeof_TC, kNR);
|
parallel.RangesOfNP(kMaxPackages, B.Rows(), sizeof_TC, kNR);
|
||||||
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||||
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
|
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
|
||||||
const size_t node = parallel.Node(pkg_idx);
|
const size_t node = parallel.Node(pkg_idx);
|
||||||
|
|
@ -464,8 +464,8 @@ void BindC(MatPtr& C, MMParallel& parallel) {
|
||||||
|
|
||||||
PROFILER_ZONE("Startup.BindC");
|
PROFILER_ZONE("Startup.BindC");
|
||||||
|
|
||||||
const IndexRangePartition ranges_np = parallel.RangesOfNP(
|
const IndexRangePartition ranges_np =
|
||||||
MMParallel::kMaxPackages, C.Cols(), C.ElementBytes(), kNR);
|
parallel.RangesOfNP(kMaxPackages, C.Cols(), C.ElementBytes(), kNR);
|
||||||
bool ok = true;
|
bool ok = true;
|
||||||
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||||
const IndexRange& cols_c = ranges_np.Range(pkg_idx);
|
const IndexRange& cols_c = ranges_np.Range(pkg_idx);
|
||||||
|
|
|
||||||
21
ops/matmul.h
21
ops/matmul.h
|
|
@ -57,11 +57,12 @@ static constexpr size_t kMaxMR = 4;
|
||||||
// the ThreadingContext to shorten call sites.
|
// the ThreadingContext to shorten call sites.
|
||||||
class MMParallel {
|
class MMParallel {
|
||||||
public:
|
public:
|
||||||
static constexpr size_t kMaxPackages = 4;
|
|
||||||
|
|
||||||
// `ctx` must outlive this object.
|
// `ctx` must outlive this object.
|
||||||
MMParallel(ThreadingContext& ctx) : ctx_(ctx) {
|
MMParallel(ThreadingContext& ctx) : ctx_(ctx) {
|
||||||
HWY_DASSERT(ctx_.pools.NumPackages() <= kMaxPackages);
|
if (ctx_.pools.NumPackages() > kMaxPackages) {
|
||||||
|
HWY_WARN("CPU and --max_packages allow %zu > matmul.h kMaxPackages %zu.",
|
||||||
|
ctx_.pools.NumPackages(), kMaxPackages);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Allocator& allocator() const { return ctx_.allocator; }
|
Allocator& allocator() const { return ctx_.allocator; }
|
||||||
|
|
@ -78,6 +79,7 @@ class MMParallel {
|
||||||
// Calls `func(pkg_idx)` for each package in parallel.
|
// Calls `func(pkg_idx)` for each package in parallel.
|
||||||
template <class Func>
|
template <class Func>
|
||||||
void ForPkg(const size_t max_packages, const Func& func) {
|
void ForPkg(const size_t max_packages, const Func& func) {
|
||||||
|
if constexpr (kMaxPackages > 1) {
|
||||||
ctx_.pools.AllPackages().Run(
|
ctx_.pools.AllPackages().Run(
|
||||||
0, HWY_MIN(max_packages, ctx_.pools.NumPackages()),
|
0, HWY_MIN(max_packages, ctx_.pools.NumPackages()),
|
||||||
[&](uint64_t task, size_t pkg_idx) {
|
[&](uint64_t task, size_t pkg_idx) {
|
||||||
|
|
@ -85,6 +87,9 @@ class MMParallel {
|
||||||
(void)task;
|
(void)task;
|
||||||
func(pkg_idx);
|
func(pkg_idx);
|
||||||
});
|
});
|
||||||
|
} else {
|
||||||
|
func(/*pkg_idx=*/0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
|
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
|
||||||
|
|
@ -257,7 +262,7 @@ class MMStorage {
|
||||||
partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) {
|
partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) {
|
||||||
// Per-package allocation so each can decompress A into its own copy.
|
// Per-package allocation so each can decompress A into its own copy.
|
||||||
// Must be padded, see `DoDecompressA`.
|
// Must be padded, see `DoDecompressA`.
|
||||||
parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) {
|
parallel.ForPkg(kMaxPackages, [&](size_t pkg_idx) {
|
||||||
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
|
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
|
||||||
"pkg_A", Extents2D(kMaxM, kMaxK), allocator, MatPadding::kOdd));
|
"pkg_A", Extents2D(kMaxM, kMaxK), allocator, MatPadding::kOdd));
|
||||||
|
|
||||||
|
|
@ -287,7 +292,7 @@ class MMStorage {
|
||||||
StridedViewD Partial() const { return partial_; }
|
StridedViewD Partial() const { return partial_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<MatStorageT<BF16>> pkg_A_[MMParallel::kMaxPackages];
|
std::unique_ptr<MatStorageT<BF16>> pkg_A_[kMaxPackages];
|
||||||
MatStorageT<double> partial_storage_;
|
MatStorageT<double> partial_storage_;
|
||||||
StridedViewD partial_;
|
StridedViewD partial_;
|
||||||
};
|
};
|
||||||
|
|
@ -646,7 +651,9 @@ class MMKeys {
|
||||||
struct MMPerKey {
|
struct MMPerKey {
|
||||||
MMPerKey(size_t max_packages, size_t N, size_t sizeof_TC, size_t nr,
|
MMPerKey(size_t max_packages, size_t N, size_t sizeof_TC, size_t nr,
|
||||||
MMParallel& parallel)
|
MMParallel& parallel)
|
||||||
: ranges_np(parallel.RangesOfNP(max_packages, N, sizeof_TC, nr)) {}
|
: ranges_np(parallel.RangesOfNP(max_packages, N, sizeof_TC, nr)) {
|
||||||
|
HWY_DASSERT(ranges_np.NumTasks() <= max_packages);
|
||||||
|
}
|
||||||
|
|
||||||
// Only profile if enabled and the main autotuner finished (the par_a
|
// Only profile if enabled and the main autotuner finished (the par_a
|
||||||
// autotuner is per-package and we want to avoid synchronization).
|
// autotuner is per-package and we want to avoid synchronization).
|
||||||
|
|
@ -654,7 +661,7 @@ struct MMPerKey {
|
||||||
|
|
||||||
const IndexRangePartition ranges_np;
|
const IndexRangePartition ranges_np;
|
||||||
MMAutoTune<MMConfig> autotune;
|
MMAutoTune<MMConfig> autotune;
|
||||||
MMAutoTune<MMParA> autotune_par_a[MMParallel::kMaxPackages];
|
MMAutoTune<MMParA> autotune_par_a[kMaxPackages];
|
||||||
};
|
};
|
||||||
|
|
||||||
// Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive
|
// Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive
|
||||||
|
|
|
||||||
|
|
@ -264,7 +264,7 @@ void TestTiny() {
|
||||||
MatMulEnv env(ctx);
|
MatMulEnv env(ctx);
|
||||||
NestedPools& pools = env.ctx.pools;
|
NestedPools& pools = env.ctx.pools;
|
||||||
|
|
||||||
if constexpr (GEMMA_DISABLE_TOPOLOGY) {
|
if constexpr (GEMMA_DISABLE_TOPOLOGY || kMaxPackages == 1) {
|
||||||
if (max_packages == 2) break; // we only have one package
|
if (max_packages == 2) break; // we only have one package
|
||||||
} else {
|
} else {
|
||||||
// If less than the limit, we have already tested all num_packages.
|
// If less than the limit, we have already tested all num_packages.
|
||||||
|
|
|
||||||
|
|
@ -576,8 +576,7 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
|
||||||
HWY_DASSERT(activations.SameShape(out));
|
HWY_DASSERT(activations.SameShape(out));
|
||||||
|
|
||||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||||
const size_t pkg_idx = 0;
|
SmallParallelFor(activations.Rows(), ctx.pools,
|
||||||
SmallParallelFor(activations.Rows(), ctx.pools, pkg_idx,
|
|
||||||
[&](uint64_t token_idx, size_t worker) {
|
[&](uint64_t token_idx, size_t worker) {
|
||||||
RMSNorm(activations.Row(token_idx),
|
RMSNorm(activations.Row(token_idx),
|
||||||
weights_t->PackedScale1(), 0, out.Row(token_idx),
|
weights_t->PackedScale1(), 0, out.Row(token_idx),
|
||||||
|
|
@ -593,8 +592,7 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
|
||||||
HWY_DASSERT(weights.Cols() == inout.Cols());
|
HWY_DASSERT(weights.Cols() == inout.Cols());
|
||||||
|
|
||||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||||
const size_t pkg_idx = 0;
|
SmallParallelFor(inout.Rows(), ctx.pools,
|
||||||
SmallParallelFor(inout.Rows(), ctx.pools, pkg_idx,
|
|
||||||
[&](uint64_t token_idx, size_t worker) {
|
[&](uint64_t token_idx, size_t worker) {
|
||||||
RMSNormInplace(weights_t->PackedScale1(), 0,
|
RMSNormInplace(weights_t->PackedScale1(), 0,
|
||||||
inout.Row(token_idx), inout.Cols(),
|
inout.Row(token_idx), inout.Cols(),
|
||||||
|
|
@ -624,9 +622,8 @@ template <typename XT>
|
||||||
static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
|
static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
|
||||||
ThreadingContext& ctx) {
|
ThreadingContext& ctx) {
|
||||||
HWY_DASSERT(out.SameShape(x));
|
HWY_DASSERT(out.SameShape(x));
|
||||||
const size_t pkg_idx = 0;
|
|
||||||
SmallParallelFor(
|
SmallParallelFor(
|
||||||
out.Rows(), ctx.pools, pkg_idx, [&](uint64_t token_idx, size_t worker) {
|
out.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) {
|
||||||
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), worker);
|
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), worker);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,11 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
// Maximum number of packages (CPU sockets) to use. `ThreadingArgs` verifies the
|
||||||
|
// runtime `max_packages` does not exceed this. MatMul's outer per-package loop
|
||||||
|
// is disabled if this is 1.
|
||||||
|
constexpr size_t kMaxPackages = 1;
|
||||||
|
|
||||||
enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };
|
enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };
|
||||||
|
|
||||||
static inline const char* ToString(Tristate t) {
|
static inline const char* ToString(Tristate t) {
|
||||||
|
|
|
||||||
|
|
@ -324,9 +324,9 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1,
|
||||||
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes
|
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes
|
||||||
// over clusters of ONE package, then within each cluster.
|
// over clusters of ONE package, then within each cluster.
|
||||||
template <class Func>
|
template <class Func>
|
||||||
void ParallelFor(size_t num_tasks, NestedPools& pools, size_t pkg_idx,
|
void ParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) {
|
||||||
const Func& func) {
|
// Even if there are multiple packages, we only use the first.
|
||||||
const size_t pkg_base = pkg_idx * pools.MaxWorkersPerPackage();
|
const size_t pkg_idx = 0;
|
||||||
|
|
||||||
// If few tasks, run on a single cluster. Also avoids a bit of overhead if
|
// If few tasks, run on a single cluster. Also avoids a bit of overhead if
|
||||||
// there is only one cluster.
|
// there is only one cluster.
|
||||||
|
|
@ -335,7 +335,7 @@ void ParallelFor(size_t num_tasks, NestedPools& pools, size_t pkg_idx,
|
||||||
hwy::ThreadPool& cluster = pools.Cluster(pkg_idx, 0);
|
hwy::ThreadPool& cluster = pools.Cluster(pkg_idx, 0);
|
||||||
if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) {
|
if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) {
|
||||||
return cluster.Run(0, num_tasks, [&](uint64_t task, size_t thread) {
|
return cluster.Run(0, num_tasks, [&](uint64_t task, size_t thread) {
|
||||||
func(task, pkg_base + thread);
|
func(task, thread);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -346,8 +346,7 @@ void ParallelFor(size_t num_tasks, NestedPools& pools, size_t pkg_idx,
|
||||||
ranges, all_clusters,
|
ranges, all_clusters,
|
||||||
[&](const IndexRange& range, const size_t cluster_idx) {
|
[&](const IndexRange& range, const size_t cluster_idx) {
|
||||||
hwy::ThreadPool& cluster = pools.Cluster(pkg_idx, cluster_idx);
|
hwy::ThreadPool& cluster = pools.Cluster(pkg_idx, cluster_idx);
|
||||||
const size_t cluster_base =
|
const size_t cluster_base = cluster_idx * pools.MaxWorkersPerCluster();
|
||||||
pkg_base + cluster_idx * pools.MaxWorkersPerCluster();
|
|
||||||
cluster.Run(range.begin(), range.end(),
|
cluster.Run(range.begin(), range.end(),
|
||||||
[&](uint64_t task, size_t thread) {
|
[&](uint64_t task, size_t thread) {
|
||||||
func(task, cluster_base + thread);
|
func(task, cluster_base + thread);
|
||||||
|
|
@ -357,13 +356,12 @@ void ParallelFor(size_t num_tasks, NestedPools& pools, size_t pkg_idx,
|
||||||
|
|
||||||
// As above, but for lightweight tasks. Uses only one pool.
|
// As above, but for lightweight tasks. Uses only one pool.
|
||||||
template <class Func>
|
template <class Func>
|
||||||
void SmallParallelFor(size_t num_tasks, NestedPools& pools, size_t pkg_idx,
|
void SmallParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) {
|
||||||
const Func& func) {
|
// Even if there are multiple packages, we only use the first.
|
||||||
const size_t pkg_base = pkg_idx * pools.MaxWorkersPerPackage();
|
const size_t pkg_idx = 0;
|
||||||
|
|
||||||
pools.Pool(pkg_idx).Run(0, num_tasks, [&](uint64_t task, size_t thread) {
|
pools.Pool(pkg_idx).Run(
|
||||||
func(task, pkg_base + thread);
|
0, num_tasks, [&](uint64_t task, size_t thread) { func(task, thread); });
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "util/basics.h" // Tristate
|
#include "util/basics.h" // Tristate, kMaxPackages
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
#include "util/topology.h"
|
#include "util/topology.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
|
|
@ -60,8 +60,9 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
||||||
// all available resources.
|
// all available resources.
|
||||||
visitor(skip_packages, "skip_packages", size_t{0},
|
visitor(skip_packages, "skip_packages", size_t{0},
|
||||||
"Index of the first socket to use; default 0 = unlimited.", 2);
|
"Index of the first socket to use; default 0 = unlimited.", 2);
|
||||||
visitor(max_packages, "max_packages", size_t{0},
|
visitor(max_packages, "max_packages", size_t{1},
|
||||||
"Max sockets to use; default 0 = all unless large batch size.", 2);
|
"Max sockets to use; default = 1, 0 = unlimited.", 2);
|
||||||
|
HWY_ASSERT(max_packages <= kMaxPackages);
|
||||||
visitor(skip_clusters, "skip_clusters", size_t{0},
|
visitor(skip_clusters, "skip_clusters", size_t{0},
|
||||||
"Index of the first CCX to use; default 0 = unlimited.", 2);
|
"Index of the first CCX to use; default 0 = unlimited.", 2);
|
||||||
visitor(max_clusters, "max_clusters", size_t{0},
|
visitor(max_clusters, "max_clusters", size_t{0},
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue