Add NestedPools: one per socket/cluster

Use in dot_test
app.h: add new flags and rename num_threads to max_threads
matmul: Parallelize MatMulSlow and enable spinning, more large/fewer medium test cases
PiperOrigin-RevId: 683216386
This commit is contained in:
Jan Wassenberg 2024-10-07 09:39:48 -07:00 committed by Copybara-Service
parent bd53b0f7c3
commit 2c28b18eb0
10 changed files with 660 additions and 110 deletions

View File

@ -48,6 +48,17 @@ cc_library(
], ],
) )
cc_test(
name = "threading_test",
srcs = ["util/threading_test.cc"],
deps = [
":threading",
"@googletest//:gtest_main",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
],
)
cc_library( cc_library(
name = "ops", name = "ops",
hdrs = [ hdrs = [
@ -306,6 +317,7 @@ cc_library(
":args", ":args",
":common", ":common",
":gemma_lib", ":gemma_lib",
":threading",
"//compression:io", "//compression:io",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:thread_pool", "@hwy//:thread_pool",

View File

@ -157,20 +157,21 @@ enable_testing()
include(GoogleTest) include(GoogleTest)
set(GEMMA_TEST_FILES set(GEMMA_TEST_FILES
backprop/backward_test.cc
backprop/backward_scalar_test.cc backprop/backward_scalar_test.cc
backprop/backward_test.cc
backprop/optimize_test.cc backprop/optimize_test.cc
compression/compress_test.cc compression/compress_test.cc
compression/distortion_test.cc compression/distortion_test.cc
compression/sfp_test.cc
compression/nuq_test.cc compression/nuq_test.cc
ops/dot_test.cc compression/sfp_test.cc
ops/ops_test.cc
ops/matmul_test.cc
ops/gemma_matvec_test.cc
evals/gemma_test.cc evals/gemma_test.cc
ops/dot_test.cc
ops/gemma_matvec_test.cc
ops/matmul_test.cc
ops/ops_test.cc
paligemma/image_test.cc paligemma/image_test.cc
paligemma/paligemma_test.cc paligemma/paligemma_test.cc
util/threading_test.cc
) )
foreach (TESTFILE IN LISTS GEMMA_TEST_FILES) foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)

View File

@ -59,7 +59,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
const AppArgs& app) const AppArgs& app)
: pools_(app.max_clusters, app.num_threads, app.pin) { : pools_(app.max_clusters, app.max_threads, app.pin) {
InferenceArgs mutable_inference = inference; InferenceArgs mutable_inference = inference;
AbortIfInvalidArgs(mutable_inference); AbortIfInvalidArgs(mutable_inference);
LoaderArgs mutable_loader = loader; LoaderArgs mutable_loader = loader;
@ -232,6 +232,7 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
char cpu100[100] = "unknown"; char cpu100[100] = "unknown";
(void)hwy::platform::GetCpuString(cpu100); (void)hwy::platform::GetCpuString(cpu100);
// TODO: call TopologyString() once we have NestedPools.
const std::vector<hwy::LogicalProcessorSet>& clusters = const std::vector<hwy::LogicalProcessorSet>& clusters =
pools.CoresPerCluster(); pools.CoresPerCluster();
const size_t per_cluster = const size_t per_cluster =

View File

@ -55,7 +55,7 @@ int main(int argc, char** argv) {
} }
// Instantiate model and KV Cache // Instantiate model and KV Cache
gcpp::PerClusterPools pools(app.max_clusters, app.num_threads, app.pin); gcpp::PerClusterPools pools(app.max_clusters, app.max_threads, app.pin);
gcpp::Gemma model = gcpp::CreateGemma(loader, pools); gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
gcpp::KVCache kv_cache = gcpp::KVCache kv_cache =
gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size); gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size);

View File

@ -194,7 +194,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
// Note that num_threads is an upper bound; we also limit to the number of // Note that num_threads is an upper bound; we also limit to the number of
// detected and enabled cores. // detected and enabled cores.
PerClusterPools pools(app.max_clusters, app.num_threads, app.pin); PerClusterPools pools(app.max_clusters, app.max_threads, app.pin);
Gemma model = CreateGemma(loader, pools); Gemma model = CreateGemma(loader, pools);
KVCache kv_cache = KVCache kv_cache =

View File

@ -814,7 +814,7 @@ class DotStats {
// Forward relative error, lower is better. // Forward relative error, lower is better.
void CheckRel() const { void CheckRel() const {
ASSERT_INSIDE(kComp2, 2E-4, s_rels[kComp2].GeometricMean(), 3.7E-3); ASSERT_INSIDE(kComp2, 2E-4, s_rels[kComp2].GeometricMean(), 3.7E-3);
ASSERT_INSIDE(kComp2, 1E-5f, s_rels[kComp2].Max(), 0.4f); ASSERT_INSIDE(kComp2, 1E-5f, s_rels[kComp2].Max(), 1.23f);
// Compensated and Double are very accurate. // Compensated and Double are very accurate.
ASSERT_LESS(kCompensated, s_rels[kCompensated].Min(), 1E-8f); ASSERT_LESS(kCompensated, s_rels[kCompensated].Min(), 1E-8f);
@ -1096,6 +1096,7 @@ void TestAllDot() {
return; return;
} }
{ // ensure no profiler zones are active
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
constexpr size_t kMaxWorkers = 15; constexpr size_t kMaxWorkers = 15;
@ -1106,19 +1107,21 @@ void TestAllDot() {
constexpr size_t kReps = hn::AdjustedReps(40); constexpr size_t kReps = hn::AdjustedReps(40);
const size_t num = 24 * 1024; const size_t num = 24 * 1024;
PerClusterPools pools(/*max_clusters=*/1, kMaxWorkers - 1, /*pin=*/1); NestedPools pools(kMaxWorkers - 1, /*pin=*/1, BoundedSlice(0, 1),
BoundedSlice(0, 1));
RowVectorBatch<float> a(kMaxWorkers, num); RowVectorBatch<float> a(kMaxWorkers, num);
RowVectorBatch<float> b(kMaxWorkers, num); RowVectorBatch<float> b(kMaxWorkers, num);
RowVectorBatch<double> bufs(kMaxWorkers, num); RowVectorBatch<double> bufs(kMaxWorkers, num);
std::array<DotStats, kMaxWorkers> all_stats; std::array<DotStats, kMaxWorkers> all_stats;
pools.Inner(0).Run(0, kReps, [&](const uint32_t rep, size_t thread) { pools.Cluster(0, 0).Run(0, kReps, [&](const uint32_t rep, size_t thread) {
float* HWY_RESTRICT pa = a.Batch(thread); float* HWY_RESTRICT pa = a.Batch(thread);
float* HWY_RESTRICT pb = b.Batch(thread); float* HWY_RESTRICT pb = b.Batch(thread);
double* HWY_RESTRICT buf = bufs.Batch(thread); double* HWY_RESTRICT buf = bufs.Batch(thread);
const PackedSpan<const float> a_span(pa, num); const PackedSpan<const float> a_span(pa, num);
DotStats& stats = all_stats[thread]; DotStats& stats = all_stats[thread];
const double cond = GenerateIllConditionedInputs(num, pa, pb, rngs[thread]); const double cond =
GenerateIllConditionedInputs(num, pa, pb, rngs[thread]);
const float dot_exact = ExactDot(pa, pb, num, buf); const float dot_exact = ExactDot(pa, pb, num, buf);
@ -1152,7 +1155,7 @@ void TestAllDot() {
stats.Print(); stats.Print();
} }
stats.Check(); stats.Check();
}
PROFILER_PRINT_RESULTS(); PROFILER_PRINT_RESULTS();
} }

View File

@ -18,6 +18,8 @@
#define HWY_DISABLED_TARGETS HWY_SCALAR #define HWY_DISABLED_TARGETS HWY_SCALAR
#endif #endif
#include "ops/matmul.h"
#include <stddef.h> #include <stddef.h>
#include <stdio.h> #include <stdio.h>
@ -162,42 +164,41 @@ void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b,
} }
} }
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup. template <typename MatTA, typename MatTB>
template <typename MatTA, typename MatTB, HWY_IF_NOT_BF16(MatTA)>
HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc,
const MatTA* HWY_RESTRICT a, const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b_trans, const float scale, const MatTB* HWY_RESTRICT b_trans, const float scale,
const float* add, float* HWY_RESTRICT out) { const float* HWY_RESTRICT add, MatMulEnv& env,
const hn::ScalableTag<float> df; float* HWY_RESTRICT out) {
// MatTA can be any Packed except NuqStream because it uses pointer
// arithmetic, because it is the second argument to Dot, which does not
// support a v_ofs.
static_assert(sizeof(MatTA) >= sizeof(BF16), "A matrix must be BF16/f32");
const hn::ScalableTag<float> df; // lane type is ignored
const PackedSpan<const MatTB> b_span = const PackedSpan<const MatTB> b_span =
MakeSpan(b_trans, cols_a_rows_b * cols_bc); MakeSpan(b_trans, cols_a_rows_b * cols_bc);
for (size_t i = 0; i < rows_ac; ++i) {
for (size_t j = 0; j < cols_bc; ++j) { env.Pools().Outer().Run(
out[i * cols_bc + j] = scale * Dot(df, b_span, j * cols_a_rows_b, 0, rows_ac, [&](const uint64_t i, size_t o_thread) HWY_ATTR {
a + i * cols_a_rows_b, cols_a_rows_b); hwy::ThreadPool& inner = env.Pools().Inner(o_thread);
}
if (add != nullptr) { if (add != nullptr) {
for (size_t j = 0; j < cols_bc; ++j) { inner.Run(0, cols_bc, [&](const uint64_t j, size_t i_thread) {
out[i * cols_bc + j] += add[j]; out[i * cols_bc + j] =
} scale * Dot(df, b_span, j * cols_a_rows_b,
} a + i * cols_a_rows_b, cols_a_rows_b) +
add[j];
});
} else {
inner.Run(0, cols_bc, [&](const uint64_t j, size_t i_thread) {
out[i * cols_bc + j] =
scale * Dot(df, b_span, j * cols_a_rows_b,
a + i * cols_a_rows_b, cols_a_rows_b);
});
} }
});
} }
// The above overload can handle A=f32 and any B; handle A=bf16 via Decompress.
template <typename MatTA, typename MatTB, HWY_IF_BF16(MatTA)>
HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc,
const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b_trans, const float scale,
const float* add, float* HWY_RESTRICT out) {
const size_t num_a = cols_a_rows_b * rows_ac;
FloatPtr a_raw = hwy::AllocateAligned<float>(num_a);
HWY_ASSERT(a_raw);
const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, MakeSpan(a, num_a), 0, a_raw.get(), num_a);
MatMulSlow(rows_ac, cols_a_rows_b, cols_bc, a_raw.get(), b_trans, scale, add,
out);
}
void PrintSpeed(const char* algo, size_t rows_ac, size_t cols_a_rows_b, void PrintSpeed(const char* algo, size_t rows_ac, size_t cols_a_rows_b,
size_t cols_bc, double elapsed) { size_t cols_bc, double elapsed) {
const size_t num_b = cols_a_rows_b * cols_bc; const size_t num_b = cols_a_rows_b * cols_bc;
@ -233,7 +234,7 @@ void TestMatMul(MatMulEnv& env) {
GenerateZeroMat<float, kRowsAC, kColsBC>(pool); GenerateZeroMat<float, kRowsAC, kColsBC>(pool);
const double start_slow = hwy::platform::Now(); const double start_slow = hwy::platform::Now();
MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), scale, MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), scale,
kAdd ? add->data() : nullptr, c_slow->data()); kAdd ? add->data() : nullptr, env, c_slow->data());
if (want_bench) { if (want_bench) {
PrintSpeed("MatMulSlow", kRowsAC, kColsARowsB, kColsBC, PrintSpeed("MatMulSlow", kRowsAC, kColsARowsB, kColsBC,
hwy::platform::Now() - start_slow); hwy::platform::Now() - start_slow);
@ -265,22 +266,26 @@ void TestAllMatMul() {
PerClusterPools pools(/*max_clusters=*/1, /*max_threads=*/4, /*pin=*/1); PerClusterPools pools(/*max_clusters=*/1, /*max_threads=*/4, /*pin=*/1);
MatMulEnv env(pools); MatMulEnv env(pools);
pools.StartSpinning();
using F32 = float; using F32 = float;
using SFP = SfpStream; using SFP = SfpStream;
// large-scale test // large-scale test: batch_size=128 is better than 64 or 256 for SKX.
TestMatMul<64, 24576, 3072, /*kAdd=*/false, BF16, SFP>(env); TestMatMul<128, 24576, 3072, /*kAdd=*/false, F32, SFP>(env);
TestMatMul<64, 3072, 24576, /*kAdd=*/false, BF16, SFP>(env); TestMatMul<128, 3072, 24576, /*kAdd=*/false, F32, SFP>(env);
TestMatMul<64, 24576, 3072, /*kAdd=*/false, F32, SFP>(env); TestMatMul<1, 24576, 3072, /*kAdd=*/false, F32, F32>(env);
TestMatMul<64, 3072, 24576, /*kAdd=*/false, F32, SFP>(env); TestMatMul<1, 3072, 24576, /*kAdd=*/false, F32, F32>(env);
// medium-sized square test // medium-sized square test - temporarily disabled for faster testing.
if constexpr (false) {
TestMatMul<512, 512, 512, /*kAdd=*/false, F32>(env); TestMatMul<512, 512, 512, /*kAdd=*/false, F32>(env);
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16>(env); TestMatMul<512, 512, 512, /*kAdd=*/true, BF16>(env);
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>(env); TestMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>(env);
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>(env); TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>(env);
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>(env); TestMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>(env);
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>(env); TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>(env);
}
// minimal non-square test. kColsARowsB must be at least 2 vectors. // minimal non-square test. kColsARowsB must be at least 2 vectors.
TestMatMul<35, 128, 32, /*kAdd=*/false, F32>(env); TestMatMul<35, 128, 32, /*kAdd=*/false, F32>(env);

View File

@ -28,6 +28,7 @@
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/gemma.h" // For CreateGemma #include "gemma/gemma.h" // For CreateGemma
#include "util/args.h" #include "util/args.h"
#include "util/threading.h"
#include "hwy/base.h" // HWY_IS_ASAN #include "hwy/base.h" // HWY_IS_ASAN
namespace gcpp { namespace gcpp {
@ -57,9 +58,15 @@ class AppArgs : public ArgsBase<AppArgs> {
int verbosity; int verbosity;
size_t num_threads; // divided among the detected clusters size_t max_threads; // divided among the detected clusters
size_t max_clusters;
int pin; // -1 = auto, 0 = no, 1 = yes int pin; // -1 = auto, 0 = no, 1 = yes
// For BoundedSlice:
size_t skip_packages;
size_t max_packages;
size_t skip_clusters;
size_t max_clusters;
size_t skip_lps;
size_t max_lps;
std::string eot_line; std::string eot_line;
@ -71,11 +78,23 @@ class AppArgs : public ArgsBase<AppArgs> {
"developer/debug info).\n Default = 1.", "developer/debug info).\n Default = 1.",
2); 2);
visitor(num_threads, "num_threads", size_t{0}, // The exact meaning is more subtle: see the comment at NestedPools ctor.
visitor(max_threads, "num_threads", size_t{0},
"Maximum number of threads to use; default 0 = unlimited.", 2); "Maximum number of threads to use; default 0 = unlimited.", 2);
visitor(max_clusters, "max_clusters", size_t{0},
"Maximum number of sockets/CCXs to use; default 0 = unlimited.", 2);
visitor(pin, "pin", -1, "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2); visitor(pin, "pin", -1, "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
visitor(skip_packages, "skip_packages", size_t{0},
"Index of the first socket to use; default 0 = unlimited.", 2);
visitor(max_packages, "max_packages", size_t{0},
"Maximum number of sockets to use; default 0 = unlimited.", 2);
visitor(skip_clusters, "skip_clusters", size_t{0},
"Index of the first CCX to use; default 0 = unlimited.", 2);
visitor(max_clusters, "max_clusters", size_t{0},
"Maximum number of CCXs to use; default 0 = unlimited.", 2);
// These are only used when CPU topology is unknown.
visitor(skip_lps, "skip_lps", size_t{0},
"Index of the first LP to use; default 0 = unlimited.", 2);
visitor(max_lps, "max_lps", size_t{0},
"Maximum number of LPs to use; default 0 = unlimited.", 2);
visitor( visitor(
eot_line, "eot_line", std::string(""), eot_line, "eot_line", std::string(""),
@ -87,6 +106,13 @@ class AppArgs : public ArgsBase<AppArgs> {
} }
}; };
static inline NestedPools CreatePools(const AppArgs& app) {
return NestedPools(app.max_threads, app.pin,
BoundedSlice(app.skip_packages, app.max_packages),
BoundedSlice(app.skip_clusters, app.max_clusters),
BoundedSlice(app.skip_lps, app.max_lps));
}
struct LoaderArgs : public ArgsBase<LoaderArgs> { struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path, LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,

View File

@ -22,7 +22,8 @@
#include <stdio.h> #include <stdio.h>
#include <algorithm> // std::sort #include <algorithm> // std::sort
#include <memory> #include <memory> // std::unique_ptr
#include <utility> // std::move
#include <vector> #include <vector>
#include "hwy/base.h" // HWY_ASSERT #include "hwy/base.h" // HWY_ASSERT
@ -31,6 +32,7 @@
namespace gcpp { namespace gcpp {
// DEPRECATED, will be replaced by NestedPools once MatMul is updated.
// Owns 'inner' thread pools, one per 'cluster' (CCX or socket), plus an // Owns 'inner' thread pools, one per 'cluster' (CCX or socket), plus an
// 'outer' thread pool with one worker per cluster. // 'outer' thread pool with one worker per cluster.
// //
@ -245,6 +247,405 @@ class PerClusterPools {
std::vector<std::unique_ptr<hwy::ThreadPool>> inner_pools_; std::vector<std::unique_ptr<hwy::ThreadPool>> inner_pools_;
}; };
// A slice of a 1D integer range such as the indices of packages or clusters.
// This allows assigning them to multiple instances of our binary.
struct BoundedSlice {
// Defaults to "use all detected".
BoundedSlice(size_t skip = 0, size_t max = 0) : skip(skip), max(max) {}
// How many to skip, or equivalently, index of the first to use. It is an
// error if this is >= `detected`, because that would leave none for this
// instance to use.
size_t skip;
// Upper bound on the number to use, or zero if no limit.
size_t max;
// STL-style one past the end.
size_t End(size_t detected) const {
return (max == 0) ? detected : HWY_MIN(detected, skip + max);
}
// Number of elements in the slice.
size_t Num(size_t detected) const { return End(detected) - skip; }
template <class Func>
void ForEach(const char* name, size_t detected, const Func& func) {
if (skip >= detected) {
HWY_ABORT("Invalid skip=%zu for %s, detected=%zu", skip, name, detected);
}
for (size_t i = skip; i < End(detected); ++i) {
func(i);
}
}
};
// "LP" is a logical processor, a 0-based index passed to the OS.
using LPS = hwy::LogicalProcessorSet;
// Wraps hwy::Topology and only keeps the subset of packages and clusters
// apportioned by BoundedSlice, further limited by the OS affinity mask.
// NOTE: if topology is unknown or the OS affinity is too restrictive, we fall
// back to a single package and cluster.
class BoundedTopology {
// Sort packages/clusters by descending size so that users who only use one
// get the largest.
template <class Group>
static void SortByDescendingLPs(std::vector<Group>& groups) {
std::sort(groups.begin(), groups.end(), [](const Group& a, const Group& b) {
return a.num_lps > b.num_lps;
});
}
public:
struct Cluster {
// Simple version when topology is unknown.
explicit Cluster(size_t num_workers) : num_lps(num_workers) {
HWY_ASSERT(num_lps != 0);
}
Cluster(const std::vector<hwy::Topology::LP>& all_lps, const LPS& enabled,
size_t package_lp, const hwy::Topology::Cluster& cluster,
LPS& package_lps) {
// All first-hyperthread LPs from the cluster that are enabled and not
// already in use as the package representative.
cluster.lps.Foreach([&](size_t lp) {
if (all_lps[lp].smt == 0 && enabled.Get(lp) && lp != package_lp) {
HWY_ASSERT(!lps.Get(lp));
lps.Set(lp);
HWY_ASSERT(!package_lps.Get(lp));
package_lps.Set(lp);
}
});
num_lps = lps.Count(); // = 0 if all disabled.
}
LPS lps;
size_t num_lps;
// Set by caller to the first of `lps` if there are multiple clusters in a
// package.
size_t cluster_lp = 0;
};
struct Package {
// Simple version when topology is unknown.
explicit Package(size_t num_workers) {
package_lp = 0;
num_lps = num_workers;
clusters.push_back(Cluster(num_workers));
}
Package(size_t package_idx, const hwy::Topology& topology,
const LPS& enabled, BoundedSlice cluster_slice) {
const hwy::Topology::Package& package = topology.packages[package_idx];
package_lp = package.clusters[0].lps.First();
cluster_slice.ForEach(
"cluster", package.clusters.size(), [&](size_t cluster_idx) {
Cluster cluster(topology.lps, enabled, package_lp,
package.clusters[cluster_idx], lps);
if (HWY_LIKELY(cluster.num_lps != 0)) {
num_lps += cluster.num_lps; // before std::move
clusters.push_back(std::move(cluster));
}
});
// Note that it is possible for `clusters` to be empty if its LPs are all
// disabled. If so, the caller will ignore topology and create a single
// package and cluster.
SortByDescendingLPs(clusters);
// If there are multiple clusters, set their first LP to represent the
// cluster and mark them as unavailable for its pool.
if (clusters.size() > 1) {
for (Cluster& cluster : clusters) {
cluster.cluster_lp = cluster.lps.First();
// Nonzero because if lp == 0 were enabled, it would be used as
// `package_lp` and excluded from `cluster.lps`.
HWY_ASSERT(cluster.cluster_lp != 0);
HWY_ASSERT(cluster.cluster_lp != package_lp);
cluster.lps.Clear(cluster.cluster_lp);
}
}
}
size_t package_lp;
LPS lps;
size_t num_lps = 0;
std::vector<Cluster> clusters;
};
BoundedTopology(BoundedSlice package_slice, BoundedSlice cluster_slice,
BoundedSlice lp_slice) {
const bool have_threading_support = hwy::HaveThreadingSupport();
LPS enabled_lps; // LPs not disabled via OS, taskset, or numactl.
bool missing_cluster = false;
if (HWY_LIKELY(have_threading_support)) {
(void)GetThreadAffinity(enabled_lps); // failure = all disabled
// No effect if topology is unknown or `enabled_lps` is empty.
package_slice.ForEach(
"package", topology_.packages.size(), [&](size_t package_idx) {
Package package(package_idx, topology_, enabled_lps, cluster_slice);
// Skip if empty - can happen due to `enabled_lps`.
if (HWY_LIKELY(!package.clusters.empty())) {
total_lps_ += package.num_lps; // before std::move
packages_.push_back(std::move(package));
}
});
for (Package& package : packages_) {
missing_cluster = package.clusters.empty();
if (HWY_UNLIKELY(missing_cluster)) {
fprintf(
stderr,
"Warning, found no clusters for package with %zu LPs.\nWe will "
"ignore topology and assume a single package/cluster.\n",
package.num_lps);
break;
}
}
}
// Topology unknown or any package ended up empty: create a single package
// with one cluster.
if (HWY_UNLIKELY(packages_.empty() || missing_cluster)) {
// We do not bother to detect hyperthreads. Not all CPUs have two per
// core, so instead of dividing, rely on the user's `lp_slice.max`. This
// works because Linux groups LPs by HT.
const size_t num_lps = have_threading_support
? lp_slice.Num(hwy::TotalLogicalProcessors())
: 1;
packages_.clear();
packages_.push_back(Package(num_lps));
total_lps_ = num_lps;
snprintf(topology_string_, sizeof(topology_string_), "LPs=%zu", num_lps);
} else {
SortByDescendingLPs(packages_);
const hwy::Topology::Package& tpackage0 = topology_.packages[0];
HWY_ASSERT(!tpackage0.clusters.empty());
const hwy::Topology::Cluster& tcluster0 = tpackage0.clusters[0];
const Package& package0 = GetPackage(0);
const Cluster& cluster0 = GetCluster(0, 0);
snprintf(topology_string_, sizeof(topology_string_),
"%zux%zux%zu, using %zux%zux%zu", topology_.packages.size(),
tpackage0.clusters.size(), tcluster0.lps.Count(),
packages_.size(), package0.clusters.size(), cluster0.num_lps);
}
HWY_ASSERT(NumPackages() != 0);
for (size_t package_idx = 0; package_idx < NumPackages(); ++package_idx) {
HWY_ASSERT(NumClusters(package_idx) != 0);
}
}
const char* TopologyString() const { return topology_string_; }
size_t NumPackages() const { return packages_.size(); }
const Package& GetPackage(size_t package_idx) const {
HWY_ASSERT(package_idx < NumPackages());
return packages_[package_idx];
}
Package& GetPackage(size_t package_idx) {
HWY_ASSERT(package_idx < NumPackages());
return packages_[package_idx];
}
size_t NumClusters(size_t package_idx) const {
return GetPackage(package_idx).clusters.size();
}
const Cluster& GetCluster(size_t package_idx, size_t cluster_idx) const {
const Package& package = GetPackage(package_idx);
HWY_ASSERT(cluster_idx < package.clusters.size());
return package.clusters[cluster_idx];
}
Cluster& GetCluster(size_t package_idx, size_t cluster_idx) {
Package& package = GetPackage(package_idx);
HWY_ASSERT(cluster_idx < package.clusters.size());
return package.clusters[cluster_idx];
}
// Returns number of logical processors, for allocating per-thread buffers.
size_t NumLP() const { return total_lps_; }
private:
hwy::Topology topology_;
size_t total_lps_ = 0;
std::vector<Package> packages_;
char topology_string_[96];
};
// Creates a hierarchy of thread pools according to BoundedTopology: one with a
// thread per enabled package; for each of those, one with a thread per enabled
// cluster (CCX/shared L3), and for each of those, the remaining enabled cores
// in that cluster. The cores representing each package and cluster are not
// included in the per-cluster pool because we support spin-waiting, hence
// there should be at most one thread per HW core.
//
// Useful when there are tasks which should be parallelized by workers sharing a
// cache, or on the same NUMA node. In both cases, individual pools have lower
// barrier synchronization latency than one large pool. However, to utilize all
// cores, call sites will have to use nested parallel-for loops.
class NestedPools {
public:
// Neither move nor copy.
NestedPools() = delete;
NestedPools(const NestedPools&) = delete;
NestedPools& operator=(const NestedPools&) = delete;
NestedPools(NestedPools&&) = delete;
NestedPools& operator=(NestedPools&&) = delete;
// `max_threads` is the maximum number of threads to divide among all
// clusters. It does not include the package and cluster representatives.
// This is more intuitive than a per-cluster limit for users who may not be
// aware of the CPU topology.
//
// To ensure we do not create more threads than there are HW cores, which
// would cause huge slowdowns when spinning, `BoundedSlice` imposes upper
// bounds on the number of detected packages and clusters rather than
// defining an exact amount.
//
// `pin` is 0 or 1 to force enable/disable, or -1 to choose automatically.
NestedPools(size_t max_threads, int pin = -1,
BoundedSlice package_slice = BoundedSlice(),
BoundedSlice cluster_slice = BoundedSlice(),
BoundedSlice lp_slice = BoundedSlice())
: topology_(package_slice, cluster_slice, lp_slice) {
if (pin == -1) pin = topology_.NumLP() >= 12;
packages_.resize(topology_.NumPackages());
all_packages_ = MakePool(packages_.size());
const size_t max_workers_per_package = max_threads / packages_.size();
// Parallel to ensure we also pin the calling (main) thread.
all_packages_->Run(
0, all_packages_->NumWorkers(),
[&](uint64_t package_idx, size_t thread) {
HWY_ASSERT(package_idx == thread); // each thread has one task
packages_[package_idx] = Package(
topology_, package_idx, max_workers_per_package, pin, lp_slice);
});
}
// Spinning reduces the latency of barrier synchronization, but wastes lots
// of energy for long waits, so only do it during generation. This might
// also be unsafe in virtualized environments because we require threads to
// be running on their own core and thus responsive to the barrier
// synchronization.
void StartSpinning() { SetWaitMode(hwy::PoolWaitMode::kSpin); }
void StopSpinning() { SetWaitMode(hwy::PoolWaitMode::kBlock); }
hwy::ThreadPool& AllPackages() { return *all_packages_; }
hwy::ThreadPool& AllClusters(size_t package_idx) {
HWY_ASSERT(package_idx < AllPackages().NumWorkers());
return *packages_[package_idx].all_clusters;
}
hwy::ThreadPool& Cluster(size_t package_idx, size_t cluster_idx) {
HWY_ASSERT(cluster_idx < AllClusters(package_idx).NumWorkers());
return *packages_[package_idx].clusters[cluster_idx];
}
const char* TopologyString() const { return topology_.TopologyString(); }
// Returns number of logical processors, for allocating per-thread buffers.
size_t NumLP() const { return topology_.NumLP(); }
private:
// `max_or_zero` == 0 means no limit.
static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) {
return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero);
}
// We want vectors of hwy::ThreadPool, which is unfortunately not movable,
// hence we wrap them in unique_ptr.
using PoolPtr = std::unique_ptr<hwy::ThreadPool>;
static PoolPtr MakePool(size_t num_workers) {
// `ThreadPool` expects the number of threads to create, which is one less
// than the number of workers, but avoid underflow if zero.
const size_t num_threads = num_workers == 0 ? 0 : num_workers - 1;
return std::make_unique<hwy::ThreadPool>(num_threads);
}
class Package {
static PoolPtr CreateClusterPool(const BoundedTopology::Cluster& cluster,
size_t max_cluster_workers, int pin,
BoundedSlice lp_slice) {
PoolPtr pool =
MakePool(CapIfNonZero(cluster.num_lps, max_cluster_workers));
if (!pin) return pool;
// Else: pin all new threads AND the calling thread from `all_clusters`.
// We know the topology: pin to this cluster's cores, including the
// calling thread from `all_clusters`.
if (cluster.lps.Any()) {
std::vector<size_t> lps;
lps.reserve(cluster.num_lps);
cluster.lps.Foreach([&lps](size_t lp) { lps.push_back(lp); });
pool->Run(0, pool->NumWorkers(), [&lps](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each worker has one task
hwy::PinThreadToLogicalProcessor(lps[task]);
});
} else {
// Pin to consecutive LPs.
pool->Run(0, pool->NumWorkers(),
[lp_slice](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each worker has one task
hwy::PinThreadToLogicalProcessor(lp_slice.skip + thread);
});
}
return pool;
}
public:
Package() = default; // for vector
Package(const BoundedTopology& topology, size_t package_idx,
size_t max_workers_per_package, int pin, BoundedSlice lp_slice) {
clusters.resize(topology.NumClusters(package_idx));
const size_t max_workers_per_cluster =
max_workers_per_package / clusters.size();
all_clusters = MakePool(clusters.size());
// Parallel so we also pin the calling thread from `all_packages_`.
all_clusters->Run(
0, all_clusters->NumWorkers(),
[&](size_t cluster_idx, size_t thread) {
HWY_ASSERT(cluster_idx == thread); // each thread has one task
const BoundedTopology::Cluster& cluster =
topology.GetCluster(package_idx, cluster_idx);
clusters[cluster_idx] = CreateClusterPool(
cluster, max_workers_per_cluster, pin, lp_slice);
});
}
std::vector<PoolPtr> clusters;
PoolPtr all_clusters;
};
void SetWaitMode(hwy::PoolWaitMode wait_mode) {
all_packages_->SetWaitMode(wait_mode);
for (Package& package : packages_) {
package.all_clusters->SetWaitMode(wait_mode);
for (PoolPtr& cluster : package.clusters) {
cluster->SetWaitMode(wait_mode);
}
}
}
BoundedTopology topology_;
std::vector<Package> packages_;
PoolPtr all_packages_;
};
static inline NestedPools CreateSinglePool(size_t max_threads, int pin = -1) {
const BoundedSlice one(0, 1);
return NestedPools(max_threads, pin, one, one);
}
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ #endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_

101
util/threading_test.cc Normal file
View File

@ -0,0 +1,101 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "util/threading.h"
#include <stddef.h>
#include <stdio.h>
#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "hwy/base.h" // HWY_ASSERT
namespace gcpp {
namespace {
using ::testing::ElementsAre;
TEST(ThreadingTest, TestBoundedSlice) {
const char* name = "test";
// No args = no limit.
{
BoundedSlice slice;
std::vector<size_t> expected;
slice.ForEach(name, 10, [&](size_t i) { expected.push_back(i); });
EXPECT_EQ(10, slice.Num(10));
EXPECT_THAT(expected, ElementsAre(0, 1, 2, 3, 4, 5, 6, 7, 8, 9));
}
// One arg: skip first N
{
BoundedSlice slice(3);
std::vector<size_t> expected;
slice.ForEach(name, 9, [&](size_t i) { expected.push_back(i); });
EXPECT_EQ(6, slice.Num(9));
EXPECT_THAT(expected, ElementsAre(3, 4, 5, 6, 7, 8));
}
// Both args: skip first N, then use at most M
{
BoundedSlice slice(3, 2);
std::vector<size_t> expected;
slice.ForEach(name, 9, [&](size_t i) { expected.push_back(i); });
EXPECT_EQ(2, slice.Num(9));
EXPECT_THAT(expected, ElementsAre(3, 4));
}
// Both args, but `max > detected - skip`: fewer than limit. Note that
// `skip >= detected` is an error.
{
BoundedSlice slice(3, 2);
std::vector<size_t> expected;
slice.ForEach(name, 4, [&](size_t i) { expected.push_back(i); });
EXPECT_EQ(1, slice.Num(4));
EXPECT_THAT(expected, ElementsAre(3));
}
}
TEST(ThreadingTest, TestBoundedTopology) {
const BoundedSlice all;
const BoundedSlice one(0, 1);
// All
{
BoundedTopology topology(all, all, all);
fprintf(stderr, "%s\n", topology.TopologyString());
ASSERT_NE(0, topology.NumPackages());
ASSERT_NE(0, topology.NumClusters(0));
}
// Max one package
{
BoundedTopology topology(one, all, all);
fprintf(stderr, "%s\n", topology.TopologyString());
ASSERT_EQ(1, topology.NumPackages());
ASSERT_NE(0, topology.NumClusters(0));
}
// Max one cluster
{
BoundedTopology topology(all, one, all);
fprintf(stderr, "%s\n", topology.TopologyString());
ASSERT_NE(0, topology.NumPackages());
ASSERT_EQ(1, topology.NumClusters(0));
}
}
} // namespace
} // namespace gcpp