mirror of https://github.com/google/gemma.cpp.git
Add NestedPools: one per socket/cluster
Use in dot_test app.h: add new flags and rename num_threads to max_threads matmul: Parallelize MatMulSlow and enable spinning, more large/fewer medium test cases PiperOrigin-RevId: 683216386
This commit is contained in:
parent
bd53b0f7c3
commit
2c28b18eb0
12
BUILD.bazel
12
BUILD.bazel
|
|
@ -48,6 +48,17 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "threading_test",
|
||||
srcs = ["util/threading_test.cc"],
|
||||
deps = [
|
||||
":threading",
|
||||
"@googletest//:gtest_main",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:hwy_test_util",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ops",
|
||||
hdrs = [
|
||||
|
|
@ -306,6 +317,7 @@ cc_library(
|
|||
":args",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":threading",
|
||||
"//compression:io",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:thread_pool",
|
||||
|
|
|
|||
|
|
@ -157,20 +157,21 @@ enable_testing()
|
|||
include(GoogleTest)
|
||||
|
||||
set(GEMMA_TEST_FILES
|
||||
backprop/backward_test.cc
|
||||
backprop/backward_scalar_test.cc
|
||||
backprop/backward_test.cc
|
||||
backprop/optimize_test.cc
|
||||
compression/compress_test.cc
|
||||
compression/distortion_test.cc
|
||||
compression/sfp_test.cc
|
||||
compression/nuq_test.cc
|
||||
ops/dot_test.cc
|
||||
ops/ops_test.cc
|
||||
ops/matmul_test.cc
|
||||
ops/gemma_matvec_test.cc
|
||||
compression/sfp_test.cc
|
||||
evals/gemma_test.cc
|
||||
ops/dot_test.cc
|
||||
ops/gemma_matvec_test.cc
|
||||
ops/matmul_test.cc
|
||||
ops/ops_test.cc
|
||||
paligemma/image_test.cc
|
||||
paligemma/paligemma_test.cc
|
||||
util/threading_test.cc
|
||||
)
|
||||
|
||||
foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
|
|||
|
||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
const AppArgs& app)
|
||||
: pools_(app.max_clusters, app.num_threads, app.pin) {
|
||||
: pools_(app.max_clusters, app.max_threads, app.pin) {
|
||||
InferenceArgs mutable_inference = inference;
|
||||
AbortIfInvalidArgs(mutable_inference);
|
||||
LoaderArgs mutable_loader = loader;
|
||||
|
|
@ -232,6 +232,7 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
|
|||
char cpu100[100] = "unknown";
|
||||
(void)hwy::platform::GetCpuString(cpu100);
|
||||
|
||||
// TODO: call TopologyString() once we have NestedPools.
|
||||
const std::vector<hwy::LogicalProcessorSet>& clusters =
|
||||
pools.CoresPerCluster();
|
||||
const size_t per_cluster =
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
|
||||
// Instantiate model and KV Cache
|
||||
gcpp::PerClusterPools pools(app.max_clusters, app.num_threads, app.pin);
|
||||
gcpp::PerClusterPools pools(app.max_clusters, app.max_threads, app.pin);
|
||||
gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
|
||||
gcpp::KVCache kv_cache =
|
||||
gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size);
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
|
||||
// Note that num_threads is an upper bound; we also limit to the number of
|
||||
// detected and enabled cores.
|
||||
PerClusterPools pools(app.max_clusters, app.num_threads, app.pin);
|
||||
PerClusterPools pools(app.max_clusters, app.max_threads, app.pin);
|
||||
|
||||
Gemma model = CreateGemma(loader, pools);
|
||||
KVCache kv_cache =
|
||||
|
|
|
|||
109
ops/dot_test.cc
109
ops/dot_test.cc
|
|
@ -814,7 +814,7 @@ class DotStats {
|
|||
// Forward relative error, lower is better.
|
||||
void CheckRel() const {
|
||||
ASSERT_INSIDE(kComp2, 2E-4, s_rels[kComp2].GeometricMean(), 3.7E-3);
|
||||
ASSERT_INSIDE(kComp2, 1E-5f, s_rels[kComp2].Max(), 0.4f);
|
||||
ASSERT_INSIDE(kComp2, 1E-5f, s_rels[kComp2].Max(), 1.23f);
|
||||
|
||||
// Compensated and Double are very accurate.
|
||||
ASSERT_LESS(kCompensated, s_rels[kCompensated].Min(), 1E-8f);
|
||||
|
|
@ -1096,63 +1096,66 @@ void TestAllDot() {
|
|||
return;
|
||||
}
|
||||
|
||||
const hn::ScalableTag<float> df;
|
||||
{ // ensure no profiler zones are active
|
||||
const hn::ScalableTag<float> df;
|
||||
|
||||
constexpr size_t kMaxWorkers = 15;
|
||||
std::mt19937 rngs[kMaxWorkers];
|
||||
for (size_t i = 0; i < kMaxWorkers; ++i) {
|
||||
rngs[i].seed(12345 + 65537 * i);
|
||||
}
|
||||
|
||||
constexpr size_t kReps = hn::AdjustedReps(40);
|
||||
const size_t num = 24 * 1024;
|
||||
PerClusterPools pools(/*max_clusters=*/1, kMaxWorkers - 1, /*pin=*/1);
|
||||
RowVectorBatch<float> a(kMaxWorkers, num);
|
||||
RowVectorBatch<float> b(kMaxWorkers, num);
|
||||
RowVectorBatch<double> bufs(kMaxWorkers, num);
|
||||
std::array<DotStats, kMaxWorkers> all_stats;
|
||||
|
||||
pools.Inner(0).Run(0, kReps, [&](const uint32_t rep, size_t thread) {
|
||||
float* HWY_RESTRICT pa = a.Batch(thread);
|
||||
float* HWY_RESTRICT pb = b.Batch(thread);
|
||||
double* HWY_RESTRICT buf = bufs.Batch(thread);
|
||||
const PackedSpan<const float> a_span(pa, num);
|
||||
DotStats& stats = all_stats[thread];
|
||||
const double cond = GenerateIllConditionedInputs(num, pa, pb, rngs[thread]);
|
||||
|
||||
const float dot_exact = ExactDot(pa, pb, num, buf);
|
||||
|
||||
float dots[kVariants] = {};
|
||||
double times[kVariants] = {};
|
||||
for (size_t variant = 0; variant < kVariants; ++variant) {
|
||||
constexpr size_t kTimeReps = hn::AdjustedReps(10);
|
||||
std::array<double, kTimeReps> elapsed;
|
||||
for (int time_rep = 0; time_rep < kTimeReps; ++time_rep) {
|
||||
const double start = hwy::platform::Now();
|
||||
dots[variant] += CallDot(df, variant, a_span, /*w_ofs=*/0, pb, num);
|
||||
hwy::PreventElision(*pa);
|
||||
elapsed[time_rep] = hwy::platform::Now() - start;
|
||||
}
|
||||
dots[variant] /= kTimeReps;
|
||||
times[variant] = TrimmedMean(elapsed.data(), kTimeReps);
|
||||
constexpr size_t kMaxWorkers = 15;
|
||||
std::mt19937 rngs[kMaxWorkers];
|
||||
for (size_t i = 0; i < kMaxWorkers; ++i) {
|
||||
rngs[i].seed(12345 + 65537 * i);
|
||||
}
|
||||
|
||||
stats.NotifyTimes(times);
|
||||
stats.NotifyRep(num, cond, dot_exact, dots);
|
||||
stats.NotifyRatios();
|
||||
});
|
||||
constexpr size_t kReps = hn::AdjustedReps(40);
|
||||
const size_t num = 24 * 1024;
|
||||
NestedPools pools(kMaxWorkers - 1, /*pin=*/1, BoundedSlice(0, 1),
|
||||
BoundedSlice(0, 1));
|
||||
RowVectorBatch<float> a(kMaxWorkers, num);
|
||||
RowVectorBatch<float> b(kMaxWorkers, num);
|
||||
RowVectorBatch<double> bufs(kMaxWorkers, num);
|
||||
std::array<DotStats, kMaxWorkers> all_stats;
|
||||
|
||||
DotStats& stats = all_stats[0];
|
||||
for (size_t i = 1; i < kMaxWorkers; ++i) {
|
||||
stats.Assimilate(all_stats[i]);
|
||||
}
|
||||
static bool once = true;
|
||||
if (once) {
|
||||
once = false;
|
||||
stats.Print();
|
||||
}
|
||||
stats.Check();
|
||||
pools.Cluster(0, 0).Run(0, kReps, [&](const uint32_t rep, size_t thread) {
|
||||
float* HWY_RESTRICT pa = a.Batch(thread);
|
||||
float* HWY_RESTRICT pb = b.Batch(thread);
|
||||
double* HWY_RESTRICT buf = bufs.Batch(thread);
|
||||
const PackedSpan<const float> a_span(pa, num);
|
||||
DotStats& stats = all_stats[thread];
|
||||
const double cond =
|
||||
GenerateIllConditionedInputs(num, pa, pb, rngs[thread]);
|
||||
|
||||
const float dot_exact = ExactDot(pa, pb, num, buf);
|
||||
|
||||
float dots[kVariants] = {};
|
||||
double times[kVariants] = {};
|
||||
for (size_t variant = 0; variant < kVariants; ++variant) {
|
||||
constexpr size_t kTimeReps = hn::AdjustedReps(10);
|
||||
std::array<double, kTimeReps> elapsed;
|
||||
for (int time_rep = 0; time_rep < kTimeReps; ++time_rep) {
|
||||
const double start = hwy::platform::Now();
|
||||
dots[variant] += CallDot(df, variant, a_span, /*w_ofs=*/0, pb, num);
|
||||
hwy::PreventElision(*pa);
|
||||
elapsed[time_rep] = hwy::platform::Now() - start;
|
||||
}
|
||||
dots[variant] /= kTimeReps;
|
||||
times[variant] = TrimmedMean(elapsed.data(), kTimeReps);
|
||||
}
|
||||
|
||||
stats.NotifyTimes(times);
|
||||
stats.NotifyRep(num, cond, dot_exact, dots);
|
||||
stats.NotifyRatios();
|
||||
});
|
||||
|
||||
DotStats& stats = all_stats[0];
|
||||
for (size_t i = 1; i < kMaxWorkers; ++i) {
|
||||
stats.Assimilate(all_stats[i]);
|
||||
}
|
||||
static bool once = true;
|
||||
if (once) {
|
||||
once = false;
|
||||
stats.Print();
|
||||
}
|
||||
stats.Check();
|
||||
}
|
||||
PROFILER_PRINT_RESULTS();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@
|
|||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||
#endif
|
||||
|
||||
#include "ops/matmul.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
|
|
@ -162,42 +164,41 @@ void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b,
|
|||
}
|
||||
}
|
||||
|
||||
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup.
|
||||
template <typename MatTA, typename MatTB, HWY_IF_NOT_BF16(MatTA)>
|
||||
template <typename MatTA, typename MatTB>
|
||||
HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc,
|
||||
const MatTA* HWY_RESTRICT a,
|
||||
const MatTB* HWY_RESTRICT b_trans, const float scale,
|
||||
const float* add, float* HWY_RESTRICT out) {
|
||||
const hn::ScalableTag<float> df;
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
float* HWY_RESTRICT out) {
|
||||
// MatTA can be any Packed except NuqStream because it uses pointer
|
||||
// arithmetic, because it is the second argument to Dot, which does not
|
||||
// support a v_ofs.
|
||||
static_assert(sizeof(MatTA) >= sizeof(BF16), "A matrix must be BF16/f32");
|
||||
|
||||
const hn::ScalableTag<float> df; // lane type is ignored
|
||||
const PackedSpan<const MatTB> b_span =
|
||||
MakeSpan(b_trans, cols_a_rows_b * cols_bc);
|
||||
for (size_t i = 0; i < rows_ac; ++i) {
|
||||
for (size_t j = 0; j < cols_bc; ++j) {
|
||||
out[i * cols_bc + j] = scale * Dot(df, b_span, j * cols_a_rows_b,
|
||||
a + i * cols_a_rows_b, cols_a_rows_b);
|
||||
}
|
||||
if (add != nullptr) {
|
||||
for (size_t j = 0; j < cols_bc; ++j) {
|
||||
out[i * cols_bc + j] += add[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
env.Pools().Outer().Run(
|
||||
0, rows_ac, [&](const uint64_t i, size_t o_thread) HWY_ATTR {
|
||||
hwy::ThreadPool& inner = env.Pools().Inner(o_thread);
|
||||
if (add != nullptr) {
|
||||
inner.Run(0, cols_bc, [&](const uint64_t j, size_t i_thread) {
|
||||
out[i * cols_bc + j] =
|
||||
scale * Dot(df, b_span, j * cols_a_rows_b,
|
||||
a + i * cols_a_rows_b, cols_a_rows_b) +
|
||||
add[j];
|
||||
});
|
||||
} else {
|
||||
inner.Run(0, cols_bc, [&](const uint64_t j, size_t i_thread) {
|
||||
out[i * cols_bc + j] =
|
||||
scale * Dot(df, b_span, j * cols_a_rows_b,
|
||||
a + i * cols_a_rows_b, cols_a_rows_b);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// The above overload can handle A=f32 and any B; handle A=bf16 via Decompress.
|
||||
template <typename MatTA, typename MatTB, HWY_IF_BF16(MatTA)>
|
||||
HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc,
|
||||
const MatTA* HWY_RESTRICT a,
|
||||
const MatTB* HWY_RESTRICT b_trans, const float scale,
|
||||
const float* add, float* HWY_RESTRICT out) {
|
||||
const size_t num_a = cols_a_rows_b * rows_ac;
|
||||
FloatPtr a_raw = hwy::AllocateAligned<float>(num_a);
|
||||
HWY_ASSERT(a_raw);
|
||||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(df, MakeSpan(a, num_a), 0, a_raw.get(), num_a);
|
||||
MatMulSlow(rows_ac, cols_a_rows_b, cols_bc, a_raw.get(), b_trans, scale, add,
|
||||
out);
|
||||
}
|
||||
void PrintSpeed(const char* algo, size_t rows_ac, size_t cols_a_rows_b,
|
||||
size_t cols_bc, double elapsed) {
|
||||
const size_t num_b = cols_a_rows_b * cols_bc;
|
||||
|
|
@ -233,7 +234,7 @@ void TestMatMul(MatMulEnv& env) {
|
|||
GenerateZeroMat<float, kRowsAC, kColsBC>(pool);
|
||||
const double start_slow = hwy::platform::Now();
|
||||
MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), scale,
|
||||
kAdd ? add->data() : nullptr, c_slow->data());
|
||||
kAdd ? add->data() : nullptr, env, c_slow->data());
|
||||
if (want_bench) {
|
||||
PrintSpeed("MatMulSlow", kRowsAC, kColsARowsB, kColsBC,
|
||||
hwy::platform::Now() - start_slow);
|
||||
|
|
@ -265,22 +266,26 @@ void TestAllMatMul() {
|
|||
|
||||
PerClusterPools pools(/*max_clusters=*/1, /*max_threads=*/4, /*pin=*/1);
|
||||
MatMulEnv env(pools);
|
||||
pools.StartSpinning();
|
||||
|
||||
using F32 = float;
|
||||
using SFP = SfpStream;
|
||||
|
||||
// large-scale test
|
||||
TestMatMul<64, 24576, 3072, /*kAdd=*/false, BF16, SFP>(env);
|
||||
TestMatMul<64, 3072, 24576, /*kAdd=*/false, BF16, SFP>(env);
|
||||
TestMatMul<64, 24576, 3072, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<64, 3072, 24576, /*kAdd=*/false, F32, SFP>(env);
|
||||
// large-scale test: batch_size=128 is better than 64 or 256 for SKX.
|
||||
TestMatMul<128, 24576, 3072, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<128, 3072, 24576, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<1, 24576, 3072, /*kAdd=*/false, F32, F32>(env);
|
||||
TestMatMul<1, 3072, 24576, /*kAdd=*/false, F32, F32>(env);
|
||||
|
||||
// medium-sized square test
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/false, F32>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>(env);
|
||||
// medium-sized square test - temporarily disabled for faster testing.
|
||||
if constexpr (false) {
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/false, F32>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>(env);
|
||||
}
|
||||
|
||||
// minimal non-square test. kColsARowsB must be at least 2 vectors.
|
||||
TestMatMul<35, 128, 32, /*kAdd=*/false, F32>(env);
|
||||
|
|
|
|||
36
util/app.h
36
util/app.h
|
|
@ -28,6 +28,7 @@
|
|||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h" // For CreateGemma
|
||||
#include "util/args.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h" // HWY_IS_ASAN
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -57,9 +58,15 @@ class AppArgs : public ArgsBase<AppArgs> {
|
|||
|
||||
int verbosity;
|
||||
|
||||
size_t num_threads; // divided among the detected clusters
|
||||
size_t max_clusters;
|
||||
size_t max_threads; // divided among the detected clusters
|
||||
int pin; // -1 = auto, 0 = no, 1 = yes
|
||||
// For BoundedSlice:
|
||||
size_t skip_packages;
|
||||
size_t max_packages;
|
||||
size_t skip_clusters;
|
||||
size_t max_clusters;
|
||||
size_t skip_lps;
|
||||
size_t max_lps;
|
||||
|
||||
std::string eot_line;
|
||||
|
||||
|
|
@ -71,11 +78,23 @@ class AppArgs : public ArgsBase<AppArgs> {
|
|||
"developer/debug info).\n Default = 1.",
|
||||
2);
|
||||
|
||||
visitor(num_threads, "num_threads", size_t{0},
|
||||
// The exact meaning is more subtle: see the comment at NestedPools ctor.
|
||||
visitor(max_threads, "num_threads", size_t{0},
|
||||
"Maximum number of threads to use; default 0 = unlimited.", 2);
|
||||
visitor(max_clusters, "max_clusters", size_t{0},
|
||||
"Maximum number of sockets/CCXs to use; default 0 = unlimited.", 2);
|
||||
visitor(pin, "pin", -1, "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
|
||||
visitor(skip_packages, "skip_packages", size_t{0},
|
||||
"Index of the first socket to use; default 0 = unlimited.", 2);
|
||||
visitor(max_packages, "max_packages", size_t{0},
|
||||
"Maximum number of sockets to use; default 0 = unlimited.", 2);
|
||||
visitor(skip_clusters, "skip_clusters", size_t{0},
|
||||
"Index of the first CCX to use; default 0 = unlimited.", 2);
|
||||
visitor(max_clusters, "max_clusters", size_t{0},
|
||||
"Maximum number of CCXs to use; default 0 = unlimited.", 2);
|
||||
// These are only used when CPU topology is unknown.
|
||||
visitor(skip_lps, "skip_lps", size_t{0},
|
||||
"Index of the first LP to use; default 0 = unlimited.", 2);
|
||||
visitor(max_lps, "max_lps", size_t{0},
|
||||
"Maximum number of LPs to use; default 0 = unlimited.", 2);
|
||||
|
||||
visitor(
|
||||
eot_line, "eot_line", std::string(""),
|
||||
|
|
@ -87,6 +106,13 @@ class AppArgs : public ArgsBase<AppArgs> {
|
|||
}
|
||||
};
|
||||
|
||||
static inline NestedPools CreatePools(const AppArgs& app) {
|
||||
return NestedPools(app.max_threads, app.pin,
|
||||
BoundedSlice(app.skip_packages, app.max_packages),
|
||||
BoundedSlice(app.skip_clusters, app.max_clusters),
|
||||
BoundedSlice(app.skip_lps, app.max_lps));
|
||||
}
|
||||
|
||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
|
||||
|
|
|
|||
403
util/threading.h
403
util/threading.h
|
|
@ -22,7 +22,8 @@
|
|||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::sort
|
||||
#include <memory>
|
||||
#include <memory> // std::unique_ptr
|
||||
#include <utility> // std::move
|
||||
#include <vector>
|
||||
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
|
|
@ -31,6 +32,7 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
// DEPRECATED, will be replaced by NestedPools once MatMul is updated.
|
||||
// Owns 'inner' thread pools, one per 'cluster' (CCX or socket), plus an
|
||||
// 'outer' thread pool with one worker per cluster.
|
||||
//
|
||||
|
|
@ -245,6 +247,405 @@ class PerClusterPools {
|
|||
std::vector<std::unique_ptr<hwy::ThreadPool>> inner_pools_;
|
||||
};
|
||||
|
||||
// A slice of a 1D integer range such as the indices of packages or clusters.
|
||||
// This allows assigning them to multiple instances of our binary.
|
||||
struct BoundedSlice {
|
||||
// Defaults to "use all detected".
|
||||
BoundedSlice(size_t skip = 0, size_t max = 0) : skip(skip), max(max) {}
|
||||
|
||||
// How many to skip, or equivalently, index of the first to use. It is an
|
||||
// error if this is >= `detected`, because that would leave none for this
|
||||
// instance to use.
|
||||
size_t skip;
|
||||
|
||||
// Upper bound on the number to use, or zero if no limit.
|
||||
size_t max;
|
||||
|
||||
// STL-style one past the end.
|
||||
size_t End(size_t detected) const {
|
||||
return (max == 0) ? detected : HWY_MIN(detected, skip + max);
|
||||
}
|
||||
|
||||
// Number of elements in the slice.
|
||||
size_t Num(size_t detected) const { return End(detected) - skip; }
|
||||
|
||||
template <class Func>
|
||||
void ForEach(const char* name, size_t detected, const Func& func) {
|
||||
if (skip >= detected) {
|
||||
HWY_ABORT("Invalid skip=%zu for %s, detected=%zu", skip, name, detected);
|
||||
}
|
||||
for (size_t i = skip; i < End(detected); ++i) {
|
||||
func(i);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// "LP" is a logical processor, a 0-based index passed to the OS.
|
||||
using LPS = hwy::LogicalProcessorSet;
|
||||
|
||||
// Wraps hwy::Topology and only keeps the subset of packages and clusters
|
||||
// apportioned by BoundedSlice, further limited by the OS affinity mask.
|
||||
// NOTE: if topology is unknown or the OS affinity is too restrictive, we fall
|
||||
// back to a single package and cluster.
|
||||
class BoundedTopology {
|
||||
// Sort packages/clusters by descending size so that users who only use one
|
||||
// get the largest.
|
||||
template <class Group>
|
||||
static void SortByDescendingLPs(std::vector<Group>& groups) {
|
||||
std::sort(groups.begin(), groups.end(), [](const Group& a, const Group& b) {
|
||||
return a.num_lps > b.num_lps;
|
||||
});
|
||||
}
|
||||
|
||||
public:
|
||||
struct Cluster {
|
||||
// Simple version when topology is unknown.
|
||||
explicit Cluster(size_t num_workers) : num_lps(num_workers) {
|
||||
HWY_ASSERT(num_lps != 0);
|
||||
}
|
||||
|
||||
Cluster(const std::vector<hwy::Topology::LP>& all_lps, const LPS& enabled,
|
||||
size_t package_lp, const hwy::Topology::Cluster& cluster,
|
||||
LPS& package_lps) {
|
||||
// All first-hyperthread LPs from the cluster that are enabled and not
|
||||
// already in use as the package representative.
|
||||
cluster.lps.Foreach([&](size_t lp) {
|
||||
if (all_lps[lp].smt == 0 && enabled.Get(lp) && lp != package_lp) {
|
||||
HWY_ASSERT(!lps.Get(lp));
|
||||
lps.Set(lp);
|
||||
HWY_ASSERT(!package_lps.Get(lp));
|
||||
package_lps.Set(lp);
|
||||
}
|
||||
});
|
||||
num_lps = lps.Count(); // = 0 if all disabled.
|
||||
}
|
||||
|
||||
LPS lps;
|
||||
size_t num_lps;
|
||||
// Set by caller to the first of `lps` if there are multiple clusters in a
|
||||
// package.
|
||||
size_t cluster_lp = 0;
|
||||
};
|
||||
|
||||
struct Package {
|
||||
// Simple version when topology is unknown.
|
||||
explicit Package(size_t num_workers) {
|
||||
package_lp = 0;
|
||||
num_lps = num_workers;
|
||||
clusters.push_back(Cluster(num_workers));
|
||||
}
|
||||
|
||||
Package(size_t package_idx, const hwy::Topology& topology,
|
||||
const LPS& enabled, BoundedSlice cluster_slice) {
|
||||
const hwy::Topology::Package& package = topology.packages[package_idx];
|
||||
package_lp = package.clusters[0].lps.First();
|
||||
cluster_slice.ForEach(
|
||||
"cluster", package.clusters.size(), [&](size_t cluster_idx) {
|
||||
Cluster cluster(topology.lps, enabled, package_lp,
|
||||
package.clusters[cluster_idx], lps);
|
||||
if (HWY_LIKELY(cluster.num_lps != 0)) {
|
||||
num_lps += cluster.num_lps; // before std::move
|
||||
clusters.push_back(std::move(cluster));
|
||||
}
|
||||
});
|
||||
|
||||
// Note that it is possible for `clusters` to be empty if its LPs are all
|
||||
// disabled. If so, the caller will ignore topology and create a single
|
||||
// package and cluster.
|
||||
|
||||
SortByDescendingLPs(clusters);
|
||||
|
||||
// If there are multiple clusters, set their first LP to represent the
|
||||
// cluster and mark them as unavailable for its pool.
|
||||
if (clusters.size() > 1) {
|
||||
for (Cluster& cluster : clusters) {
|
||||
cluster.cluster_lp = cluster.lps.First();
|
||||
// Nonzero because if lp == 0 were enabled, it would be used as
|
||||
// `package_lp` and excluded from `cluster.lps`.
|
||||
HWY_ASSERT(cluster.cluster_lp != 0);
|
||||
HWY_ASSERT(cluster.cluster_lp != package_lp);
|
||||
cluster.lps.Clear(cluster.cluster_lp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t package_lp;
|
||||
LPS lps;
|
||||
size_t num_lps = 0;
|
||||
std::vector<Cluster> clusters;
|
||||
};
|
||||
|
||||
BoundedTopology(BoundedSlice package_slice, BoundedSlice cluster_slice,
|
||||
BoundedSlice lp_slice) {
|
||||
const bool have_threading_support = hwy::HaveThreadingSupport();
|
||||
LPS enabled_lps; // LPs not disabled via OS, taskset, or numactl.
|
||||
bool missing_cluster = false;
|
||||
|
||||
if (HWY_LIKELY(have_threading_support)) {
|
||||
(void)GetThreadAffinity(enabled_lps); // failure = all disabled
|
||||
|
||||
// No effect if topology is unknown or `enabled_lps` is empty.
|
||||
package_slice.ForEach(
|
||||
"package", topology_.packages.size(), [&](size_t package_idx) {
|
||||
Package package(package_idx, topology_, enabled_lps, cluster_slice);
|
||||
// Skip if empty - can happen due to `enabled_lps`.
|
||||
if (HWY_LIKELY(!package.clusters.empty())) {
|
||||
total_lps_ += package.num_lps; // before std::move
|
||||
packages_.push_back(std::move(package));
|
||||
}
|
||||
});
|
||||
|
||||
for (Package& package : packages_) {
|
||||
missing_cluster = package.clusters.empty();
|
||||
if (HWY_UNLIKELY(missing_cluster)) {
|
||||
fprintf(
|
||||
stderr,
|
||||
"Warning, found no clusters for package with %zu LPs.\nWe will "
|
||||
"ignore topology and assume a single package/cluster.\n",
|
||||
package.num_lps);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Topology unknown or any package ended up empty: create a single package
|
||||
// with one cluster.
|
||||
if (HWY_UNLIKELY(packages_.empty() || missing_cluster)) {
|
||||
// We do not bother to detect hyperthreads. Not all CPUs have two per
|
||||
// core, so instead of dividing, rely on the user's `lp_slice.max`. This
|
||||
// works because Linux groups LPs by HT.
|
||||
const size_t num_lps = have_threading_support
|
||||
? lp_slice.Num(hwy::TotalLogicalProcessors())
|
||||
: 1;
|
||||
packages_.clear();
|
||||
packages_.push_back(Package(num_lps));
|
||||
total_lps_ = num_lps;
|
||||
snprintf(topology_string_, sizeof(topology_string_), "LPs=%zu", num_lps);
|
||||
} else {
|
||||
SortByDescendingLPs(packages_);
|
||||
|
||||
const hwy::Topology::Package& tpackage0 = topology_.packages[0];
|
||||
HWY_ASSERT(!tpackage0.clusters.empty());
|
||||
const hwy::Topology::Cluster& tcluster0 = tpackage0.clusters[0];
|
||||
const Package& package0 = GetPackage(0);
|
||||
const Cluster& cluster0 = GetCluster(0, 0);
|
||||
snprintf(topology_string_, sizeof(topology_string_),
|
||||
"%zux%zux%zu, using %zux%zux%zu", topology_.packages.size(),
|
||||
tpackage0.clusters.size(), tcluster0.lps.Count(),
|
||||
packages_.size(), package0.clusters.size(), cluster0.num_lps);
|
||||
}
|
||||
|
||||
HWY_ASSERT(NumPackages() != 0);
|
||||
for (size_t package_idx = 0; package_idx < NumPackages(); ++package_idx) {
|
||||
HWY_ASSERT(NumClusters(package_idx) != 0);
|
||||
}
|
||||
}
|
||||
|
||||
const char* TopologyString() const { return topology_string_; }
|
||||
|
||||
size_t NumPackages() const { return packages_.size(); }
|
||||
const Package& GetPackage(size_t package_idx) const {
|
||||
HWY_ASSERT(package_idx < NumPackages());
|
||||
return packages_[package_idx];
|
||||
}
|
||||
Package& GetPackage(size_t package_idx) {
|
||||
HWY_ASSERT(package_idx < NumPackages());
|
||||
return packages_[package_idx];
|
||||
}
|
||||
|
||||
size_t NumClusters(size_t package_idx) const {
|
||||
return GetPackage(package_idx).clusters.size();
|
||||
}
|
||||
const Cluster& GetCluster(size_t package_idx, size_t cluster_idx) const {
|
||||
const Package& package = GetPackage(package_idx);
|
||||
HWY_ASSERT(cluster_idx < package.clusters.size());
|
||||
return package.clusters[cluster_idx];
|
||||
}
|
||||
Cluster& GetCluster(size_t package_idx, size_t cluster_idx) {
|
||||
Package& package = GetPackage(package_idx);
|
||||
HWY_ASSERT(cluster_idx < package.clusters.size());
|
||||
return package.clusters[cluster_idx];
|
||||
}
|
||||
|
||||
// Returns number of logical processors, for allocating per-thread buffers.
|
||||
size_t NumLP() const { return total_lps_; }
|
||||
|
||||
private:
|
||||
hwy::Topology topology_;
|
||||
size_t total_lps_ = 0;
|
||||
std::vector<Package> packages_;
|
||||
char topology_string_[96];
|
||||
};
|
||||
|
||||
// Creates a hierarchy of thread pools according to BoundedTopology: one with a
|
||||
// thread per enabled package; for each of those, one with a thread per enabled
|
||||
// cluster (CCX/shared L3), and for each of those, the remaining enabled cores
|
||||
// in that cluster. The cores representing each package and cluster are not
|
||||
// included in the per-cluster pool because we support spin-waiting, hence
|
||||
// there should be at most one thread per HW core.
|
||||
//
|
||||
// Useful when there are tasks which should be parallelized by workers sharing a
|
||||
// cache, or on the same NUMA node. In both cases, individual pools have lower
|
||||
// barrier synchronization latency than one large pool. However, to utilize all
|
||||
// cores, call sites will have to use nested parallel-for loops.
|
||||
class NestedPools {
|
||||
public:
|
||||
// Neither move nor copy.
|
||||
NestedPools() = delete;
|
||||
NestedPools(const NestedPools&) = delete;
|
||||
NestedPools& operator=(const NestedPools&) = delete;
|
||||
NestedPools(NestedPools&&) = delete;
|
||||
NestedPools& operator=(NestedPools&&) = delete;
|
||||
|
||||
// `max_threads` is the maximum number of threads to divide among all
|
||||
// clusters. It does not include the package and cluster representatives.
|
||||
// This is more intuitive than a per-cluster limit for users who may not be
|
||||
// aware of the CPU topology.
|
||||
//
|
||||
// To ensure we do not create more threads than there are HW cores, which
|
||||
// would cause huge slowdowns when spinning, `BoundedSlice` imposes upper
|
||||
// bounds on the number of detected packages and clusters rather than
|
||||
// defining an exact amount.
|
||||
//
|
||||
// `pin` is 0 or 1 to force enable/disable, or -1 to choose automatically.
|
||||
NestedPools(size_t max_threads, int pin = -1,
|
||||
BoundedSlice package_slice = BoundedSlice(),
|
||||
BoundedSlice cluster_slice = BoundedSlice(),
|
||||
BoundedSlice lp_slice = BoundedSlice())
|
||||
: topology_(package_slice, cluster_slice, lp_slice) {
|
||||
if (pin == -1) pin = topology_.NumLP() >= 12;
|
||||
|
||||
packages_.resize(topology_.NumPackages());
|
||||
all_packages_ = MakePool(packages_.size());
|
||||
const size_t max_workers_per_package = max_threads / packages_.size();
|
||||
// Parallel to ensure we also pin the calling (main) thread.
|
||||
all_packages_->Run(
|
||||
0, all_packages_->NumWorkers(),
|
||||
[&](uint64_t package_idx, size_t thread) {
|
||||
HWY_ASSERT(package_idx == thread); // each thread has one task
|
||||
packages_[package_idx] = Package(
|
||||
topology_, package_idx, max_workers_per_package, pin, lp_slice);
|
||||
});
|
||||
}
|
||||
|
||||
// Spinning reduces the latency of barrier synchronization, but wastes lots
|
||||
// of energy for long waits, so only do it during generation. This might
|
||||
// also be unsafe in virtualized environments because we require threads to
|
||||
// be running on their own core and thus responsive to the barrier
|
||||
// synchronization.
|
||||
void StartSpinning() { SetWaitMode(hwy::PoolWaitMode::kSpin); }
|
||||
void StopSpinning() { SetWaitMode(hwy::PoolWaitMode::kBlock); }
|
||||
|
||||
hwy::ThreadPool& AllPackages() { return *all_packages_; }
|
||||
hwy::ThreadPool& AllClusters(size_t package_idx) {
|
||||
HWY_ASSERT(package_idx < AllPackages().NumWorkers());
|
||||
return *packages_[package_idx].all_clusters;
|
||||
}
|
||||
hwy::ThreadPool& Cluster(size_t package_idx, size_t cluster_idx) {
|
||||
HWY_ASSERT(cluster_idx < AllClusters(package_idx).NumWorkers());
|
||||
return *packages_[package_idx].clusters[cluster_idx];
|
||||
}
|
||||
|
||||
const char* TopologyString() const { return topology_.TopologyString(); }
|
||||
|
||||
// Returns number of logical processors, for allocating per-thread buffers.
|
||||
size_t NumLP() const { return topology_.NumLP(); }
|
||||
|
||||
private:
|
||||
// `max_or_zero` == 0 means no limit.
|
||||
static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) {
|
||||
return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero);
|
||||
}
|
||||
|
||||
// We want vectors of hwy::ThreadPool, which is unfortunately not movable,
|
||||
// hence we wrap them in unique_ptr.
|
||||
using PoolPtr = std::unique_ptr<hwy::ThreadPool>;
|
||||
|
||||
static PoolPtr MakePool(size_t num_workers) {
|
||||
// `ThreadPool` expects the number of threads to create, which is one less
|
||||
// than the number of workers, but avoid underflow if zero.
|
||||
const size_t num_threads = num_workers == 0 ? 0 : num_workers - 1;
|
||||
return std::make_unique<hwy::ThreadPool>(num_threads);
|
||||
}
|
||||
|
||||
class Package {
|
||||
static PoolPtr CreateClusterPool(const BoundedTopology::Cluster& cluster,
|
||||
size_t max_cluster_workers, int pin,
|
||||
BoundedSlice lp_slice) {
|
||||
PoolPtr pool =
|
||||
MakePool(CapIfNonZero(cluster.num_lps, max_cluster_workers));
|
||||
|
||||
if (!pin) return pool;
|
||||
// Else: pin all new threads AND the calling thread from `all_clusters`.
|
||||
|
||||
// We know the topology: pin to this cluster's cores, including the
|
||||
// calling thread from `all_clusters`.
|
||||
if (cluster.lps.Any()) {
|
||||
std::vector<size_t> lps;
|
||||
lps.reserve(cluster.num_lps);
|
||||
cluster.lps.Foreach([&lps](size_t lp) { lps.push_back(lp); });
|
||||
|
||||
pool->Run(0, pool->NumWorkers(), [&lps](uint64_t task, size_t thread) {
|
||||
HWY_ASSERT(task == thread); // each worker has one task
|
||||
hwy::PinThreadToLogicalProcessor(lps[task]);
|
||||
});
|
||||
} else {
|
||||
// Pin to consecutive LPs.
|
||||
pool->Run(0, pool->NumWorkers(),
|
||||
[lp_slice](uint64_t task, size_t thread) {
|
||||
HWY_ASSERT(task == thread); // each worker has one task
|
||||
hwy::PinThreadToLogicalProcessor(lp_slice.skip + thread);
|
||||
});
|
||||
}
|
||||
return pool;
|
||||
}
|
||||
|
||||
public:
|
||||
Package() = default; // for vector
|
||||
Package(const BoundedTopology& topology, size_t package_idx,
|
||||
size_t max_workers_per_package, int pin, BoundedSlice lp_slice) {
|
||||
clusters.resize(topology.NumClusters(package_idx));
|
||||
const size_t max_workers_per_cluster =
|
||||
max_workers_per_package / clusters.size();
|
||||
|
||||
all_clusters = MakePool(clusters.size());
|
||||
// Parallel so we also pin the calling thread from `all_packages_`.
|
||||
all_clusters->Run(
|
||||
0, all_clusters->NumWorkers(),
|
||||
[&](size_t cluster_idx, size_t thread) {
|
||||
HWY_ASSERT(cluster_idx == thread); // each thread has one task
|
||||
const BoundedTopology::Cluster& cluster =
|
||||
topology.GetCluster(package_idx, cluster_idx);
|
||||
clusters[cluster_idx] = CreateClusterPool(
|
||||
cluster, max_workers_per_cluster, pin, lp_slice);
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<PoolPtr> clusters;
|
||||
PoolPtr all_clusters;
|
||||
};
|
||||
|
||||
void SetWaitMode(hwy::PoolWaitMode wait_mode) {
|
||||
all_packages_->SetWaitMode(wait_mode);
|
||||
for (Package& package : packages_) {
|
||||
package.all_clusters->SetWaitMode(wait_mode);
|
||||
for (PoolPtr& cluster : package.clusters) {
|
||||
cluster->SetWaitMode(wait_mode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BoundedTopology topology_;
|
||||
|
||||
std::vector<Package> packages_;
|
||||
PoolPtr all_packages_;
|
||||
};
|
||||
|
||||
static inline NestedPools CreateSinglePool(size_t max_threads, int pin = -1) {
|
||||
const BoundedSlice one(0, 1);
|
||||
return NestedPools(max_threads, pin, one, one);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
|
||||
|
|
|
|||
|
|
@ -0,0 +1,101 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "util/threading.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
|
||||
TEST(ThreadingTest, TestBoundedSlice) {
|
||||
const char* name = "test";
|
||||
// No args = no limit.
|
||||
{
|
||||
BoundedSlice slice;
|
||||
std::vector<size_t> expected;
|
||||
slice.ForEach(name, 10, [&](size_t i) { expected.push_back(i); });
|
||||
EXPECT_EQ(10, slice.Num(10));
|
||||
EXPECT_THAT(expected, ElementsAre(0, 1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||
}
|
||||
|
||||
// One arg: skip first N
|
||||
{
|
||||
BoundedSlice slice(3);
|
||||
std::vector<size_t> expected;
|
||||
slice.ForEach(name, 9, [&](size_t i) { expected.push_back(i); });
|
||||
EXPECT_EQ(6, slice.Num(9));
|
||||
EXPECT_THAT(expected, ElementsAre(3, 4, 5, 6, 7, 8));
|
||||
}
|
||||
|
||||
// Both args: skip first N, then use at most M
|
||||
{
|
||||
BoundedSlice slice(3, 2);
|
||||
std::vector<size_t> expected;
|
||||
slice.ForEach(name, 9, [&](size_t i) { expected.push_back(i); });
|
||||
EXPECT_EQ(2, slice.Num(9));
|
||||
EXPECT_THAT(expected, ElementsAre(3, 4));
|
||||
}
|
||||
|
||||
// Both args, but `max > detected - skip`: fewer than limit. Note that
|
||||
// `skip >= detected` is an error.
|
||||
{
|
||||
BoundedSlice slice(3, 2);
|
||||
std::vector<size_t> expected;
|
||||
slice.ForEach(name, 4, [&](size_t i) { expected.push_back(i); });
|
||||
EXPECT_EQ(1, slice.Num(4));
|
||||
EXPECT_THAT(expected, ElementsAre(3));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ThreadingTest, TestBoundedTopology) {
|
||||
const BoundedSlice all;
|
||||
const BoundedSlice one(0, 1);
|
||||
// All
|
||||
{
|
||||
BoundedTopology topology(all, all, all);
|
||||
fprintf(stderr, "%s\n", topology.TopologyString());
|
||||
ASSERT_NE(0, topology.NumPackages());
|
||||
ASSERT_NE(0, topology.NumClusters(0));
|
||||
}
|
||||
|
||||
// Max one package
|
||||
{
|
||||
BoundedTopology topology(one, all, all);
|
||||
fprintf(stderr, "%s\n", topology.TopologyString());
|
||||
ASSERT_EQ(1, topology.NumPackages());
|
||||
ASSERT_NE(0, topology.NumClusters(0));
|
||||
}
|
||||
|
||||
// Max one cluster
|
||||
{
|
||||
BoundedTopology topology(all, one, all);
|
||||
fprintf(stderr, "%s\n", topology.TopologyString());
|
||||
ASSERT_NE(0, topology.NumPackages());
|
||||
ASSERT_EQ(1, topology.NumClusters(0));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
Loading…
Reference in New Issue