From f74d4968795d3c84ec5de94ba0c84f308671e2ef Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 27 Nov 2024 01:11:20 -0800 Subject: [PATCH] Threading/infra improvements. * Add Parallelize*Range helpers and partitioning helpers * Refactor Pinning class, store original affinity (required to construct another NestedPools after pinning happened) Compress: * prevent Compress printing stats in tests * zero-pad tensors Matmul: * add matmul_unit_test (TODO) and bench_matmul * matmul_test: change norm to row vectors (that is what is added) and include bf16 rounding error * Prepare for L2/L3 retrieval PiperOrigin-RevId: 700603811 --- BUILD.bazel | 46 +++++++- CMakeLists.txt | 2 + backprop/optimize_test.cc | 2 + compression/compress-inl.h | 8 +- compression/compress.h | 7 +- evals/benchmark_helper.cc | 1 + examples/hello_world/run.cc | 1 + ops/bench_matmul.cc | 214 ++++++++++++++++++++++++++++++++++ ops/dot_test.cc | 1 + ops/matmul-inl.h | 16 --- ops/matmul.h | 35 ++++++ ops/matmul_test.cc | 198 +++++++++++++++++++++---------- ops/matmul_unit_test.cc | 17 +++ util/allocator.cc | 21 +--- util/allocator.h | 78 +------------ util/app.h | 2 + util/basics.h | 34 +++--- util/threading.cc | 218 +++++++++++++++++++--------------- util/threading.h | 169 +++++++++++++++++++++++++-- util/threading_test.cc | 225 ++++++++++++++++++++++++++++++++++++ 20 files changed, 1001 insertions(+), 294 deletions(-) create mode 100644 ops/bench_matmul.cc create mode 100644 ops/matmul_unit_test.cc diff --git a/BUILD.bazel b/BUILD.bazel index 204f24e..4dc8c62 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -71,6 +71,7 @@ cc_test( "@googletest//:gtest_main", "@highway//:hwy", "@highway//:hwy_test_util", + "@highway//:thread_pool", ], ) @@ -166,6 +167,26 @@ cc_test( ], ) +cc_test( + name = "matmul_unit_test", + size = "small", + timeout = "long", + srcs = ["ops/matmul_unit_test.cc"], + local_defines = ["HWY_IS_TEST"], + # for test_suite. + tags = ["hwy_ops_test"], + deps = [ + ":allocator", + ":basics", + ":ops", + ":test_util", + "@googletest//:gtest_main", # buildcleaner: keep + "//compression:compress", + "@highway//:hwy", + "@highway//:hwy_test_util", + ], +) + cc_test( name = "matmul_test", size = "small", @@ -178,7 +199,28 @@ cc_test( ":allocator", ":basics", ":ops", - ":test_util", + ":threading", + "@googletest//:gtest_main", # buildcleaner: keep + "//compression:compress", + "@highway//:hwy", + "@highway//:hwy_test_util", + "@highway//:nanobenchmark", + "@highway//:thread_pool", + ], +) + +cc_test( + name = "bench_matmul", + size = "small", + timeout = "long", + srcs = ["ops/bench_matmul.cc"], + local_defines = ["HWY_IS_TEST"], + # for test_suite. + tags = ["hwy_ops_test"], + deps = [ + ":allocator", + ":basics", + ":ops", ":threading", "@googletest//:gtest_main", # buildcleaner: keep "//compression:compress", @@ -343,6 +385,7 @@ cc_library( ":basics", ":common", ":gemma_lib", + ":ops", ":threading", "//compression:io", "@highway//:hwy", @@ -624,6 +667,7 @@ cc_test( "mem": "28g", }, deps = [ + ":allocator", ":backprop", ":basics", ":common", diff --git a/CMakeLists.txt b/CMakeLists.txt index 0b16c65..0efece1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -163,9 +163,11 @@ set(GEMMA_TEST_FILES compression/sfp_test.cc evals/gemma_test.cc gemma/tensor_index_test.cc + ops/bench_matmul.cc ops/dot_test.cc ops/gemma_matvec_test.cc ops/matmul_test.cc + ops/matmul_unit_test.cc ops/ops_test.cc paligemma/image_test.cc paligemma/paligemma_test.cc diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index a23ac84..6d3522f 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -33,6 +33,7 @@ #include "gemma/configs.h" #include "gemma/gemma.h" #include "gemma/weights.h" +#include "util/allocator.h" #include "util/basics.h" #include "util/threading.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -42,6 +43,7 @@ namespace gcpp { TEST(OptimizeTest, GradientDescent) { NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1), BoundedSlice(0, 1)); + Allocator::Init(pools.Topology()); hwy::ThreadPool& pool = pools.Pool(); std::mt19937 gen(42); diff --git a/compression/compress-inl.h b/compression/compress-inl.h index be4c5af..6d8ba28 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -51,6 +51,12 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; +#ifdef HWY_IS_TEST +static constexpr bool kIsTest = true; +#else +static constexpr bool kIsTest = false; +#endif + // Enables generic code independent of compression type. template // primary, must specialize struct CompressTraits {}; @@ -438,7 +444,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, } } - const bool want_bench = num > 1024 * 1024 || COMPRESS_STATS; + const bool want_bench = COMPRESS_STATS || !kIsTest; const double t0 = want_bench ? hwy::platform::Now() : 0.0; using Traits = CompressTraits; diff --git a/compression/compress.h b/compression/compress.h index ff64b49..fc13bdf 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -38,6 +38,7 @@ #include "util/basics.h" // IWYU pragma: end_exports #include "util/allocator.h" +#include "hwy/per_target.h" #if COMPRESS_STATS #include "compression/distortion.h" #include "hwy/stats.h" @@ -360,7 +361,11 @@ class MatStorageT : public MatPtrT { } else { this->num_elements_ = num_elements; } - data_ = Allocator::Alloc(num_elements); + // Pad to allow overrunning the last row by 2 BF16 vectors, hence at most + // `2 * VectorBytes / sizeof(BF16)` elements of MatT. + const size_t padding = hwy::VectorBytes(); + data_ = Allocator::Alloc(num_elements + padding); + hwy::ZeroBytes(&data_[num_elements], padding * sizeof(MatT)); this->ptr_ = data_.get(); } diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 60cc61e..8bfc015 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -56,6 +56,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) { GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, const AppArgs& app) : pools_(CreatePools(app)) { + Allocator::Init(pools_.Topology()); InferenceArgs mutable_inference = inference; AbortIfInvalidArgs(mutable_inference); LoaderArgs mutable_loader = loader; diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 70c3654..3951350 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -59,6 +59,7 @@ int main(int argc, char** argv) { // Instantiate model and KV Cache gcpp::NestedPools pools = gcpp::CreatePools(app); + gcpp::Allocator::Init(pools.Topology()); gcpp::Gemma model = gcpp::CreateGemma(loader, pools); gcpp::KVCache kv_cache = gcpp::KVCache::Create(model.GetModelConfig(), diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc new file mode 100644 index 0000000..be58d29 --- /dev/null +++ b/ops/bench_matmul.cc @@ -0,0 +1,214 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Benchmark of large MatMul instances for which the MatMulSlow would be too +// slow. This lacks a reference and is only useful for performance measurement. + +#include "hwy/detect_compiler_arch.h" +#ifndef HWY_DISABLED_TARGETS +// Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require +// double-precision support. +#if HWY_ARCH_ARM_V7 +#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON) +#else +#define HWY_DISABLED_TARGETS HWY_SCALAR +#endif +#endif + +#include +#include + +#include + +#include "compression/compress.h" +#include "compression/shared.h" +#include "ops/matmul.h" +#include "util/allocator.h" +#include "util/basics.h" +#include "util/threading.h" +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/timer.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "ops/bench_matmul.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "compression/compress-inl.h" +#include "ops/matmul-inl.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +// For running BenchAllMatMul only once. Defined within HWY_ONCE. +extern int64_t first_target; + +namespace HWY_NAMESPACE { + +using FloatPtr = hwy::AlignedFreeUniquePtr; + +template +using MatStoragePtr = std::unique_ptr>; + +// Generates inputs: deterministic, within max SfpStream range. +template +MatStoragePtr GenerateMat(const Extents2D extents, + hwy::ThreadPool& pool) { + gcpp::CompressWorkingSet ws; + auto mat = + std::make_unique>("mat", extents.rows, extents.cols); + FloatPtr content = hwy::AllocateAligned(mat->NumElements()); + HWY_ASSERT(content); + const float scale = SfpStream::kMax / (mat->NumElements()); + pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) { + for (size_t c = 0; c < extents.cols; c++) { + float f = static_cast(r * extents.cols + c) * scale; + if ((r + c) & 1) f = -f; // Also generate some negative values. + content[r * extents.cols + c] = f; + } + }); + + CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool); + mat->set_scale(0.6f); // Arbitrary value, different from 1. + return mat; +} + +// extents describes the transposed matrix. +template +MatStoragePtr GenerateTransposedMat(const Extents2D extents, + hwy::ThreadPool& pool) { + gcpp::CompressWorkingSet ws; + auto mat = + std::make_unique>("trans", extents.rows, extents.cols); + FloatPtr content = hwy::AllocateAligned(mat->NumElements()); + const float scale = SfpStream::kMax / (mat->NumElements()); + pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) { + for (size_t c = 0; c < extents.cols; c++) { + float f = static_cast(c * extents.rows + r) * scale; + if ((r + c) & 1) f = -f; // Also generate some negative values. + content[r * extents.cols + c] = f; + } + }); + + CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool); + // Arbitrary value, different from 1, must match GenerateMat. + mat->set_scale(0.6f); + return mat; +} + +void PrintSpeed(const char* algo, const Extents2D& A_extents, + const Extents2D& B_extents, double elapsed) { + const size_t num_b = B_extents.Area(); + // 2x because of FMA. + fprintf(stderr, " %10s: %f seconds, %.1f GFLOPS.\n", algo, + elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed); +} + +// Generates inputs and prints observed throughput of MatMul. +template +void BenchMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, + MatMulEnv& env) { + hwy::ThreadPool& pool = env.Pool(); + fprintf(stderr, "BenchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", + rows_ac, cols_a_rows_b, cols_bc, add, TypeName(), + TypeName()); + + const Extents2D A_extents(rows_ac, cols_a_rows_b); + const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed + const Extents2D C_extents(rows_ac, cols_bc); + + MatStoragePtr a = GenerateMat(A_extents, pool); + MatStoragePtr b_trans = GenerateTransposedMat(B_extents, pool); + RowVectorBatch c_slow_batch(C_extents); + RowVectorBatch c_batch(C_extents); + HWY_ASSERT(a && b_trans); + + std::unique_ptr> add_storage; + if (add) { + add_storage = GenerateMat(Extents2D(1, cols_bc), pool); + HWY_ASSERT(add_storage); + add_storage->set_scale(1.0f); + } + + const auto A = ConstMatFromWeights(*a); + const auto B = ConstMatFromWeights(*b_trans); + const float* add_row = add ? add_storage->data_scale1() : nullptr; + const RowPtrF C = RowPtrFromBatch(c_batch); + + double min_elapsed = hwy::HighestValue(); + for (int rep = 0; rep < 3; ++rep) { + const double start_tiled = hwy::platform::Now(); + MatMul(A, B, add_row, env, C); + min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled); + } + PrintSpeed("MatMul", A_extents, B_extents, min_elapsed); +} + +using F32 = float; +using SFP = SfpStream; + +void BenchAllMatMul() { + if (first_target == 0) first_target = HWY_TARGET; + if (HWY_TARGET != first_target) return; + + for (size_t max_packages : {1, 2}) { + const size_t max_threads = 0; // no limit + NestedPools pools(max_threads, Tristate::kDefault, + BoundedSlice(0, max_packages)); +#if GEMMA_DISABLE_TOPOLOGY + if (max_packages == 2) break; // we only have one package +#else + // If less than the limit, we have already tested all num_packages. + if (pools.Topology().FullTopology().packages.size() < max_packages) break; +#endif + fprintf(stderr, "BenchAllMatMul %zu: %s %s\n", max_packages, + pools.TopologyString(), pools.PinString()); + + Tristate use_spinning = Tristate::kDefault; + pools.MaybeStartSpinning(use_spinning); + Allocator::Init(pools.Topology()); + MatMulEnv env(pools); + + for (size_t batch_size : {1, /*4, 64,*/ 128}) { + BenchMatMul(batch_size, 24576, 3072, /*add=*/false, env); + BenchMatMul(batch_size, 3072, 24576, /*add=*/false, env); + BenchMatMul(batch_size, 24576, 3072, /*add=*/false, env); + BenchMatMul(batch_size, 3072, 24576, /*add=*/false, env); + BenchMatMul(batch_size, 24576, 3072, /*add=*/false, env); + BenchMatMul(batch_size, 3072, 24576, /*add=*/false, env); + } + pools.MaybeStopSpinning(use_spinning); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace gcpp { +int64_t first_target = 0; // none run yet +HWY_BEFORE_TEST(BenchMatMul); +HWY_EXPORT_AND_TEST_P(BenchMatMul, BenchAllMatMul); +HWY_AFTER_TEST(); + +} // namespace gcpp + +#endif diff --git a/ops/dot_test.cc b/ops/dot_test.cc index 7fb8514..73d9f8f 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -1109,6 +1109,7 @@ void TestAllDot() { const size_t num = 24 * 1024; NestedPools pools(kMaxWorkers - 1, /*pin=*/Tristate::kDefault, BoundedSlice(0, 1), BoundedSlice(0, 1)); + Allocator::Init(pools.Topology()); RowVectorBatch a(Extents2D(kMaxWorkers, num)); RowVectorBatch b(Extents2D(kMaxWorkers, num)); RowVectorBatch bufs(Extents2D(kMaxWorkers, num)); diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 8646f79..18d4e6b 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -38,22 +38,6 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; -// The MatMul result C[r,c] is Dot(A.Row(r), B.Col(c)). To reduce the number of -// loads, we reuse the same A row for several B columns, which are also loaded -// once for several rows of C. Thus we produce one 'tile' of C at a time of -// dimensions `kRegRows` x `kRegCols`. The Reg naming is because these are -// limited by the number of registers: 32 for NEON/SVE/AVX-512. `kRegCols` == 4 -// enables the `StoreInterleaved4` transpose in `AddHorizontalSums`. We assume -// and verify that `C.cols % kRegCols == 0`. -constexpr size_t kRegCols = 4; - -// Choosing `kRegRows == kRegCols` minimizes the ratio of loads to FMA, because -// we load `kRegCols + kRegRows` vectors per `kRegRows * kRegCols` element tile. -// In general, `batch_size` (C rows) is not a multiple of `kRegRows`. Thus -// functions that load or store a tile are parameterized on `kNumRows`, which is -// generally `kRegRows`, but `batch_size % kRegRows` on the last row (if != 0). -constexpr size_t kRegRows = kRegCols; - // Loads two vectors at a time with element type hn::TFromD from a row of // transposed B. Called in a loop over col_ab. No bounds checking because // `kRow` is from B columns, which we checked is a multiple of `kRegCols`. diff --git a/ops/matmul.h b/ops/matmul.h index 2eff81f..e77bd93 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -21,6 +21,7 @@ // IWYU pragma: begin_exports #include "util/basics.h" #include "util/threading.h" +#include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" // IWYU pragma: end_exports @@ -28,6 +29,40 @@ namespace gcpp { +// TODO: remove deprecated typedef. +using Range1D = IndexRange; + +// The MatMul result C[r,c] is Dot(A.Row(r), B.Col(c)). To reduce the number of +// loads, we reuse the same A row for several B columns, which are also loaded +// once for several rows of C. Thus we produce one 'tile' of C at a time of +// dimensions `kRegRows` x `kRegCols`. The Reg naming is because these are +// limited by the number of registers: 32 for NEON/SVE/AVX-512. `kRegCols` == 4 +// enables the `StoreInterleaved4` transpose in `StoreHorizontalSums`. We assume +// and verify that `C.Cols() % kRegCols == 0`. +constexpr size_t kRegCols = 4; + +// Choosing `kRegRows == kRegCols` minimizes the ratio of loads to FMA, because +// we load `kRegCols + kRegRows` vectors per `kRegRows * kRegCols` element tile. +// In general, `batch_size` (A/C rows) is not a multiple of `kRegRows`. Thus +// functions that load or store a tile are parameterized on `kRowsPerTile`: +// usually `kRegRows`, but `batch_size % kRegRows` on the last row (if != 0). +constexpr size_t kRegRows = kRegCols; + +struct CacheSizes { + CacheSizes() = default; + CacheSizes(const BoundedTopology::Cluster& cluster) { + // Assumes each package and cluster has the same cache sizes, and uses + // reasonable defaults if unknown. + l1_bytes = 32 * 1024; // typical size, rarely changes + l2_bytes = (cluster.PrivateKiB() ? cluster.PrivateKiB() : 256) * 1024; + l3_bytes = (cluster.SharedKiB() ? cluster.SharedKiB() : 1024) * 1024; + } + + size_t l1_bytes; + size_t l2_bytes; + size_t l3_bytes; +}; + // Allocations and threads, shared across MatMul calls. class MatMulEnv { public: diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 8d36acd..ac79e4b 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -13,6 +13,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +// End to end test of MatMul, comparing against a reference implementation. + +#include "hwy/detect_compiler_arch.h" #ifndef HWY_DISABLED_TARGETS // Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require // double-precision support. @@ -23,14 +26,14 @@ #endif #endif -#include "ops/matmul.h" - #include #include #include #include "compression/compress.h" +#include "compression/shared.h" +#include "ops/matmul.h" #include "util/allocator.h" #include "util/basics.h" #include "util/threading.h" @@ -52,7 +55,11 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { +// For running TestBatchSizes only once. Defined within HWY_ONCE. +extern int64_t first_target; + namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; using FloatPtr = hwy::AlignedFreeUniquePtr; @@ -71,8 +78,9 @@ MatStoragePtr GenerateMat(const Extents2D extents, const float scale = SfpStream::kMax / (mat->NumElements()); pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) { for (size_t c = 0; c < extents.cols; c++) { - content[r * extents.cols + c] = - static_cast(r * extents.cols + c) * scale; + float f = static_cast(r * extents.cols + c) * scale; + if ((r + c) & 1) f = -f; // Also generate some negative values. + content[r * extents.cols + c] = f; } }); @@ -92,8 +100,9 @@ MatStoragePtr GenerateTransposedMat(const Extents2D extents, const float scale = SfpStream::kMax / (mat->NumElements()); pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) { for (size_t c = 0; c < extents.cols; c++) { - content[r * extents.cols + c] = - static_cast(c * extents.rows + r) * scale; + float f = static_cast(c * extents.rows + r) * scale; + if ((r + c) & 1) f = -f; // Also generate some negative values. + content[r * extents.cols + c] = f; } }); @@ -104,16 +113,28 @@ MatStoragePtr GenerateTransposedMat(const Extents2D extents, } // Returns 1-norm, used for estimating tolerable numerical differences. -double MaxColAbsSum(const float* HWY_RESTRICT a, const Extents2D& extents) { - double max_col_abs_sum = 0.0; - for (size_t c = 0; c < extents.cols; c++) { - double col_abs_sum = 0.0; - for (size_t r = 0; r < extents.rows; r++) { - col_abs_sum += hwy::ScalarAbs(a[r * extents.cols + c]); +double MaxRowAbsSum(const float* HWY_RESTRICT a, const Extents2D& extents) { + double max_row_abs_sum = 0.0; + for (size_t r = 0; r < extents.rows; r++) { + const float* row = a + r * extents.cols; + double row_abs_sum = 0.0; + for (size_t c = 0; c < extents.cols; c++) { + row_abs_sum += hwy::ScalarAbs(row[c]); } - max_col_abs_sum = HWY_MAX(max_col_abs_sum, col_abs_sum); + max_row_abs_sum = HWY_MAX(max_row_abs_sum, row_abs_sum); } - return max_col_abs_sum; + return max_row_abs_sum; +} + +// Returns the maximum absolute value of `a`. +float MaxAbs(const float* HWY_RESTRICT a, const Extents2D& extents) { + float max_abs = 0.0f; + for (size_t c = 0; c < extents.cols; c++) { + for (size_t r = 0; r < extents.rows; r++) { + max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(a[r * extents.cols + c])); + } + } + return max_abs; } // B is already transposed. @@ -132,12 +153,25 @@ void AssertClose(const ConstMat& A, const ConstMat& B, DecompressAndZeroPad(df, MakeSpan(A.ptr, num_a), 0, a.get(), num_a); DecompressAndZeroPad(df, MakeSpan(B.ptr, num_b), 0, b_trans.get(), num_b); - const double norm = MaxColAbsSum(a.get(), A.Extents()) * - MaxColAbsSum(b_trans.get(), B.Extents()); - // Dot(float,BF16) rounds both to BF16. - using RefType = hwy::If() && IsF32(), float, BF16>; - const double epsilon = hwy::ConvertScalarTo(hwy::Epsilon()); - const double tolerance = 200.0 * norm * epsilon; + // MatMul rounds inputs to BF16, so error is proportional to the max input + // magnitude, but also to f32 accumulation of rows in A and B. + const double norm = MaxRowAbsSum(a.get(), A.Extents()) * + MaxRowAbsSum(b_trans.get(), B.Extents()); + const float max_abs = + MaxAbs(a.get(), A.Extents()) * MaxAbs(b_trans.get(), B.Extents()); + const double eps_bf16 = hwy::ConvertScalarTo(hwy::Epsilon()); + const double eps_f32 = hwy::ConvertScalarTo(hwy::Epsilon()); + double tolerance = 8 * norm * eps_f32; + // Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the + // tolerance there. + if (IsF32() && IsF32()) { + tolerance += 4 * max_abs * eps_bf16; + } + EXPECT_GE(tolerance, 1E-4); + if (tolerance > 4.0) { + fprintf(stderr, "WARN: high tolerance %f norm %f maxabs %f\n", tolerance, + norm, max_abs); + } for (size_t r = 0; r < A.extents.rows; r++) { const float* expected_row = C_slow.Row(r); @@ -148,10 +182,11 @@ void AssertClose(const ConstMat& A, const ConstMat& B, if (!(expected_value - tolerance <= actual_value && actual_value <= expected_value + tolerance)) { - fprintf( - stderr, - "(%zu,%zu): expected %f, actual %f, norm %f eps %E tolerance %f\n", - r, c, expected_value, actual_value, norm, epsilon, tolerance); + fprintf(stderr, + "(%zu,%zu): expected %f, actual %f, norm %f maxabs %f " + "tolerance %f\n", + r, c, expected_value, actual_value, norm, max_abs, tolerance); + return; } } } @@ -171,20 +206,31 @@ HWY_INLINE void MatMulSlow(const ConstMat A, const ConstMat B, const hn::ScalableTag df; // lane type is ignored const PackedSpan b_span = MakeSpan(B.ptr, B.ofs + B.extents.Area()); - const Extents2D C_extents(A.extents.rows, C.Cols()); + const IndexRange all_rows_c(0, A.Extents().rows); + const IndexRange all_cols_c(0, C.Cols()); - StaticPartitionRowsAndCols( - env.Pools(), C_extents, sizeof(MatTB), - [&](const Range2D& C_range, const TaskLocation& loc) { - loc.cluster.Run( - C_range.rows.begin(), C_range.rows.end(), - [&](const uint64_t row, size_t /*thread*/) { - float* HWY_RESTRICT C_row = C.Row(row); - for (size_t row_b_col_c : C_range.cols) { - const float add = add_row ? add_row[row_b_col_c] : 0.0f; - C_row[row_b_col_c] = - add + scale * Dot(df, b_span, row_b_col_c * B.extents.cols, - A.ptr + A.Row(row), A.extents.cols); + NestedPools& pools = env.Pools(); + hwy::ThreadPool& all_packages = pools.AllPackages(); + const IndexRangePartition get_row_c = + StaticPartition(all_rows_c, all_packages.NumWorkers(), 1); + ParallelizeOneRange( + get_row_c, all_packages, + [&](const IndexRange& rows_c, size_t package_idx) HWY_ATTR { + hwy::ThreadPool& all_clusters = pools.AllClusters(package_idx); + const size_t multiple = Allocator::Alignment() / sizeof(MatTB); + const IndexRangePartition get_col_c = + StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple); + ParallelizeOneRange( + get_col_c, all_clusters, + [&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR { + for (size_t r : rows_c) { + float* HWY_RESTRICT C_row = C.Row(r); + for (size_t c : cols_c) { + const float add = add_row ? add_row[c] : 0.0f; + C_row[c] = + add + scale * Dot(df, b_span, c * B.extents.cols, + A.ptr + A.Row(r), A.extents.cols); + } } }); }); @@ -250,6 +296,40 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, AssertClose(A, B, C_slow, C); } +using F32 = float; +using SFP = SfpStream; + +// Sweep batch_size for a single input type and Highway target, to verify the +// row partitioning. +void TestBatchSizes() { + if (first_target == 0) first_target = HWY_TARGET; + if (HWY_TARGET != first_target) return; + + for (size_t max_packages : {1, 2}) { + const size_t max_threads = 0; // no limit + NestedPools pools(max_threads, Tristate::kDefault, + BoundedSlice(0, max_packages)); +#if GEMMA_DISABLE_TOPOLOGY + if (max_packages == 2) break; // we only have one package +#else + // If less than the limit, we have already tested all num_packages. + if (pools.Topology().FullTopology().packages.size() < max_packages) break; +#endif + fprintf(stderr, "TestBatchSizes %zu: %s %s\n", max_packages, + pools.TopologyString(), pools.PinString()); + + Tristate use_spinning = Tristate::kDefault; + pools.MaybeStartSpinning(use_spinning); + Allocator::Init(pools.Topology()); + MatMulEnv env(pools); + + for (size_t batch_size = 1; batch_size <= 3 * kRegRows; ++batch_size) { + TestMatMul(batch_size, 256, 256, /*add=*/false, env); + } + pools.MaybeStopSpinning(use_spinning); + } +} + void TestAllMatMul() { // Skip EMU128 (10x slower than SSE4 for SFP) and older x86. if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSE4 || @@ -257,32 +337,30 @@ void TestAllMatMul() { return; } - NestedPools pools(4, /*pin=*/Tristate::kDefault); + NestedPools pools(0); // no limits Tristate use_spinning = Tristate::kDefault; pools.MaybeStartSpinning(use_spinning); Allocator::Init(pools.Topology()); MatMulEnv env(pools); - using F32 = float; - using SFP = SfpStream; + // Sizes seen in gemma_test 2B. + TestMatMul(1, 2048, 512, /*add=*/false, env); + TestMatMul(1, 2048, 2048, /*add=*/false, env); + TestMatMul(1, 2048, 16384, /*add=*/false, env); + TestMatMul(1, 16384, 2048, /*add=*/false, env); + TestMatMul(1, 2048, 256000, /*add=*/false, env); + TestMatMul(5, 2048, 512, /*add=*/false, env); + TestMatMul(5, 2048, 2048, /*add=*/false, env); + TestMatMul(5, 2048, 16384, /*add=*/false, env); + TestMatMul(5, 16384, 2048, /*add=*/false, env); - // large-scale test: batch_size=128 is better than 64 or 256 for SKX. - // TestMatMul(128, 24576, 3072, /*add=*/false, env); - // TestMatMul(128, 3072, 24576, /*add=*/false, env); - TestMatMul(1, 24576, 3072, /*add=*/false, env); - TestMatMul(1, 3072, 24576, /*add=*/false, env); - TestMatMul(1, 24576, 3072, /*add=*/false, env); - TestMatMul(1, 3072, 24576, /*add=*/false, env); - - // medium-sized square test - temporarily disabled for faster testing. - if constexpr (false) { - TestMatMul(512, 512, 512, /*add=*/false, env); - TestMatMul(512, 512, 512, /*add=*/true, env); - TestMatMul(512, 512, 512, /*add=*/false, env); - TestMatMul(512, 512, 512, /*add=*/true, env); - TestMatMul(512, 512, 512, /*add=*/false, env); - TestMatMul(512, 512, 512, /*add=*/true, env); - } + // medium-sized square + TestMatMul(512, 512, 512, /*add=*/false, env); + TestMatMul(512, 512, 512, /*add=*/true, env); + TestMatMul(512, 512, 512, /*add=*/false, env); + TestMatMul(512, 512, 512, /*add=*/true, env); + TestMatMul(512, 512, 512, /*add=*/false, env); + TestMatMul(512, 512, 512, /*add=*/true, env); // minimal non-square test. kColsARowsB must be at least 2 vectors. TestMatMul(35, 128, 32, /*add=*/false, env); @@ -325,8 +403,10 @@ HWY_AFTER_NAMESPACE(); #if HWY_ONCE namespace gcpp { -HWY_BEFORE_TEST(MatmulTest); -HWY_EXPORT_AND_TEST_P(MatmulTest, TestAllMatMul); +int64_t first_target = 0; // none run yet +HWY_BEFORE_TEST(MatMulTest); +HWY_EXPORT_AND_TEST_P(MatMulTest, TestBatchSizes); +HWY_EXPORT_AND_TEST_P(MatMulTest, TestAllMatMul); HWY_AFTER_TEST(); } // namespace gcpp diff --git a/ops/matmul_unit_test.cc b/ops/matmul_unit_test.cc new file mode 100644 index 0000000..f8752b8 --- /dev/null +++ b/ops/matmul_unit_test.cc @@ -0,0 +1,17 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// TODO: Tests of individual MatMul components. +int main() { return 0; } \ No newline at end of file diff --git a/util/allocator.cc b/util/allocator.cc index dd9943b..8f41f00 100644 --- a/util/allocator.cc +++ b/util/allocator.cc @@ -103,7 +103,7 @@ size_t CountBusyPages(size_t num_pages, size_t node, void** pages, // which means it would have to be called before pages are faulted in, but // `aligned_allocator.h` modifies the first bytes for its bookkeeping. // May overwrite some of the memory with zeros. -static void BindMemory(void* ptr, size_t bytes, size_t node) { +void BindMemory(void* ptr, size_t bytes, size_t node) { constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough" // Avoid mbind because it does not report why it failed, which is most likely // because pages are busy, in which case we want to know which. @@ -159,24 +159,7 @@ static void BindMemory(void* ptr, size_t bytes, size_t node) { #else // TODO: support other OSes. -static void BindMemory(void*, size_t, size_t) {} +void BindMemory(void*, size_t, size_t) {} #endif // GEMMA_NUMA && HWY_OS_LINUX -void BindTensor(NestedPools& nested, const Extents2D& extents, - size_t bytes_per_col, void* ptr) { - if (!Allocator::UseNUMA()) return; - uint8_t* p8 = static_cast(ptr); - const size_t bytes_per_row = extents.cols * bytes_per_col; - StaticPartitionRowsAndCols( - nested, extents, bytes_per_col, - [&](const Range2D& r, const TaskLocation& loc) { - for (size_t row : r.rows) { - uint8_t* slice = - p8 + row * bytes_per_row + r.cols.begin() * bytes_per_col; - const size_t slice_size = r.cols.Num() * bytes_per_col; - BindMemory(slice, slice_size, loc.node); - } - }); -} - } // namespace gcpp diff --git a/util/allocator.h b/util/allocator.h index 08476e3..cf1e161 100644 --- a/util/allocator.h +++ b/util/allocator.h @@ -125,82 +125,8 @@ class Allocator { static size_t alignment_; }; -// For shorter arguments to the StaticPartitionRowsAndCols functor. -struct TaskLocation { - TaskLocation(size_t node, size_t package_idx, hwy::ThreadPool& cluster, - size_t worker_offset) - : node(node), - package_idx(package_idx), - cluster(cluster), - worker_offset(worker_offset) {} - size_t node; - size_t package_idx; - hwy::ThreadPool& cluster; - const size_t worker_offset; -}; - -// Used in MatMul and allocator.h. Defined here because it depends on -// Allocator::Alignment(). -template -void StaticPartitionRowsAndCols(NestedPools& nested, Extents2D extents, - size_t bytes_per_element, const Func& func) { - // Both rows and cols must be a multiple of the alignment to avoid - // touching remote pages. - const size_t multiple = Allocator::Alignment() / bytes_per_element; - - // Static partitioning of columns across packages. We assume that column - // sharding is more expensive, hence we distribute columns across packages, - // of which there are usually only one or two. For MatMul, the final result is - // the sum of each package's partial dot products. - hwy::ThreadPool& all_packages = nested.AllPackages(); - const size_t num_packages = all_packages.NumWorkers(); - const size_t cols_per_package = - hwy::RoundUpTo(hwy::DivCeil(extents.cols, num_packages), multiple); - const size_t col_tasks = hwy::DivCeil(extents.cols, cols_per_package); - HWY_ASSERT(col_tasks <= num_packages); - all_packages.Run( - 0, col_tasks, [&](uint64_t package_idx, size_t package_thread) { - HWY_ASSERT(package_idx == package_thread); // one task per worker - const size_t col_begin = package_idx * cols_per_package; - const Range1D col_range = - MakeRange1D(col_begin, extents.cols, cols_per_package); - - // Static partitioning of rows across the package's clusters. We assume - // that row sharding is cheaper. In MatMul, results can indeed be - // computed independently for each row of B. - hwy::ThreadPool& all_clusters = nested.AllClusters(package_idx); - const size_t num_clusters = all_clusters.NumWorkers(); - const size_t rows_per_cluster = - hwy::RoundUpTo(hwy::DivCeil(extents.rows, num_clusters), multiple); - const size_t row_tasks = hwy::DivCeil(extents.rows, rows_per_cluster); - HWY_ASSERT(row_tasks <= num_clusters); - all_clusters.Run( - 0, row_tasks, [&](uint64_t cluster_idx, size_t cluster_thread) { - HWY_ASSERT(cluster_idx == cluster_thread); // one task per worker - - // For binding to NUMA node. - const size_t node = nested.Node(package_idx, cluster_idx); - // Older CPUs that predate chiplets typically have only one - // cluster, so callers should also parallelize using this - // per-cluster pool. - hwy::ThreadPool& cluster = - nested.Cluster(package_idx, cluster_idx); - // This plus the worker from `cluster->Run` is the TLS index. - const size_t worker_offset = - nested.WorkerOffset(package_idx, cluster_idx); - - const size_t row_begin = cluster_idx * rows_per_cluster; - const Range1D row_range = - MakeRange1D(row_begin, extents.rows, rows_per_cluster); - - func(Range2D(row_range, col_range), - TaskLocation(node, package_idx, cluster, worker_offset)); - }); - }); -} - -void BindTensor(NestedPools& nested, size_t rows, size_t cols, - size_t bytes_per_col, void* ptr); +// For future NUMA support. TODO: use. +void BindMemory(void* ptr, size_t bytes, size_t node); } // namespace gcpp diff --git a/util/app.h b/util/app.h index 5128a38..8736ecd 100644 --- a/util/app.h +++ b/util/app.h @@ -27,6 +27,7 @@ #include "compression/io.h" // Path #include "gemma/common.h" #include "gemma/gemma.h" // For CreateGemma +#include "ops/matmul.h" #include "util/args.h" #include "util/basics.h" // Tristate #include "util/threading.h" @@ -115,6 +116,7 @@ class AppArgs : public ArgsBase { } }; +// Callers must call Allocator::Init(pools.Topology()) after this. static inline NestedPools CreatePools(const AppArgs& app) { return NestedPools(app.max_threads, app.pin, BoundedSlice(app.skip_packages, app.max_packages), diff --git a/util/basics.h b/util/basics.h index cfe2204..bdec099 100644 --- a/util/basics.h +++ b/util/basics.h @@ -21,7 +21,7 @@ #include #include "hwy/aligned_allocator.h" -#include "hwy/base.h" // HWY_IS_MSAN +#include "hwy/base.h" // IWYU pragma: end_exports #if HWY_IS_MSAN @@ -60,7 +60,7 @@ struct TokenAndProb { float prob; }; -// Entire size of a 2D array. By contrast, Range2D is a subrange. +// Entire size of a 2D array. struct Extents2D { Extents2D() : rows(0), cols(0) {} Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) { @@ -74,11 +74,13 @@ struct Extents2D { size_t cols; }; -// Range2D consists of two Range1D. -struct Range1D { - Range1D(size_t begin, size_t end) : begin_(begin), end_(end) { +struct IndexRange { + IndexRange(size_t begin, size_t end) : begin_(begin), end_(end) { HWY_DASSERT(begin < end); } + IndexRange(const IndexRange& other) = default; + IndexRange& operator=(const IndexRange& other) = default; + size_t Num() const { return end_ - begin_; } // Enable range-based for loops. @@ -101,22 +103,15 @@ struct Range1D { Iterator begin() const { return Iterator(begin_); } Iterator end() const { return Iterator(end_); } - const size_t begin_; - const size_t end_; + size_t begin_; + size_t end_; }; -static inline Range1D MakeRange1D(size_t begin, size_t end, size_t max_size) { - return Range1D(begin, HWY_MIN(begin + max_size, end)); +static inline IndexRange MakeIndexRange(size_t begin, size_t end, + size_t max_size) { + return IndexRange(begin, HWY_MIN(begin + max_size, end)); } -// In MatMul, the two axes are used independently, hence we do not define -// Range2D as a top-left and extents. -struct Range2D { - Range2D(Range1D rows, Range1D cols) : rows(rows), cols(cols) {} - const Range1D rows; - const Range1D cols; -}; - // Lightweight version of `MatPtr` used for the C argument of `MatMul`, because // it is always float and does not support compressed T, but does support an // arbitrary stride >= cols. @@ -125,6 +120,10 @@ class RowPtr { public: RowPtr(T* HWY_RESTRICT row0, size_t cols) : row0_(row0), cols_(cols), stride_(cols) {} + RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride) + : row0_(row0), cols_(cols), stride_(stride) { + HWY_DASSERT(stride >= cols); + } T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; } size_t Cols() const { return cols_; } @@ -207,6 +206,7 @@ struct ConstMat { } const Extents2D& Extents() const { return extents; } + size_t Stride() const { return extents.cols; } // Shrinks the row-extent of this matrix view, i.e. reduces the view to a // subrange of the original rows starting at row 0. diff --git a/util/threading.cc b/util/threading.cc index b4bb84e..3c0ff0d 100644 --- a/util/threading.cc +++ b/util/threading.cc @@ -39,31 +39,106 @@ static void SortByDescendingSize(std::vector& groups) { [](const T& a, const T& b) { return a.Size() > b.Size(); }); } +// Singleton, holds the original process affinity and the pinning status. +class Pinning { + static bool InContainer() { + return false; } + + public: + // Returns set of LPs available for use. Subsequent calls return the same + // set as the first, because pinning overwrites the main thread's affinity. + // Thread-hostile, not called concurrently. + LPS EnabledLPs() { + if (original_affinity_.Any()) return original_affinity_; + + // Regardless of topology, ignore LPs disabled via OS, taskset, or numactl. + LPS enabled_lps; + if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) { + const size_t num_lps = hwy::TotalLogicalProcessors(); + fprintf( + stderr, + "Warning, unknown OS affinity, considering all %zu LPs enabled\n.", + num_lps); + for (size_t lp = 0; lp < num_lps; ++lp) { + enabled_lps.Set(lp); + } + } + + // Without threading support, only keep the first enabled LP; it might still + // make sense to pin the main thread to avoid migrations. + if (HWY_UNLIKELY(!hwy::HaveThreadingSupport())) { + HWY_ASSERT(enabled_lps.Any()); + const size_t lp = enabled_lps.First(); + enabled_lps = LPS(); + enabled_lps.Set(lp); + fprintf(stderr, + "Warning, threads not supported, using only the main thread\n."); + } + + original_affinity_ = enabled_lps; + return enabled_lps; + } + + void SetPolicy(Tristate pin) { + if (pin == Tristate::kDefault) { + // Pinning is unreliable inside containers because the hypervisor might + // periodically change our affinity mask, or other processes might also + // pin themselves to the same LPs. + pin = InContainer() ? Tristate::kFalse : Tristate::kTrue; + } + want_pin_ = (pin == Tristate::kTrue); + any_error_.clear(); + } + + // If want_pin_, tries to pin each worker in `pool` to an LP in `cluster`, + // and sets `any_error_` if any fails. + void MaybePin(const BoundedTopology::Cluster& cluster, PoolPtr& pool) { + if (HWY_UNLIKELY(!want_pin_)) return; + + const std::vector lps = cluster.LPVector(); + HWY_ASSERT(pool->NumWorkers() <= lps.size()); + pool->Run( + 0, pool->NumWorkers(), + [this, &pool, &lps](uint64_t task, size_t thread) { + HWY_ASSERT(task == thread); // each worker has one task + if (HWY_UNLIKELY(!hwy::PinThreadToLogicalProcessor(lps[task]))) { + fprintf(stderr, + "Pinning failed for task %zu of %zu to %zu (size %zu)\n", + task, pool->NumWorkers(), lps[task], lps.size()); + (void)any_error_.test_and_set(); + } + }); + } + + // Called ONCE after all MaybePin because it invalidates the error status. + bool AllPinned(const char** pin_string) { + // If !want_pin_, MaybePin will return without setting any_error_, but in + // that case we still want to return false to avoid spinning. + // .test() was only added in C++20, so we use .test_and_set() instead. + const bool all_pinned = want_pin_ && !any_error_.test_and_set(); + *pin_string = all_pinned ? "pinned" + : want_pin_ ? "pinning failed" + : "pinning skipped"; + return all_pinned; + } + + private: + std::atomic_flag any_error_ = ATOMIC_FLAG_INIT; + bool want_pin_; // set in SetPolicy + LPS original_affinity_; +}; // Pinning + +// Singleton saves global affinity across all BoundedTopology instances because +// pinning overwrites it. +static Pinning& GetPinning() { + static Pinning pinning; + return pinning; +} + BoundedTopology::BoundedTopology(BoundedSlice package_slice, BoundedSlice cluster_slice, BoundedSlice lp_slice) { - // Regardless of topology, ignore LPs disabled via OS, taskset, or numactl. - LPS enabled_lps; - if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) { - const size_t num_lps = hwy::TotalLogicalProcessors(); - fprintf(stderr, - "Warning, unknown OS affinity, considering all %zu LPs enabled\n.", - num_lps); - for (size_t lp = 0; lp < num_lps; ++lp) { - enabled_lps.Set(lp); - } - } - - // Without threading support, only keep the first enabled LP; it might still - // make sense to pin the main thread to avoid migrations. - if (HWY_UNLIKELY(!hwy::HaveThreadingSupport())) { - HWY_ASSERT(enabled_lps.Any()); - const size_t lp = enabled_lps.First(); - enabled_lps = LPS(); - enabled_lps.Set(lp); - fprintf(stderr, - "Warning, threads not supported, using only the main thread\n."); - } + const LPS enabled_lps = GetPinning().EnabledLPs(); #if !GEMMA_DISABLE_TOPOLOGY if (HWY_LIKELY(!topology_.packages.empty())) { @@ -110,19 +185,33 @@ BoundedTopology::Cluster::Cluster(const LPS& enabled_lps, AddLP(lp); - // Set `node` once, and ensure subsequent nodes match - we assume there - // is only one NUMA node per cluster. + // Set fields once, and ensure subsequent LPs match - we assume there + // is only one NUMA node per cluster, with the same L2/L3 size. const size_t lp_node = static_cast(all_lps[lp].node); if (is_first_lp) { is_first_lp = false; node_ = lp_node; + private_kib_ = tcluster.private_kib; + shared_kib_ = tcluster.shared_kib; } else { static bool warned = false; - if (lp_node != node_ && !warned) { - warned = true; - fprintf(stderr, "WARNING: lp %zu on node %zu != cluster node %zu.\n", - lp, lp_node, node_); - } + if (HWY_LIKELY(!warned)) { + if (HWY_UNLIKELY(lp_node != node_)) { + warned = true; + fprintf(stderr, "WARNING: lp %zu on node %zu != cluster node %zu.\n", + lp, lp_node, node_); + } + if (HWY_UNLIKELY(private_kib_ != tcluster.private_kib)) { + warned = true; + fprintf(stderr, "WARNING: lp %zu private_kib %zu != cluster %zu.\n", + lp, private_kib_, tcluster.private_kib); + } + if (HWY_UNLIKELY(shared_kib_ != tcluster.shared_kib)) { + warned = true; + fprintf(stderr, "WARNING: lp %zu shared_kib %zu != cluster %zu.\n", + lp, shared_kib_, tcluster.shared_kib); + } + } // !warned } }); } @@ -141,6 +230,7 @@ BoundedTopology::Package::Package(const LPS& enabled_lps, "cluster", tpackage.clusters.size(), [&](size_t cluster_idx) { const hwy::Topology::Cluster& tcluster = tpackage.clusters[cluster_idx]; Cluster cluster(enabled_lps, topology.lps, tcluster); + // Skip if empty, i.e. too few `enabled_lps`. if (HWY_LIKELY(cluster.Size() != 0)) { clusters.push_back(std::move(cluster)); @@ -267,56 +357,6 @@ static PoolPtr MakePool(size_t num_workers) { return std::make_unique(num_threads); } -static bool InContainer() { - return false;} - -class NestedPools::Pinning { - public: - Pinning(Tristate pin, const BoundedTopology& topology) { - if (pin == Tristate::kDefault) { - // Pinning is unreliable inside containers because the hypervisor might - // periodically change our affinity mask, or other processes might also - // pin themselves to the same LPs. - pin = InContainer() ? Tristate::kFalse : Tristate::kTrue; - } - want_pin_ = (pin == Tristate::kTrue); - } - - // If want_pin_, tries to pin each worker in `pool` to an LP in `cluster`, - // and sets `any_error_` if any fails. - void MaybePin(const BoundedTopology::Cluster& cluster, PoolPtr& pool) { - if (HWY_UNLIKELY(!want_pin_)) return; - - const std::vector lps = cluster.LPVector(); - HWY_ASSERT(pool->NumWorkers() <= lps.size()); - pool->Run( - 0, pool->NumWorkers(), - [this, &pool, &lps](uint64_t task, size_t thread) { - HWY_ASSERT(task == thread); // each worker has one task - if (HWY_UNLIKELY(!hwy::PinThreadToLogicalProcessor(lps[task]))) { - fprintf(stderr, - "Pinning failed for task %zu of %zu to %zu (size %zu)\n", - task, pool->NumWorkers(), lps[task], lps.size()); - (void)any_error_.test_and_set(); - } - }); - } - - bool WantPin() const { return want_pin_; } - - // Called ONCE after all MaybePin because it invalidates the error status. - bool AllPinned() { - // If !want_pin_, MaybePin will return without setting any_error_, but in - // that case we still want to return false to avoid spinning. - // .test() was only added in C++20, so we use .test_and_set() instead. - return want_pin_ && !any_error_.test_and_set(); - } - - private: - std::atomic_flag any_error_ = ATOMIC_FLAG_INIT; - bool want_pin_; // set in ctor -}; // Pinning - // Used to divide max_threads and max_workers_per_package across packages and // clusters. Ensures small upper bounds are respected. static size_t DivideMaxAcross(const size_t max, const size_t instances) { @@ -333,7 +373,7 @@ NestedPools::NestedPools(size_t max_threads, Tristate pin, BoundedSlice package_slice, BoundedSlice cluster_slice, BoundedSlice lp_slice) : topology_(package_slice, cluster_slice, lp_slice) { - Pinning pinning(pin, topology_); + GetPinning().SetPolicy(pin); packages_.resize(topology_.NumPackages()); all_packages_ = MakePool(packages_.size()); const size_t max_workers_per_package = @@ -344,14 +384,11 @@ NestedPools::NestedPools(size_t max_threads, Tristate pin, all_packages_->Run( 0, all_packages_->NumWorkers(), [&](uint64_t package_idx, size_t thread) { HWY_ASSERT(package_idx == thread); // each thread has one task - packages_[package_idx] = Package( - topology_, package_idx, max_workers_per_package, pinning, lp_slice); + packages_[package_idx] = + Package(topology_, package_idx, max_workers_per_package, lp_slice); }); - all_pinned_ = pinning.AllPinned(); - pin_string_ = all_pinned_ ? "pinned" - : pinning.WantPin() ? "pinning failed" - : "pinning skipped"; + all_pinned_ = GetPinning().AllPinned(&pin_string_); // For mapping package/cluster/thread to noncontiguous TLS indices, in case // cluster/thread counts differ. @@ -368,14 +405,9 @@ NestedPools::NestedPools(size_t max_threads, Tristate pin, HWY_ASSERT(max_workers_per_cluster_ <= 256); } -// `max_or_zero` == 0 means no limit. -static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) { - return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero); -} - NestedPools::Package::Package(const BoundedTopology& topology, size_t package_idx, - size_t max_workers_per_package, Pinning& pinning, + size_t max_workers_per_package, BoundedSlice lp_slice) { // Pre-allocate because elements are set concurrently. clusters_.resize(topology.NumClusters(package_idx)); @@ -393,7 +425,7 @@ NestedPools::Package::Package(const BoundedTopology& topology, clusters_[cluster_idx] = MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster)); // Pin workers AND the calling thread from `all_clusters`. - pinning.MaybePin(cluster, clusters_[cluster_idx]); + GetPinning().MaybePin(cluster, clusters_[cluster_idx]); }); } diff --git a/util/threading.h b/util/threading.h index 6be1503..604882e 100644 --- a/util/threading.h +++ b/util/threading.h @@ -17,14 +17,17 @@ #define THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ #include +#include #include // std::unique_ptr #include +// IWYU pragma: begin_exports #include "util/basics.h" // Tristate #include "hwy/base.h" // HWY_ASSERT #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/topology.h" +// IWYU pragma: end_exports #ifndef GEMMA_DISABLE_TOPOLOGY #define GEMMA_DISABLE_TOPOLOGY 0 @@ -32,6 +35,15 @@ namespace gcpp { +static inline size_t SaturatingSub(size_t a, size_t b) { + return a - HWY_MIN(a, b); +} + +// `max_or_zero` == 0 means no limit. +static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) { + return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero); +} + // A slice of a 1D integer range such as the indices of packages or clusters. // This allows assigning them to multiple instances of our binary. class BoundedSlice { @@ -86,6 +98,7 @@ using PoolPtr = std::unique_ptr; // back to a single package and cluster. class BoundedTopology { public: + // Thread-hostile, typically called from main thread. BoundedTopology(BoundedSlice package_slice, BoundedSlice cluster_slice, BoundedSlice lp_slice); @@ -112,6 +125,8 @@ class BoundedTopology { } size_t Node() const { return node_; } + size_t PrivateKiB() const { return private_kib_; } + size_t SharedKiB() const { return shared_kib_; } private: void AddLP(size_t lp) { @@ -126,6 +141,10 @@ class BoundedTopology { size_t num_workers_ = 0; // NUMA node, set from hwy::Topology::LP::node. size_t node_ = 0; + // L2 cache size in KiB, or 0 if unknown. + size_t private_kib_ = 0; + // L3 cache size in KiB, or 0 if unknown. + size_t shared_kib_ = 0; }; // Cluster size_t NumClusters(size_t package_idx) const { @@ -145,6 +164,10 @@ class BoundedTopology { return package.clusters[cluster_idx]; } +#if !GEMMA_DISABLE_TOPOLOGY + const hwy::Topology& FullTopology() const { return topology_; } +#endif + private: struct Package { // Topology is unknown, rely on OS affinity and user-specified slice. @@ -257,6 +280,8 @@ class NestedPools { // Returns the first of `cluster.NumWorkers()` TLS indices, to which callers // add the worker index given by `cluster.Run`. size_t WorkerOffset(size_t package_idx, size_t cluster_idx) const { + HWY_DASSERT(package_idx < packages_.size()); + HWY_DASSERT(cluster_idx < packages_[package_idx].NumClusters()); return (package_idx * max_clusters_per_package_ + cluster_idx) * max_workers_per_cluster_; } @@ -267,26 +292,25 @@ class NestedPools { const char* TopologyString() const { return topology_.TopologyString(); } const char* PinString() const { return pin_string_; } - // Returns a single pool on the first package: either one thread per cluster + // Returns a single pool on the given package: either one thread per cluster // if there is more than one, which maximizes available memory bandwidth, or // the first cluster, which is typically the whole package. For use by callers - // that only parallelize over a 1D range, as opposed to the nested - // parallelism of `StaticPartitionRowsAndCols`. - hwy::ThreadPool& Pool() { + // that only have a single parallel-for. + hwy::ThreadPool& Pool(size_t package_idx = 0) { // Only one cluster: use its pool, typically a whole socket. - if (AllClusters(0).NumWorkers() == 1) return Cluster(0, 0); - return AllClusters(0); + if (AllClusters(package_idx).NumWorkers() == 1) { + return Cluster(package_idx, 0); + } + // One worker per cluster to maximize bandwidth availability. + return AllClusters(package_idx); } private: - class Pinning; - class Package { public: Package() = default; // for vector Package(const BoundedTopology& topology, size_t package_idx, - size_t max_workers_per_package, Pinning& pinning, - BoundedSlice lp_slice); + size_t max_workers_per_package, BoundedSlice lp_slice); size_t NumClusters() const { return clusters_.size(); } size_t MaxWorkersPerCluster() const { @@ -330,11 +354,134 @@ class NestedPools { std::vector packages_; PoolPtr all_packages_; - // For TLS indices. + // For TLS indices. One might think this belongs in BoundedTopology, but it + // depends on max_threads, which is passed to the NestedPools constructor. size_t max_clusters_per_package_ = 0; size_t max_workers_per_cluster_ = 0; }; +// Splits `range` into subranges of size `task_size`, except for the last, +// which receives the remainder. Used with the `ParallelizeOneRange` etc. +// functions below. +class IndexRangePartition { + public: + IndexRangePartition(const IndexRange& range, const size_t task_size) + : range_(range), task_size_(task_size) { + const size_t num = range.Num(); + HWY_DASSERT(task_size_ != 0); + num_tasks_ = hwy::DivCeil(num, task_size_); + HWY_DASSERT(num_tasks_ != 0); + if constexpr (HWY_IS_DEBUG_BUILD) { + const size_t handled = num_tasks_ * task_size_; + // The last task may extend beyond items, but at most by (task_size_ - 1). + HWY_DASSERT(num <= handled && handled < num + task_size_); + } + } + + size_t TaskSize() const { return task_size_; } + size_t NumTasks() const { return num_tasks_; } + + IndexRange Range(size_t task_idx) const { + HWY_DASSERT(task_idx < NumTasks()); + return MakeIndexRange(range_.begin() + task_idx * task_size_, range_.end(), + task_size_); + } + + private: + IndexRange range_; + size_t task_size_; + size_t num_tasks_; +}; + +// Starts with `max_size` and rounds DOWN to a multiple of `size_multiple` +// unless that would be zero. It is the caller's responsibility to choose +// `size_multiple` to avoid two heavily imbalanced tasks. +// Use when the number of tasks does not matter, but each must fit into caches. +static inline IndexRangePartition MaxSizePartition(const IndexRange& range, + const size_t max_size, + const size_t size_multiple) { + HWY_DASSERT(size_multiple != 0); + size_t size = HWY_MIN(range.Num(), max_size); + if (size > size_multiple) size = hwy::RoundDownTo(size, size_multiple); + return IndexRangePartition(range, size); +} + +// Up to `max_tasks` tasks, each rounded UP to `size_multiple`, unless that +// would be more than the range. Use when the number of tasks is known, e.g. +// one per ThreadPool worker. +static inline IndexRangePartition StaticPartition(const IndexRange& range, + const size_t max_tasks, + const size_t size_multiple) { + HWY_DASSERT(max_tasks != 0); + size_t size = + hwy::RoundUpTo(hwy::DivCeil(range.Num(), max_tasks), size_multiple); + size = HWY_MIN(size, range.Num()); + return IndexRangePartition(range, size); +} + +// Parallel-for over a single range. This takes care of translating the task +// index to a range. +template +void ParallelizeOneRange(const IndexRangePartition& get1, hwy::ThreadPool& pool, + const Func& func) { + const size_t num_tasks = get1.NumTasks(); + pool.Run(0, num_tasks, [&](uint64_t task, size_t thread) { + const IndexRange range1 = get1.Range(task); + func(range1, thread); + }); +} + +// Parallel-for over the Cartesian product of the two sets of ranges. This +// combines their indices into a single 'task' so they can be executed by one +// `pool.Run`, which increases the amount of work available to workers and +// reduces fork-join overhead vs. nested parallel-for loops. Calls `func` with +// the two ranges and the thread index within `pool`. +template +void ParallelizeTwoRanges(const IndexRangePartition& get1, + const IndexRangePartition& get2, + hwy::ThreadPool& pool, const Func& func) { + const hwy::Divisor div1(static_cast(get1.NumTasks())); + + const size_t num_tasks = get1.NumTasks() * get2.NumTasks(); + pool.Run(0, num_tasks, [&](uint64_t task, size_t thread) { + HWY_DASSERT(task < (uint64_t{1} << 32)); + const size_t idx2 = div1.Divide(static_cast(task)); + const size_t idx1 = div1.Remainder(static_cast(task)); + HWY_DASSERT(idx1 < get1.NumTasks()); + HWY_DASSERT(idx2 < get2.NumTasks()); + const IndexRange range1 = get1.Range(idx1); + const IndexRange range2 = get2.Range(idx2); + func(range1, range2, thread); + }); +} + +// As above, for three ranges. +template +void ParallelizeThreeRanges(const IndexRangePartition& get1, + const IndexRangePartition& get2, + const IndexRangePartition& get3, + hwy::ThreadPool& pool, const Func& func) { + const hwy::Divisor div1(static_cast(get1.NumTasks())); + const size_t num12 = get1.NumTasks() * get2.NumTasks(); + const hwy::Divisor div12(static_cast(num12)); + + const size_t num_tasks = num12 * get3.NumTasks(); + pool.Run(0, num_tasks, [&](uint64_t task, size_t thread) { + HWY_DASSERT(task < (uint64_t{1} << 32)); + const size_t idx3 = div12.Divide(static_cast(task)); + const size_t task12 = div12.Remainder(static_cast(task)); + const size_t idx2 = div1.Divide(static_cast(task12)); + const size_t idx1 = div1.Remainder(static_cast(task12)); + HWY_DASSERT(idx1 < get1.NumTasks()); + HWY_DASSERT(idx2 < get2.NumTasks()); + HWY_DASSERT(idx3 < get3.NumTasks()); + const IndexRange range1 = get1.Range(idx1); + const IndexRange range2 = get2.Range(idx2); + const IndexRange range3 = get3.Range(idx3); + func(range1, range2, range3, thread); + }); +} + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ diff --git a/util/threading_test.cc b/util/threading_test.cc index 7f01f2c..2190e7e 100644 --- a/util/threading_test.cc +++ b/util/threading_test.cc @@ -23,6 +23,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "hwy/base.h" // HWY_ASSERT +#include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { namespace { @@ -111,5 +112,229 @@ TEST(ThreadingTest, TestBoundedTopology) { } } +TEST(ThreadingTest, TestMaxSizePartition) { + const IndexRange range(0, 100); + // Round down + { + const IndexRangePartition partition = MaxSizePartition(range, 55, 32); + HWY_ASSERT(partition.TaskSize() == 32); + HWY_ASSERT(partition.NumTasks() == 4); + } + // Huge `max_size`: single task + { + const IndexRangePartition partition = MaxSizePartition(range, 9999, 1); + HWY_ASSERT(partition.TaskSize() == 100); + HWY_ASSERT(partition.NumTasks() == 1); + } + // Huge `max_size`: `size_multiple` is still respected + { + const IndexRangePartition partition = MaxSizePartition(range, 9999, 64); + HWY_ASSERT(partition.TaskSize() == 64); + HWY_ASSERT(partition.NumTasks() == 2); + } + // `size_multiple` larger than range: ignore multiple + { + const IndexRangePartition partition = MaxSizePartition(range, 55, 128); + HWY_ASSERT(partition.TaskSize() == 55); + HWY_ASSERT(partition.NumTasks() == 2); + } + // Small `max_size`: small tasks + { + const IndexRangePartition partition = MaxSizePartition(range, 2, 1); + HWY_ASSERT(partition.TaskSize() == 2); + HWY_ASSERT(partition.NumTasks() == 50); + } + // Large `max_size`: two tasks with lots of overhang + { + const IndexRangePartition partition = MaxSizePartition(range, 98, 1); + HWY_ASSERT(partition.TaskSize() == 98); + HWY_ASSERT(partition.NumTasks() == 2); + } + // `size_multiple` almost as large as a different, smaller range: imbalanced + { + const IndexRangePartition partition = + MaxSizePartition(IndexRange(0, 6), 6, 4); + HWY_ASSERT(partition.TaskSize() == 4); + HWY_ASSERT(partition.NumTasks() == 2); + } +} + +TEST(ThreadingTest, TestStaticPartition) { + const IndexRange range(0, 100); + // Round up + { + const IndexRangePartition partition = StaticPartition(range, 2, 64); + HWY_ASSERT(partition.TaskSize() == 64); + HWY_ASSERT(partition.NumTasks() == 2); + } + // No `size_multiple`: division still rounds up + { + const IndexRangePartition partition = StaticPartition(range, 3, 1); + HWY_ASSERT(partition.TaskSize() == 34); + HWY_ASSERT(partition.NumTasks() == 3); + } + // Huge `max_tasks`: one each + { + const IndexRangePartition partition = StaticPartition(range, 9999, 1); + HWY_ASSERT(partition.TaskSize() == 1); + HWY_ASSERT(partition.NumTasks() == 100); + } + // `size_multiple` larger than range: single task + { + const IndexRangePartition partition = StaticPartition(range, 2, 128); + HWY_ASSERT(partition.TaskSize() == 100); + HWY_ASSERT(partition.NumTasks() == 1); + } + // `max_tasks` = 1: single task, even if rounding up would exceed the range + { + const IndexRangePartition partition = StaticPartition(range, 1, 8); + HWY_ASSERT(partition.TaskSize() == 100); + HWY_ASSERT(partition.NumTasks() == 1); + } +} + +TEST(ThreadingTest, TestParallelizeOneRange) { + const IndexRange range(0, 10); + const IndexRangePartition partition = StaticPartition(range, 2, 4); + hwy::ThreadPool null_pool(0); + size_t calls = 0; + ParallelizeOneRange(partition, null_pool, + [&](const IndexRange& range, size_t) { + if (++calls == 1) { + HWY_ASSERT(range.begin() == 0 && range.end() == 8); + } else { + HWY_ASSERT(range.begin() == 8 && range.end() == 10); + } + }); + HWY_ASSERT(calls == 2); +} + +TEST(ThreadingTest, TestParallelizeTwoRanges) { + const IndexRangePartition partition1 = + StaticPartition(IndexRange(0, 10), 2, 4); + const IndexRangePartition partition2 = + MaxSizePartition(IndexRange(128, 256), 32, 32); + HWY_ASSERT(partition2.NumTasks() == 4); + hwy::ThreadPool null_pool(0); + { + size_t calls = 0; + ParallelizeTwoRanges( + partition1, partition2, null_pool, + [&](const IndexRange& range1, const IndexRange& range2, size_t) { + ++calls; + HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8); + HWY_ASSERT(range2.begin() % 32 == 0); + HWY_ASSERT(range2.Num() % 32 == 0); + }); + HWY_ASSERT(calls == 2 * 4); + } + + // Also swap order to test Remainder() logic. + { + size_t calls = 0; + ParallelizeTwoRanges( + partition2, partition1, null_pool, + [&](const IndexRange& range2, const IndexRange& range1, size_t) { + ++calls; + HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8); + HWY_ASSERT(range2.begin() % 32 == 0); + HWY_ASSERT(range2.Num() % 32 == 0); + }); + HWY_ASSERT(calls == 2 * 4); + } +} + +TEST(ThreadingTest, TestParallelizeThreeRanges) { + // Named according to number of tasks. + const IndexRangePartition partition3 = + StaticPartition(IndexRange(0, 8), 3, 1); // [0, 3) [3, 6) [6, 8) + HWY_ASSERT(partition3.NumTasks() == 3); + const IndexRangePartition partition2 = + MaxSizePartition(IndexRange(10, 30), 10, 10); // [10, 20), [20, 30) + HWY_ASSERT(partition2.NumTasks() == 2); + const IndexRangePartition partition4 = + MaxSizePartition(IndexRange(100, 500), 100, 100); // 100, 200, 300, 400 + HWY_ASSERT(partition4.NumTasks() == 4); + + const auto check_ranges = [&](const IndexRange& range3, + const IndexRange& range2, + const IndexRange& range4) { + HWY_ASSERT(range3.begin() == 0 || range3.begin() == 3 || + range3.begin() == 6); + HWY_ASSERT(range2.begin() == 10 || range2.begin() == 20); + HWY_ASSERT(range4.begin() % 100 == 0); + }; + + hwy::ThreadPool null_pool(0); + // All 6 permutations of the three ranges to test the Remainder() logic: + // 3, 2, 4 + { + size_t calls = 0; + ParallelizeThreeRanges( + partition3, partition2, partition4, null_pool, + [&](IndexRange range3, IndexRange range2, IndexRange range4, size_t) { + ++calls; + check_ranges(range3, range2, range4); + }); + HWY_ASSERT(calls == 3 * 2 * 4); + } + // 3, 4, 2 + { + size_t calls = 0; + ParallelizeThreeRanges( + partition3, partition4, partition2, null_pool, + [&](IndexRange range3, IndexRange range4, IndexRange range2, size_t) { + ++calls; + check_ranges(range3, range2, range4); + }); + HWY_ASSERT(calls == 3 * 2 * 4); + } + + // 4, 2, 3 + { + size_t calls = 0; + ParallelizeThreeRanges( + partition4, partition2, partition3, null_pool, + [&](IndexRange range4, IndexRange range2, IndexRange range3, size_t) { + ++calls; + check_ranges(range3, range2, range4); + }); + HWY_ASSERT(calls == 3 * 2 * 4); + } + // 4, 3, 2 + { + size_t calls = 0; + ParallelizeThreeRanges( + partition4, partition3, partition2, null_pool, + [&](IndexRange range4, IndexRange range3, IndexRange range2, size_t) { + ++calls; + check_ranges(range3, range2, range4); + }); + HWY_ASSERT(calls == 3 * 2 * 4); + } + + // 2, 3, 4 + { + size_t calls = 0; + ParallelizeThreeRanges( + partition2, partition3, partition4, null_pool, + [&](IndexRange range2, IndexRange range3, IndexRange range4, size_t) { + ++calls; + check_ranges(range3, range2, range4); + }); + HWY_ASSERT(calls == 3 * 2 * 4); + } + // 2, 4, 3 + { + size_t calls = 0; + ParallelizeThreeRanges( + partition2, partition4, partition3, null_pool, + [&](IndexRange range2, IndexRange range4, IndexRange range3, size_t) { + ++calls; + check_ranges(range3, range2, range4); + }); + HWY_ASSERT(calls == 3 * 2 * 4); + } +} } // namespace } // namespace gcpp