Threading/infra improvements.

* Add Parallelize*Range helpers and partitioning helpers
* Refactor Pinning class, store original affinity (required to construct another NestedPools after pinning happened)

Compress:
* prevent Compress printing stats in tests
* zero-pad tensors

Matmul:
* add matmul_unit_test (TODO) and bench_matmul
* matmul_test: change norm to row vectors (that is what is added) and include bf16 rounding error
* Prepare for L2/L3 retrieval
PiperOrigin-RevId: 700603811
This commit is contained in:
Jan Wassenberg 2024-11-27 01:11:20 -08:00 committed by Copybara-Service
parent 109a4d9f85
commit f74d496879
20 changed files with 1001 additions and 294 deletions

View File

@ -71,6 +71,7 @@ cc_test(
"@googletest//:gtest_main", "@googletest//:gtest_main",
"@highway//:hwy", "@highway//:hwy",
"@highway//:hwy_test_util", "@highway//:hwy_test_util",
"@highway//:thread_pool",
], ],
) )
@ -166,6 +167,26 @@ cc_test(
], ],
) )
cc_test(
name = "matmul_unit_test",
size = "small",
timeout = "long",
srcs = ["ops/matmul_unit_test.cc"],
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
deps = [
":allocator",
":basics",
":ops",
":test_util",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"@highway//:hwy",
"@highway//:hwy_test_util",
],
)
cc_test( cc_test(
name = "matmul_test", name = "matmul_test",
size = "small", size = "small",
@ -178,7 +199,28 @@ cc_test(
":allocator", ":allocator",
":basics", ":basics",
":ops", ":ops",
":test_util", ":threading",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
"@highway//:thread_pool",
],
)
cc_test(
name = "bench_matmul",
size = "small",
timeout = "long",
srcs = ["ops/bench_matmul.cc"],
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
deps = [
":allocator",
":basics",
":ops",
":threading", ":threading",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress", "//compression:compress",
@ -343,6 +385,7 @@ cc_library(
":basics", ":basics",
":common", ":common",
":gemma_lib", ":gemma_lib",
":ops",
":threading", ":threading",
"//compression:io", "//compression:io",
"@highway//:hwy", "@highway//:hwy",
@ -624,6 +667,7 @@ cc_test(
"mem": "28g", "mem": "28g",
}, },
deps = [ deps = [
":allocator",
":backprop", ":backprop",
":basics", ":basics",
":common", ":common",

View File

@ -163,9 +163,11 @@ set(GEMMA_TEST_FILES
compression/sfp_test.cc compression/sfp_test.cc
evals/gemma_test.cc evals/gemma_test.cc
gemma/tensor_index_test.cc gemma/tensor_index_test.cc
ops/bench_matmul.cc
ops/dot_test.cc ops/dot_test.cc
ops/gemma_matvec_test.cc ops/gemma_matvec_test.cc
ops/matmul_test.cc ops/matmul_test.cc
ops/matmul_unit_test.cc
ops/ops_test.cc ops/ops_test.cc
paligemma/image_test.cc paligemma/image_test.cc
paligemma/paligemma_test.cc paligemma/paligemma_test.cc

View File

@ -33,6 +33,7 @@
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "util/allocator.h"
#include "util/basics.h" #include "util/basics.h"
#include "util/threading.h" #include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -42,6 +43,7 @@ namespace gcpp {
TEST(OptimizeTest, GradientDescent) { TEST(OptimizeTest, GradientDescent) {
NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1), NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
BoundedSlice(0, 1)); BoundedSlice(0, 1));
Allocator::Init(pools.Topology());
hwy::ThreadPool& pool = pools.Pool(); hwy::ThreadPool& pool = pools.Pool();
std::mt19937 gen(42); std::mt19937 gen(42);

View File

@ -51,6 +51,12 @@ namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
#ifdef HWY_IS_TEST
static constexpr bool kIsTest = true;
#else
static constexpr bool kIsTest = false;
#endif
// Enables generic code independent of compression type. // Enables generic code independent of compression type.
template <typename T> // primary, must specialize template <typename T> // primary, must specialize
struct CompressTraits {}; struct CompressTraits {};
@ -438,7 +444,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
} }
} }
const bool want_bench = num > 1024 * 1024 || COMPRESS_STATS; const bool want_bench = COMPRESS_STATS || !kIsTest;
const double t0 = want_bench ? hwy::platform::Now() : 0.0; const double t0 = want_bench ? hwy::platform::Now() : 0.0;
using Traits = CompressTraits<Packed>; using Traits = CompressTraits<Packed>;

View File

@ -38,6 +38,7 @@
#include "util/basics.h" #include "util/basics.h"
// IWYU pragma: end_exports // IWYU pragma: end_exports
#include "util/allocator.h" #include "util/allocator.h"
#include "hwy/per_target.h"
#if COMPRESS_STATS #if COMPRESS_STATS
#include "compression/distortion.h" #include "compression/distortion.h"
#include "hwy/stats.h" #include "hwy/stats.h"
@ -360,7 +361,11 @@ class MatStorageT : public MatPtrT<MatT> {
} else { } else {
this->num_elements_ = num_elements; this->num_elements_ = num_elements;
} }
data_ = Allocator::Alloc<MatT>(num_elements); // Pad to allow overrunning the last row by 2 BF16 vectors, hence at most
// `2 * VectorBytes / sizeof(BF16)` elements of MatT.
const size_t padding = hwy::VectorBytes();
data_ = Allocator::Alloc<MatT>(num_elements + padding);
hwy::ZeroBytes(&data_[num_elements], padding * sizeof(MatT));
this->ptr_ = data_.get(); this->ptr_ = data_.get();
} }

View File

@ -56,6 +56,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
const AppArgs& app) const AppArgs& app)
: pools_(CreatePools(app)) { : pools_(CreatePools(app)) {
Allocator::Init(pools_.Topology());
InferenceArgs mutable_inference = inference; InferenceArgs mutable_inference = inference;
AbortIfInvalidArgs(mutable_inference); AbortIfInvalidArgs(mutable_inference);
LoaderArgs mutable_loader = loader; LoaderArgs mutable_loader = loader;

View File

@ -59,6 +59,7 @@ int main(int argc, char** argv) {
// Instantiate model and KV Cache // Instantiate model and KV Cache
gcpp::NestedPools pools = gcpp::CreatePools(app); gcpp::NestedPools pools = gcpp::CreatePools(app);
gcpp::Allocator::Init(pools.Topology());
gcpp::Gemma model = gcpp::CreateGemma(loader, pools); gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
gcpp::KVCache kv_cache = gcpp::KVCache kv_cache =
gcpp::KVCache::Create(model.GetModelConfig(), gcpp::KVCache::Create(model.GetModelConfig(),

214
ops/bench_matmul.cc Normal file
View File

@ -0,0 +1,214 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Benchmark of large MatMul instances for which the MatMulSlow would be too
// slow. This lacks a reference and is only useful for performance measurement.
#include "hwy/detect_compiler_arch.h"
#ifndef HWY_DISABLED_TARGETS
// Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require
// double-precision support.
#if HWY_ARCH_ARM_V7
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON)
#else
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
#endif
#include <stddef.h>
#include <stdio.h>
#include <memory>
#include "compression/compress.h"
#include "compression/shared.h"
#include "ops/matmul.h"
#include "util/allocator.h"
#include "util/basics.h"
#include "util/threading.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/timer.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "ops/bench_matmul.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "ops/matmul-inl.h"
#include "hwy/tests/test_util-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
// For running BenchAllMatMul only once. Defined within HWY_ONCE.
extern int64_t first_target;
namespace HWY_NAMESPACE {
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
template <typename MatT>
using MatStoragePtr = std::unique_ptr<MatStorageT<MatT>>;
// Generates inputs: deterministic, within max SfpStream range.
template <typename MatT>
MatStoragePtr<MatT> GenerateMat(const Extents2D extents,
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
auto mat =
std::make_unique<MatStorageT<MatT>>("mat", extents.rows, extents.cols);
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
HWY_ASSERT(content);
const float scale = SfpStream::kMax / (mat->NumElements());
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
for (size_t c = 0; c < extents.cols; c++) {
float f = static_cast<float>(r * extents.cols + c) * scale;
if ((r + c) & 1) f = -f; // Also generate some negative values.
content[r * extents.cols + c] = f;
}
});
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
mat->set_scale(0.6f); // Arbitrary value, different from 1.
return mat;
}
// extents describes the transposed matrix.
template <typename MatT>
MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
auto mat =
std::make_unique<MatStorageT<MatT>>("trans", extents.rows, extents.cols);
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
const float scale = SfpStream::kMax / (mat->NumElements());
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
for (size_t c = 0; c < extents.cols; c++) {
float f = static_cast<float>(c * extents.rows + r) * scale;
if ((r + c) & 1) f = -f; // Also generate some negative values.
content[r * extents.cols + c] = f;
}
});
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
// Arbitrary value, different from 1, must match GenerateMat.
mat->set_scale(0.6f);
return mat;
}
void PrintSpeed(const char* algo, const Extents2D& A_extents,
const Extents2D& B_extents, double elapsed) {
const size_t num_b = B_extents.Area();
// 2x because of FMA.
fprintf(stderr, " %10s: %f seconds, %.1f GFLOPS.\n", algo,
elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed);
}
// Generates inputs and prints observed throughput of MatMul.
template <typename MatTA, typename MatTB = MatTA>
void BenchMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatMulEnv& env) {
hwy::ThreadPool& pool = env.Pool();
fprintf(stderr, "BenchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n",
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<MatTA>(),
TypeName<MatTB>());
const Extents2D A_extents(rows_ac, cols_a_rows_b);
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
const Extents2D C_extents(rows_ac, cols_bc);
MatStoragePtr<MatTA> a = GenerateMat<MatTA>(A_extents, pool);
MatStoragePtr<MatTB> b_trans = GenerateTransposedMat<MatTB>(B_extents, pool);
RowVectorBatch<float> c_slow_batch(C_extents);
RowVectorBatch<float> c_batch(C_extents);
HWY_ASSERT(a && b_trans);
std::unique_ptr<MatStorageT<float>> add_storage;
if (add) {
add_storage = GenerateMat<float>(Extents2D(1, cols_bc), pool);
HWY_ASSERT(add_storage);
add_storage->set_scale(1.0f);
}
const auto A = ConstMatFromWeights(*a);
const auto B = ConstMatFromWeights(*b_trans);
const float* add_row = add ? add_storage->data_scale1() : nullptr;
const RowPtrF C = RowPtrFromBatch(c_batch);
double min_elapsed = hwy::HighestValue<double>();
for (int rep = 0; rep < 3; ++rep) {
const double start_tiled = hwy::platform::Now();
MatMul(A, B, add_row, env, C);
min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled);
}
PrintSpeed("MatMul", A_extents, B_extents, min_elapsed);
}
using F32 = float;
using SFP = SfpStream;
void BenchAllMatMul() {
if (first_target == 0) first_target = HWY_TARGET;
if (HWY_TARGET != first_target) return;
for (size_t max_packages : {1, 2}) {
const size_t max_threads = 0; // no limit
NestedPools pools(max_threads, Tristate::kDefault,
BoundedSlice(0, max_packages));
#if GEMMA_DISABLE_TOPOLOGY
if (max_packages == 2) break; // we only have one package
#else
// If less than the limit, we have already tested all num_packages.
if (pools.Topology().FullTopology().packages.size() < max_packages) break;
#endif
fprintf(stderr, "BenchAllMatMul %zu: %s %s\n", max_packages,
pools.TopologyString(), pools.PinString());
Tristate use_spinning = Tristate::kDefault;
pools.MaybeStartSpinning(use_spinning);
Allocator::Init(pools.Topology());
MatMulEnv env(pools);
for (size_t batch_size : {1, /*4, 64,*/ 128}) {
BenchMatMul<F32, F32>(batch_size, 24576, 3072, /*add=*/false, env);
BenchMatMul<F32, F32>(batch_size, 3072, 24576, /*add=*/false, env);
BenchMatMul<BF16, SFP>(batch_size, 24576, 3072, /*add=*/false, env);
BenchMatMul<BF16, SFP>(batch_size, 3072, 24576, /*add=*/false, env);
BenchMatMul<F32, SFP>(batch_size, 24576, 3072, /*add=*/false, env);
BenchMatMul<F32, SFP>(batch_size, 3072, 24576, /*add=*/false, env);
}
pools.MaybeStopSpinning(use_spinning);
}
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
int64_t first_target = 0; // none run yet
HWY_BEFORE_TEST(BenchMatMul);
HWY_EXPORT_AND_TEST_P(BenchMatMul, BenchAllMatMul);
HWY_AFTER_TEST();
} // namespace gcpp
#endif

View File

@ -1109,6 +1109,7 @@ void TestAllDot() {
const size_t num = 24 * 1024; const size_t num = 24 * 1024;
NestedPools pools(kMaxWorkers - 1, /*pin=*/Tristate::kDefault, NestedPools pools(kMaxWorkers - 1, /*pin=*/Tristate::kDefault,
BoundedSlice(0, 1), BoundedSlice(0, 1)); BoundedSlice(0, 1), BoundedSlice(0, 1));
Allocator::Init(pools.Topology());
RowVectorBatch<float> a(Extents2D(kMaxWorkers, num)); RowVectorBatch<float> a(Extents2D(kMaxWorkers, num));
RowVectorBatch<float> b(Extents2D(kMaxWorkers, num)); RowVectorBatch<float> b(Extents2D(kMaxWorkers, num));
RowVectorBatch<double> bufs(Extents2D(kMaxWorkers, num)); RowVectorBatch<double> bufs(Extents2D(kMaxWorkers, num));

View File

@ -38,22 +38,6 @@ namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
// The MatMul result C[r,c] is Dot(A.Row(r), B.Col(c)). To reduce the number of
// loads, we reuse the same A row for several B columns, which are also loaded
// once for several rows of C. Thus we produce one 'tile' of C at a time of
// dimensions `kRegRows` x `kRegCols`. The Reg naming is because these are
// limited by the number of registers: 32 for NEON/SVE/AVX-512. `kRegCols` == 4
// enables the `StoreInterleaved4` transpose in `AddHorizontalSums`. We assume
// and verify that `C.cols % kRegCols == 0`.
constexpr size_t kRegCols = 4;
// Choosing `kRegRows == kRegCols` minimizes the ratio of loads to FMA, because
// we load `kRegCols + kRegRows` vectors per `kRegRows * kRegCols` element tile.
// In general, `batch_size` (C rows) is not a multiple of `kRegRows`. Thus
// functions that load or store a tile are parameterized on `kNumRows`, which is
// generally `kRegRows`, but `batch_size % kRegRows` on the last row (if != 0).
constexpr size_t kRegRows = kRegCols;
// Loads two vectors at a time with element type hn::TFromD<DR> from a row of // Loads two vectors at a time with element type hn::TFromD<DR> from a row of
// transposed B. Called in a loop over col_ab. No bounds checking because // transposed B. Called in a loop over col_ab. No bounds checking because
// `kRow` is from B columns, which we checked is a multiple of `kRegCols`. // `kRow` is from B columns, which we checked is a multiple of `kRegCols`.

View File

@ -21,6 +21,7 @@
// IWYU pragma: begin_exports // IWYU pragma: begin_exports
#include "util/basics.h" #include "util/basics.h"
#include "util/threading.h" #include "util/threading.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
// IWYU pragma: end_exports // IWYU pragma: end_exports
@ -28,6 +29,40 @@
namespace gcpp { namespace gcpp {
// TODO: remove deprecated typedef.
using Range1D = IndexRange;
// The MatMul result C[r,c] is Dot(A.Row(r), B.Col(c)). To reduce the number of
// loads, we reuse the same A row for several B columns, which are also loaded
// once for several rows of C. Thus we produce one 'tile' of C at a time of
// dimensions `kRegRows` x `kRegCols`. The Reg naming is because these are
// limited by the number of registers: 32 for NEON/SVE/AVX-512. `kRegCols` == 4
// enables the `StoreInterleaved4` transpose in `StoreHorizontalSums`. We assume
// and verify that `C.Cols() % kRegCols == 0`.
constexpr size_t kRegCols = 4;
// Choosing `kRegRows == kRegCols` minimizes the ratio of loads to FMA, because
// we load `kRegCols + kRegRows` vectors per `kRegRows * kRegCols` element tile.
// In general, `batch_size` (A/C rows) is not a multiple of `kRegRows`. Thus
// functions that load or store a tile are parameterized on `kRowsPerTile`:
// usually `kRegRows`, but `batch_size % kRegRows` on the last row (if != 0).
constexpr size_t kRegRows = kRegCols;
struct CacheSizes {
CacheSizes() = default;
CacheSizes(const BoundedTopology::Cluster& cluster) {
// Assumes each package and cluster has the same cache sizes, and uses
// reasonable defaults if unknown.
l1_bytes = 32 * 1024; // typical size, rarely changes
l2_bytes = (cluster.PrivateKiB() ? cluster.PrivateKiB() : 256) * 1024;
l3_bytes = (cluster.SharedKiB() ? cluster.SharedKiB() : 1024) * 1024;
}
size_t l1_bytes;
size_t l2_bytes;
size_t l3_bytes;
};
// Allocations and threads, shared across MatMul calls. // Allocations and threads, shared across MatMul calls.
class MatMulEnv { class MatMulEnv {
public: public:

View File

@ -13,6 +13,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// End to end test of MatMul, comparing against a reference implementation.
#include "hwy/detect_compiler_arch.h"
#ifndef HWY_DISABLED_TARGETS #ifndef HWY_DISABLED_TARGETS
// Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require // Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require
// double-precision support. // double-precision support.
@ -23,14 +26,14 @@
#endif #endif
#endif #endif
#include "ops/matmul.h"
#include <stddef.h> #include <stddef.h>
#include <stdio.h> #include <stdio.h>
#include <memory> #include <memory>
#include "compression/compress.h" #include "compression/compress.h"
#include "compression/shared.h"
#include "ops/matmul.h"
#include "util/allocator.h" #include "util/allocator.h"
#include "util/basics.h" #include "util/basics.h"
#include "util/threading.h" #include "util/threading.h"
@ -52,7 +55,11 @@
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
// For running TestBatchSizes only once. Defined within HWY_ONCE.
extern int64_t first_target;
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>; using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
@ -71,8 +78,9 @@ MatStoragePtr<MatT> GenerateMat(const Extents2D extents,
const float scale = SfpStream::kMax / (mat->NumElements()); const float scale = SfpStream::kMax / (mat->NumElements());
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) { pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
for (size_t c = 0; c < extents.cols; c++) { for (size_t c = 0; c < extents.cols; c++) {
content[r * extents.cols + c] = float f = static_cast<float>(r * extents.cols + c) * scale;
static_cast<float>(r * extents.cols + c) * scale; if ((r + c) & 1) f = -f; // Also generate some negative values.
content[r * extents.cols + c] = f;
} }
}); });
@ -92,8 +100,9 @@ MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
const float scale = SfpStream::kMax / (mat->NumElements()); const float scale = SfpStream::kMax / (mat->NumElements());
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) { pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
for (size_t c = 0; c < extents.cols; c++) { for (size_t c = 0; c < extents.cols; c++) {
content[r * extents.cols + c] = float f = static_cast<float>(c * extents.rows + r) * scale;
static_cast<float>(c * extents.rows + r) * scale; if ((r + c) & 1) f = -f; // Also generate some negative values.
content[r * extents.cols + c] = f;
} }
}); });
@ -104,16 +113,28 @@ MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
} }
// Returns 1-norm, used for estimating tolerable numerical differences. // Returns 1-norm, used for estimating tolerable numerical differences.
double MaxColAbsSum(const float* HWY_RESTRICT a, const Extents2D& extents) { double MaxRowAbsSum(const float* HWY_RESTRICT a, const Extents2D& extents) {
double max_col_abs_sum = 0.0; double max_row_abs_sum = 0.0;
for (size_t c = 0; c < extents.cols; c++) { for (size_t r = 0; r < extents.rows; r++) {
double col_abs_sum = 0.0; const float* row = a + r * extents.cols;
for (size_t r = 0; r < extents.rows; r++) { double row_abs_sum = 0.0;
col_abs_sum += hwy::ScalarAbs(a[r * extents.cols + c]); for (size_t c = 0; c < extents.cols; c++) {
row_abs_sum += hwy::ScalarAbs(row[c]);
} }
max_col_abs_sum = HWY_MAX(max_col_abs_sum, col_abs_sum); max_row_abs_sum = HWY_MAX(max_row_abs_sum, row_abs_sum);
} }
return max_col_abs_sum; return max_row_abs_sum;
}
// Returns the maximum absolute value of `a`.
float MaxAbs(const float* HWY_RESTRICT a, const Extents2D& extents) {
float max_abs = 0.0f;
for (size_t c = 0; c < extents.cols; c++) {
for (size_t r = 0; r < extents.rows; r++) {
max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(a[r * extents.cols + c]));
}
}
return max_abs;
} }
// B is already transposed. // B is already transposed.
@ -132,12 +153,25 @@ void AssertClose(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
DecompressAndZeroPad(df, MakeSpan(A.ptr, num_a), 0, a.get(), num_a); DecompressAndZeroPad(df, MakeSpan(A.ptr, num_a), 0, a.get(), num_a);
DecompressAndZeroPad(df, MakeSpan(B.ptr, num_b), 0, b_trans.get(), num_b); DecompressAndZeroPad(df, MakeSpan(B.ptr, num_b), 0, b_trans.get(), num_b);
const double norm = MaxColAbsSum(a.get(), A.Extents()) * // MatMul rounds inputs to BF16, so error is proportional to the max input
MaxColAbsSum(b_trans.get(), B.Extents()); // magnitude, but also to f32 accumulation of rows in A and B.
// Dot(float,BF16) rounds both to BF16. const double norm = MaxRowAbsSum(a.get(), A.Extents()) *
using RefType = hwy::If<IsF32<MatTA>() && IsF32<MatTB>(), float, BF16>; MaxRowAbsSum(b_trans.get(), B.Extents());
const double epsilon = hwy::ConvertScalarTo<double>(hwy::Epsilon<RefType>()); const float max_abs =
const double tolerance = 200.0 * norm * epsilon; MaxAbs(a.get(), A.Extents()) * MaxAbs(b_trans.get(), B.Extents());
const double eps_bf16 = hwy::ConvertScalarTo<double>(hwy::Epsilon<BF16>());
const double eps_f32 = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
double tolerance = 8 * norm * eps_f32;
// Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the
// tolerance there.
if (IsF32<MatTA>() && IsF32<MatTB>()) {
tolerance += 4 * max_abs * eps_bf16;
}
EXPECT_GE(tolerance, 1E-4);
if (tolerance > 4.0) {
fprintf(stderr, "WARN: high tolerance %f norm %f maxabs %f\n", tolerance,
norm, max_abs);
}
for (size_t r = 0; r < A.extents.rows; r++) { for (size_t r = 0; r < A.extents.rows; r++) {
const float* expected_row = C_slow.Row(r); const float* expected_row = C_slow.Row(r);
@ -148,10 +182,11 @@ void AssertClose(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
if (!(expected_value - tolerance <= actual_value && if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) { actual_value <= expected_value + tolerance)) {
fprintf( fprintf(stderr,
stderr, "(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
"(%zu,%zu): expected %f, actual %f, norm %f eps %E tolerance %f\n", "tolerance %f\n",
r, c, expected_value, actual_value, norm, epsilon, tolerance); r, c, expected_value, actual_value, norm, max_abs, tolerance);
return;
} }
} }
} }
@ -171,20 +206,31 @@ HWY_INLINE void MatMulSlow(const ConstMat<MatTA> A, const ConstMat<MatTB> B,
const hn::ScalableTag<float> df; // lane type is ignored const hn::ScalableTag<float> df; // lane type is ignored
const PackedSpan<const MatTB> b_span = const PackedSpan<const MatTB> b_span =
MakeSpan(B.ptr, B.ofs + B.extents.Area()); MakeSpan(B.ptr, B.ofs + B.extents.Area());
const Extents2D C_extents(A.extents.rows, C.Cols()); const IndexRange all_rows_c(0, A.Extents().rows);
const IndexRange all_cols_c(0, C.Cols());
StaticPartitionRowsAndCols( NestedPools& pools = env.Pools();
env.Pools(), C_extents, sizeof(MatTB), hwy::ThreadPool& all_packages = pools.AllPackages();
[&](const Range2D& C_range, const TaskLocation& loc) { const IndexRangePartition get_row_c =
loc.cluster.Run( StaticPartition(all_rows_c, all_packages.NumWorkers(), 1);
C_range.rows.begin(), C_range.rows.end(), ParallelizeOneRange(
[&](const uint64_t row, size_t /*thread*/) { get_row_c, all_packages,
float* HWY_RESTRICT C_row = C.Row(row); [&](const IndexRange& rows_c, size_t package_idx) HWY_ATTR {
for (size_t row_b_col_c : C_range.cols) { hwy::ThreadPool& all_clusters = pools.AllClusters(package_idx);
const float add = add_row ? add_row[row_b_col_c] : 0.0f; const size_t multiple = Allocator::Alignment() / sizeof(MatTB);
C_row[row_b_col_c] = const IndexRangePartition get_col_c =
add + scale * Dot(df, b_span, row_b_col_c * B.extents.cols, StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
A.ptr + A.Row(row), A.extents.cols); ParallelizeOneRange(
get_col_c, all_clusters,
[&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR {
for (size_t r : rows_c) {
float* HWY_RESTRICT C_row = C.Row(r);
for (size_t c : cols_c) {
const float add = add_row ? add_row[c] : 0.0f;
C_row[c] =
add + scale * Dot(df, b_span, c * B.extents.cols,
A.ptr + A.Row(r), A.extents.cols);
}
} }
}); });
}); });
@ -250,6 +296,40 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
AssertClose(A, B, C_slow, C); AssertClose(A, B, C_slow, C);
} }
using F32 = float;
using SFP = SfpStream;
// Sweep batch_size for a single input type and Highway target, to verify the
// row partitioning.
void TestBatchSizes() {
if (first_target == 0) first_target = HWY_TARGET;
if (HWY_TARGET != first_target) return;
for (size_t max_packages : {1, 2}) {
const size_t max_threads = 0; // no limit
NestedPools pools(max_threads, Tristate::kDefault,
BoundedSlice(0, max_packages));
#if GEMMA_DISABLE_TOPOLOGY
if (max_packages == 2) break; // we only have one package
#else
// If less than the limit, we have already tested all num_packages.
if (pools.Topology().FullTopology().packages.size() < max_packages) break;
#endif
fprintf(stderr, "TestBatchSizes %zu: %s %s\n", max_packages,
pools.TopologyString(), pools.PinString());
Tristate use_spinning = Tristate::kDefault;
pools.MaybeStartSpinning(use_spinning);
Allocator::Init(pools.Topology());
MatMulEnv env(pools);
for (size_t batch_size = 1; batch_size <= 3 * kRegRows; ++batch_size) {
TestMatMul<F32, F32>(batch_size, 256, 256, /*add=*/false, env);
}
pools.MaybeStopSpinning(use_spinning);
}
}
void TestAllMatMul() { void TestAllMatMul() {
// Skip EMU128 (10x slower than SSE4 for SFP) and older x86. // Skip EMU128 (10x slower than SSE4 for SFP) and older x86.
if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSE4 || if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSE4 ||
@ -257,32 +337,30 @@ void TestAllMatMul() {
return; return;
} }
NestedPools pools(4, /*pin=*/Tristate::kDefault); NestedPools pools(0); // no limits
Tristate use_spinning = Tristate::kDefault; Tristate use_spinning = Tristate::kDefault;
pools.MaybeStartSpinning(use_spinning); pools.MaybeStartSpinning(use_spinning);
Allocator::Init(pools.Topology()); Allocator::Init(pools.Topology());
MatMulEnv env(pools); MatMulEnv env(pools);
using F32 = float; // Sizes seen in gemma_test 2B.
using SFP = SfpStream; TestMatMul<F32>(1, 2048, 512, /*add=*/false, env);
TestMatMul<F32>(1, 2048, 2048, /*add=*/false, env);
TestMatMul<F32>(1, 2048, 16384, /*add=*/false, env);
TestMatMul<F32>(1, 16384, 2048, /*add=*/false, env);
TestMatMul<F32>(1, 2048, 256000, /*add=*/false, env);
TestMatMul<F32>(5, 2048, 512, /*add=*/false, env);
TestMatMul<F32>(5, 2048, 2048, /*add=*/false, env);
TestMatMul<F32>(5, 2048, 16384, /*add=*/false, env);
TestMatMul<F32>(5, 16384, 2048, /*add=*/false, env);
// large-scale test: batch_size=128 is better than 64 or 256 for SKX. // medium-sized square
// TestMatMul<F32, SFP>(128, 24576, 3072, /*add=*/false, env); TestMatMul<F32>(512, 512, 512, /*add=*/false, env);
// TestMatMul<F32, SFP>(128, 3072, 24576, /*add=*/false, env); TestMatMul<BF16>(512, 512, 512, /*add=*/true, env);
TestMatMul<F32, F32>(1, 24576, 3072, /*add=*/false, env); TestMatMul<F32, BF16>(512, 512, 512, /*add=*/false, env);
TestMatMul<F32, F32>(1, 3072, 24576, /*add=*/false, env); TestMatMul<BF16, F32>(512, 512, 512, /*add=*/true, env);
TestMatMul<F32, SFP>(1, 24576, 3072, /*add=*/false, env); TestMatMul<F32, SFP>(512, 512, 512, /*add=*/false, env);
TestMatMul<F32, SFP>(1, 3072, 24576, /*add=*/false, env); TestMatMul<BF16, SFP>(512, 512, 512, /*add=*/true, env);
// medium-sized square test - temporarily disabled for faster testing.
if constexpr (false) {
TestMatMul<F32>(512, 512, 512, /*add=*/false, env);
TestMatMul<BF16>(512, 512, 512, /*add=*/true, env);
TestMatMul<F32, BF16>(512, 512, 512, /*add=*/false, env);
TestMatMul<BF16, F32>(512, 512, 512, /*add=*/true, env);
TestMatMul<F32, SFP>(512, 512, 512, /*add=*/false, env);
TestMatMul<BF16, SFP>(512, 512, 512, /*add=*/true, env);
}
// minimal non-square test. kColsARowsB must be at least 2 vectors. // minimal non-square test. kColsARowsB must be at least 2 vectors.
TestMatMul<F32>(35, 128, 32, /*add=*/false, env); TestMatMul<F32>(35, 128, 32, /*add=*/false, env);
@ -325,8 +403,10 @@ HWY_AFTER_NAMESPACE();
#if HWY_ONCE #if HWY_ONCE
namespace gcpp { namespace gcpp {
HWY_BEFORE_TEST(MatmulTest); int64_t first_target = 0; // none run yet
HWY_EXPORT_AND_TEST_P(MatmulTest, TestAllMatMul); HWY_BEFORE_TEST(MatMulTest);
HWY_EXPORT_AND_TEST_P(MatMulTest, TestBatchSizes);
HWY_EXPORT_AND_TEST_P(MatMulTest, TestAllMatMul);
HWY_AFTER_TEST(); HWY_AFTER_TEST();
} // namespace gcpp } // namespace gcpp

17
ops/matmul_unit_test.cc Normal file
View File

@ -0,0 +1,17 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// TODO: Tests of individual MatMul components.
int main() { return 0; }

View File

@ -103,7 +103,7 @@ size_t CountBusyPages(size_t num_pages, size_t node, void** pages,
// which means it would have to be called before pages are faulted in, but // which means it would have to be called before pages are faulted in, but
// `aligned_allocator.h` modifies the first bytes for its bookkeeping. // `aligned_allocator.h` modifies the first bytes for its bookkeeping.
// May overwrite some of the memory with zeros. // May overwrite some of the memory with zeros.
static void BindMemory(void* ptr, size_t bytes, size_t node) { void BindMemory(void* ptr, size_t bytes, size_t node) {
constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough" constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough"
// Avoid mbind because it does not report why it failed, which is most likely // Avoid mbind because it does not report why it failed, which is most likely
// because pages are busy, in which case we want to know which. // because pages are busy, in which case we want to know which.
@ -159,24 +159,7 @@ static void BindMemory(void* ptr, size_t bytes, size_t node) {
#else #else
// TODO: support other OSes. // TODO: support other OSes.
static void BindMemory(void*, size_t, size_t) {} void BindMemory(void*, size_t, size_t) {}
#endif // GEMMA_NUMA && HWY_OS_LINUX #endif // GEMMA_NUMA && HWY_OS_LINUX
void BindTensor(NestedPools& nested, const Extents2D& extents,
size_t bytes_per_col, void* ptr) {
if (!Allocator::UseNUMA()) return;
uint8_t* p8 = static_cast<uint8_t*>(ptr);
const size_t bytes_per_row = extents.cols * bytes_per_col;
StaticPartitionRowsAndCols(
nested, extents, bytes_per_col,
[&](const Range2D& r, const TaskLocation& loc) {
for (size_t row : r.rows) {
uint8_t* slice =
p8 + row * bytes_per_row + r.cols.begin() * bytes_per_col;
const size_t slice_size = r.cols.Num() * bytes_per_col;
BindMemory(slice, slice_size, loc.node);
}
});
}
} // namespace gcpp } // namespace gcpp

View File

@ -125,82 +125,8 @@ class Allocator {
static size_t alignment_; static size_t alignment_;
}; };
// For shorter arguments to the StaticPartitionRowsAndCols functor. // For future NUMA support. TODO: use.
struct TaskLocation { void BindMemory(void* ptr, size_t bytes, size_t node);
TaskLocation(size_t node, size_t package_idx, hwy::ThreadPool& cluster,
size_t worker_offset)
: node(node),
package_idx(package_idx),
cluster(cluster),
worker_offset(worker_offset) {}
size_t node;
size_t package_idx;
hwy::ThreadPool& cluster;
const size_t worker_offset;
};
// Used in MatMul and allocator.h. Defined here because it depends on
// Allocator::Alignment().
template <class Func>
void StaticPartitionRowsAndCols(NestedPools& nested, Extents2D extents,
size_t bytes_per_element, const Func& func) {
// Both rows and cols must be a multiple of the alignment to avoid
// touching remote pages.
const size_t multiple = Allocator::Alignment() / bytes_per_element;
// Static partitioning of columns across packages. We assume that column
// sharding is more expensive, hence we distribute columns across packages,
// of which there are usually only one or two. For MatMul, the final result is
// the sum of each package's partial dot products.
hwy::ThreadPool& all_packages = nested.AllPackages();
const size_t num_packages = all_packages.NumWorkers();
const size_t cols_per_package =
hwy::RoundUpTo(hwy::DivCeil(extents.cols, num_packages), multiple);
const size_t col_tasks = hwy::DivCeil(extents.cols, cols_per_package);
HWY_ASSERT(col_tasks <= num_packages);
all_packages.Run(
0, col_tasks, [&](uint64_t package_idx, size_t package_thread) {
HWY_ASSERT(package_idx == package_thread); // one task per worker
const size_t col_begin = package_idx * cols_per_package;
const Range1D col_range =
MakeRange1D(col_begin, extents.cols, cols_per_package);
// Static partitioning of rows across the package's clusters. We assume
// that row sharding is cheaper. In MatMul, results can indeed be
// computed independently for each row of B.
hwy::ThreadPool& all_clusters = nested.AllClusters(package_idx);
const size_t num_clusters = all_clusters.NumWorkers();
const size_t rows_per_cluster =
hwy::RoundUpTo(hwy::DivCeil(extents.rows, num_clusters), multiple);
const size_t row_tasks = hwy::DivCeil(extents.rows, rows_per_cluster);
HWY_ASSERT(row_tasks <= num_clusters);
all_clusters.Run(
0, row_tasks, [&](uint64_t cluster_idx, size_t cluster_thread) {
HWY_ASSERT(cluster_idx == cluster_thread); // one task per worker
// For binding to NUMA node.
const size_t node = nested.Node(package_idx, cluster_idx);
// Older CPUs that predate chiplets typically have only one
// cluster, so callers should also parallelize using this
// per-cluster pool.
hwy::ThreadPool& cluster =
nested.Cluster(package_idx, cluster_idx);
// This plus the worker from `cluster->Run` is the TLS index.
const size_t worker_offset =
nested.WorkerOffset(package_idx, cluster_idx);
const size_t row_begin = cluster_idx * rows_per_cluster;
const Range1D row_range =
MakeRange1D(row_begin, extents.rows, rows_per_cluster);
func(Range2D(row_range, col_range),
TaskLocation(node, package_idx, cluster, worker_offset));
});
});
}
void BindTensor(NestedPools& nested, size_t rows, size_t cols,
size_t bytes_per_col, void* ptr);
} // namespace gcpp } // namespace gcpp

View File

@ -27,6 +27,7 @@
#include "compression/io.h" // Path #include "compression/io.h" // Path
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/gemma.h" // For CreateGemma #include "gemma/gemma.h" // For CreateGemma
#include "ops/matmul.h"
#include "util/args.h" #include "util/args.h"
#include "util/basics.h" // Tristate #include "util/basics.h" // Tristate
#include "util/threading.h" #include "util/threading.h"
@ -115,6 +116,7 @@ class AppArgs : public ArgsBase<AppArgs> {
} }
}; };
// Callers must call Allocator::Init(pools.Topology()) after this.
static inline NestedPools CreatePools(const AppArgs& app) { static inline NestedPools CreatePools(const AppArgs& app) {
return NestedPools(app.max_threads, app.pin, return NestedPools(app.max_threads, app.pin,
BoundedSlice(app.skip_packages, app.max_packages), BoundedSlice(app.skip_packages, app.max_packages),

View File

@ -21,7 +21,7 @@
#include <stdint.h> #include <stdint.h>
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" // HWY_IS_MSAN #include "hwy/base.h"
// IWYU pragma: end_exports // IWYU pragma: end_exports
#if HWY_IS_MSAN #if HWY_IS_MSAN
@ -60,7 +60,7 @@ struct TokenAndProb {
float prob; float prob;
}; };
// Entire size of a 2D array. By contrast, Range2D is a subrange. // Entire size of a 2D array.
struct Extents2D { struct Extents2D {
Extents2D() : rows(0), cols(0) {} Extents2D() : rows(0), cols(0) {}
Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) { Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) {
@ -74,11 +74,13 @@ struct Extents2D {
size_t cols; size_t cols;
}; };
// Range2D consists of two Range1D. struct IndexRange {
struct Range1D { IndexRange(size_t begin, size_t end) : begin_(begin), end_(end) {
Range1D(size_t begin, size_t end) : begin_(begin), end_(end) {
HWY_DASSERT(begin < end); HWY_DASSERT(begin < end);
} }
IndexRange(const IndexRange& other) = default;
IndexRange& operator=(const IndexRange& other) = default;
size_t Num() const { return end_ - begin_; } size_t Num() const { return end_ - begin_; }
// Enable range-based for loops. // Enable range-based for loops.
@ -101,22 +103,15 @@ struct Range1D {
Iterator begin() const { return Iterator(begin_); } Iterator begin() const { return Iterator(begin_); }
Iterator end() const { return Iterator(end_); } Iterator end() const { return Iterator(end_); }
const size_t begin_; size_t begin_;
const size_t end_; size_t end_;
}; };
static inline Range1D MakeRange1D(size_t begin, size_t end, size_t max_size) { static inline IndexRange MakeIndexRange(size_t begin, size_t end,
return Range1D(begin, HWY_MIN(begin + max_size, end)); size_t max_size) {
return IndexRange(begin, HWY_MIN(begin + max_size, end));
} }
// In MatMul, the two axes are used independently, hence we do not define
// Range2D as a top-left and extents.
struct Range2D {
Range2D(Range1D rows, Range1D cols) : rows(rows), cols(cols) {}
const Range1D rows;
const Range1D cols;
};
// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because // Lightweight version of `MatPtr` used for the C argument of `MatMul`, because
// it is always float and does not support compressed T, but does support an // it is always float and does not support compressed T, but does support an
// arbitrary stride >= cols. // arbitrary stride >= cols.
@ -125,6 +120,10 @@ class RowPtr {
public: public:
RowPtr(T* HWY_RESTRICT row0, size_t cols) RowPtr(T* HWY_RESTRICT row0, size_t cols)
: row0_(row0), cols_(cols), stride_(cols) {} : row0_(row0), cols_(cols), stride_(cols) {}
RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride)
: row0_(row0), cols_(cols), stride_(stride) {
HWY_DASSERT(stride >= cols);
}
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; } T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
size_t Cols() const { return cols_; } size_t Cols() const { return cols_; }
@ -207,6 +206,7 @@ struct ConstMat {
} }
const Extents2D& Extents() const { return extents; } const Extents2D& Extents() const { return extents; }
size_t Stride() const { return extents.cols; }
// Shrinks the row-extent of this matrix view, i.e. reduces the view to a // Shrinks the row-extent of this matrix view, i.e. reduces the view to a
// subrange of the original rows starting at row 0. // subrange of the original rows starting at row 0.

View File

@ -39,31 +39,106 @@ static void SortByDescendingSize(std::vector<T>& groups) {
[](const T& a, const T& b) { return a.Size() > b.Size(); }); [](const T& a, const T& b) { return a.Size() > b.Size(); });
} }
// Singleton, holds the original process affinity and the pinning status.
class Pinning {
static bool InContainer() {
return false; }
public:
// Returns set of LPs available for use. Subsequent calls return the same
// set as the first, because pinning overwrites the main thread's affinity.
// Thread-hostile, not called concurrently.
LPS EnabledLPs() {
if (original_affinity_.Any()) return original_affinity_;
// Regardless of topology, ignore LPs disabled via OS, taskset, or numactl.
LPS enabled_lps;
if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) {
const size_t num_lps = hwy::TotalLogicalProcessors();
fprintf(
stderr,
"Warning, unknown OS affinity, considering all %zu LPs enabled\n.",
num_lps);
for (size_t lp = 0; lp < num_lps; ++lp) {
enabled_lps.Set(lp);
}
}
// Without threading support, only keep the first enabled LP; it might still
// make sense to pin the main thread to avoid migrations.
if (HWY_UNLIKELY(!hwy::HaveThreadingSupport())) {
HWY_ASSERT(enabled_lps.Any());
const size_t lp = enabled_lps.First();
enabled_lps = LPS();
enabled_lps.Set(lp);
fprintf(stderr,
"Warning, threads not supported, using only the main thread\n.");
}
original_affinity_ = enabled_lps;
return enabled_lps;
}
void SetPolicy(Tristate pin) {
if (pin == Tristate::kDefault) {
// Pinning is unreliable inside containers because the hypervisor might
// periodically change our affinity mask, or other processes might also
// pin themselves to the same LPs.
pin = InContainer() ? Tristate::kFalse : Tristate::kTrue;
}
want_pin_ = (pin == Tristate::kTrue);
any_error_.clear();
}
// If want_pin_, tries to pin each worker in `pool` to an LP in `cluster`,
// and sets `any_error_` if any fails.
void MaybePin(const BoundedTopology::Cluster& cluster, PoolPtr& pool) {
if (HWY_UNLIKELY(!want_pin_)) return;
const std::vector<size_t> lps = cluster.LPVector();
HWY_ASSERT(pool->NumWorkers() <= lps.size());
pool->Run(
0, pool->NumWorkers(),
[this, &pool, &lps](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each worker has one task
if (HWY_UNLIKELY(!hwy::PinThreadToLogicalProcessor(lps[task]))) {
fprintf(stderr,
"Pinning failed for task %zu of %zu to %zu (size %zu)\n",
task, pool->NumWorkers(), lps[task], lps.size());
(void)any_error_.test_and_set();
}
});
}
// Called ONCE after all MaybePin because it invalidates the error status.
bool AllPinned(const char** pin_string) {
// If !want_pin_, MaybePin will return without setting any_error_, but in
// that case we still want to return false to avoid spinning.
// .test() was only added in C++20, so we use .test_and_set() instead.
const bool all_pinned = want_pin_ && !any_error_.test_and_set();
*pin_string = all_pinned ? "pinned"
: want_pin_ ? "pinning failed"
: "pinning skipped";
return all_pinned;
}
private:
std::atomic_flag any_error_ = ATOMIC_FLAG_INIT;
bool want_pin_; // set in SetPolicy
LPS original_affinity_;
}; // Pinning
// Singleton saves global affinity across all BoundedTopology instances because
// pinning overwrites it.
static Pinning& GetPinning() {
static Pinning pinning;
return pinning;
}
BoundedTopology::BoundedTopology(BoundedSlice package_slice, BoundedTopology::BoundedTopology(BoundedSlice package_slice,
BoundedSlice cluster_slice, BoundedSlice cluster_slice,
BoundedSlice lp_slice) { BoundedSlice lp_slice) {
// Regardless of topology, ignore LPs disabled via OS, taskset, or numactl. const LPS enabled_lps = GetPinning().EnabledLPs();
LPS enabled_lps;
if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) {
const size_t num_lps = hwy::TotalLogicalProcessors();
fprintf(stderr,
"Warning, unknown OS affinity, considering all %zu LPs enabled\n.",
num_lps);
for (size_t lp = 0; lp < num_lps; ++lp) {
enabled_lps.Set(lp);
}
}
// Without threading support, only keep the first enabled LP; it might still
// make sense to pin the main thread to avoid migrations.
if (HWY_UNLIKELY(!hwy::HaveThreadingSupport())) {
HWY_ASSERT(enabled_lps.Any());
const size_t lp = enabled_lps.First();
enabled_lps = LPS();
enabled_lps.Set(lp);
fprintf(stderr,
"Warning, threads not supported, using only the main thread\n.");
}
#if !GEMMA_DISABLE_TOPOLOGY #if !GEMMA_DISABLE_TOPOLOGY
if (HWY_LIKELY(!topology_.packages.empty())) { if (HWY_LIKELY(!topology_.packages.empty())) {
@ -110,19 +185,33 @@ BoundedTopology::Cluster::Cluster(const LPS& enabled_lps,
AddLP(lp); AddLP(lp);
// Set `node` once, and ensure subsequent nodes match - we assume there // Set fields once, and ensure subsequent LPs match - we assume there
// is only one NUMA node per cluster. // is only one NUMA node per cluster, with the same L2/L3 size.
const size_t lp_node = static_cast<size_t>(all_lps[lp].node); const size_t lp_node = static_cast<size_t>(all_lps[lp].node);
if (is_first_lp) { if (is_first_lp) {
is_first_lp = false; is_first_lp = false;
node_ = lp_node; node_ = lp_node;
private_kib_ = tcluster.private_kib;
shared_kib_ = tcluster.shared_kib;
} else { } else {
static bool warned = false; static bool warned = false;
if (lp_node != node_ && !warned) { if (HWY_LIKELY(!warned)) {
warned = true; if (HWY_UNLIKELY(lp_node != node_)) {
fprintf(stderr, "WARNING: lp %zu on node %zu != cluster node %zu.\n", warned = true;
lp, lp_node, node_); fprintf(stderr, "WARNING: lp %zu on node %zu != cluster node %zu.\n",
} lp, lp_node, node_);
}
if (HWY_UNLIKELY(private_kib_ != tcluster.private_kib)) {
warned = true;
fprintf(stderr, "WARNING: lp %zu private_kib %zu != cluster %zu.\n",
lp, private_kib_, tcluster.private_kib);
}
if (HWY_UNLIKELY(shared_kib_ != tcluster.shared_kib)) {
warned = true;
fprintf(stderr, "WARNING: lp %zu shared_kib %zu != cluster %zu.\n",
lp, shared_kib_, tcluster.shared_kib);
}
} // !warned
} }
}); });
} }
@ -141,6 +230,7 @@ BoundedTopology::Package::Package(const LPS& enabled_lps,
"cluster", tpackage.clusters.size(), [&](size_t cluster_idx) { "cluster", tpackage.clusters.size(), [&](size_t cluster_idx) {
const hwy::Topology::Cluster& tcluster = tpackage.clusters[cluster_idx]; const hwy::Topology::Cluster& tcluster = tpackage.clusters[cluster_idx];
Cluster cluster(enabled_lps, topology.lps, tcluster); Cluster cluster(enabled_lps, topology.lps, tcluster);
// Skip if empty, i.e. too few `enabled_lps`. // Skip if empty, i.e. too few `enabled_lps`.
if (HWY_LIKELY(cluster.Size() != 0)) { if (HWY_LIKELY(cluster.Size() != 0)) {
clusters.push_back(std::move(cluster)); clusters.push_back(std::move(cluster));
@ -267,56 +357,6 @@ static PoolPtr MakePool(size_t num_workers) {
return std::make_unique<hwy::ThreadPool>(num_threads); return std::make_unique<hwy::ThreadPool>(num_threads);
} }
static bool InContainer() {
return false;}
class NestedPools::Pinning {
public:
Pinning(Tristate pin, const BoundedTopology& topology) {
if (pin == Tristate::kDefault) {
// Pinning is unreliable inside containers because the hypervisor might
// periodically change our affinity mask, or other processes might also
// pin themselves to the same LPs.
pin = InContainer() ? Tristate::kFalse : Tristate::kTrue;
}
want_pin_ = (pin == Tristate::kTrue);
}
// If want_pin_, tries to pin each worker in `pool` to an LP in `cluster`,
// and sets `any_error_` if any fails.
void MaybePin(const BoundedTopology::Cluster& cluster, PoolPtr& pool) {
if (HWY_UNLIKELY(!want_pin_)) return;
const std::vector<size_t> lps = cluster.LPVector();
HWY_ASSERT(pool->NumWorkers() <= lps.size());
pool->Run(
0, pool->NumWorkers(),
[this, &pool, &lps](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each worker has one task
if (HWY_UNLIKELY(!hwy::PinThreadToLogicalProcessor(lps[task]))) {
fprintf(stderr,
"Pinning failed for task %zu of %zu to %zu (size %zu)\n",
task, pool->NumWorkers(), lps[task], lps.size());
(void)any_error_.test_and_set();
}
});
}
bool WantPin() const { return want_pin_; }
// Called ONCE after all MaybePin because it invalidates the error status.
bool AllPinned() {
// If !want_pin_, MaybePin will return without setting any_error_, but in
// that case we still want to return false to avoid spinning.
// .test() was only added in C++20, so we use .test_and_set() instead.
return want_pin_ && !any_error_.test_and_set();
}
private:
std::atomic_flag any_error_ = ATOMIC_FLAG_INIT;
bool want_pin_; // set in ctor
}; // Pinning
// Used to divide max_threads and max_workers_per_package across packages and // Used to divide max_threads and max_workers_per_package across packages and
// clusters. Ensures small upper bounds are respected. // clusters. Ensures small upper bounds are respected.
static size_t DivideMaxAcross(const size_t max, const size_t instances) { static size_t DivideMaxAcross(const size_t max, const size_t instances) {
@ -333,7 +373,7 @@ NestedPools::NestedPools(size_t max_threads, Tristate pin,
BoundedSlice package_slice, BoundedSlice cluster_slice, BoundedSlice package_slice, BoundedSlice cluster_slice,
BoundedSlice lp_slice) BoundedSlice lp_slice)
: topology_(package_slice, cluster_slice, lp_slice) { : topology_(package_slice, cluster_slice, lp_slice) {
Pinning pinning(pin, topology_); GetPinning().SetPolicy(pin);
packages_.resize(topology_.NumPackages()); packages_.resize(topology_.NumPackages());
all_packages_ = MakePool(packages_.size()); all_packages_ = MakePool(packages_.size());
const size_t max_workers_per_package = const size_t max_workers_per_package =
@ -344,14 +384,11 @@ NestedPools::NestedPools(size_t max_threads, Tristate pin,
all_packages_->Run( all_packages_->Run(
0, all_packages_->NumWorkers(), [&](uint64_t package_idx, size_t thread) { 0, all_packages_->NumWorkers(), [&](uint64_t package_idx, size_t thread) {
HWY_ASSERT(package_idx == thread); // each thread has one task HWY_ASSERT(package_idx == thread); // each thread has one task
packages_[package_idx] = Package( packages_[package_idx] =
topology_, package_idx, max_workers_per_package, pinning, lp_slice); Package(topology_, package_idx, max_workers_per_package, lp_slice);
}); });
all_pinned_ = pinning.AllPinned(); all_pinned_ = GetPinning().AllPinned(&pin_string_);
pin_string_ = all_pinned_ ? "pinned"
: pinning.WantPin() ? "pinning failed"
: "pinning skipped";
// For mapping package/cluster/thread to noncontiguous TLS indices, in case // For mapping package/cluster/thread to noncontiguous TLS indices, in case
// cluster/thread counts differ. // cluster/thread counts differ.
@ -368,14 +405,9 @@ NestedPools::NestedPools(size_t max_threads, Tristate pin,
HWY_ASSERT(max_workers_per_cluster_ <= 256); HWY_ASSERT(max_workers_per_cluster_ <= 256);
} }
// `max_or_zero` == 0 means no limit.
static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) {
return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero);
}
NestedPools::Package::Package(const BoundedTopology& topology, NestedPools::Package::Package(const BoundedTopology& topology,
size_t package_idx, size_t package_idx,
size_t max_workers_per_package, Pinning& pinning, size_t max_workers_per_package,
BoundedSlice lp_slice) { BoundedSlice lp_slice) {
// Pre-allocate because elements are set concurrently. // Pre-allocate because elements are set concurrently.
clusters_.resize(topology.NumClusters(package_idx)); clusters_.resize(topology.NumClusters(package_idx));
@ -393,7 +425,7 @@ NestedPools::Package::Package(const BoundedTopology& topology,
clusters_[cluster_idx] = clusters_[cluster_idx] =
MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster)); MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster));
// Pin workers AND the calling thread from `all_clusters`. // Pin workers AND the calling thread from `all_clusters`.
pinning.MaybePin(cluster, clusters_[cluster_idx]); GetPinning().MaybePin(cluster, clusters_[cluster_idx]);
}); });
} }

View File

@ -17,14 +17,17 @@
#define THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ #define THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
#include <stddef.h> #include <stddef.h>
#include <stdint.h>
#include <memory> // std::unique_ptr #include <memory> // std::unique_ptr
#include <vector> #include <vector>
// IWYU pragma: begin_exports
#include "util/basics.h" // Tristate #include "util/basics.h" // Tristate
#include "hwy/base.h" // HWY_ASSERT #include "hwy/base.h" // HWY_ASSERT
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/contrib/thread_pool/topology.h" #include "hwy/contrib/thread_pool/topology.h"
// IWYU pragma: end_exports
#ifndef GEMMA_DISABLE_TOPOLOGY #ifndef GEMMA_DISABLE_TOPOLOGY
#define GEMMA_DISABLE_TOPOLOGY 0 #define GEMMA_DISABLE_TOPOLOGY 0
@ -32,6 +35,15 @@
namespace gcpp { namespace gcpp {
static inline size_t SaturatingSub(size_t a, size_t b) {
return a - HWY_MIN(a, b);
}
// `max_or_zero` == 0 means no limit.
static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) {
return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero);
}
// A slice of a 1D integer range such as the indices of packages or clusters. // A slice of a 1D integer range such as the indices of packages or clusters.
// This allows assigning them to multiple instances of our binary. // This allows assigning them to multiple instances of our binary.
class BoundedSlice { class BoundedSlice {
@ -86,6 +98,7 @@ using PoolPtr = std::unique_ptr<hwy::ThreadPool>;
// back to a single package and cluster. // back to a single package and cluster.
class BoundedTopology { class BoundedTopology {
public: public:
// Thread-hostile, typically called from main thread.
BoundedTopology(BoundedSlice package_slice, BoundedSlice cluster_slice, BoundedTopology(BoundedSlice package_slice, BoundedSlice cluster_slice,
BoundedSlice lp_slice); BoundedSlice lp_slice);
@ -112,6 +125,8 @@ class BoundedTopology {
} }
size_t Node() const { return node_; } size_t Node() const { return node_; }
size_t PrivateKiB() const { return private_kib_; }
size_t SharedKiB() const { return shared_kib_; }
private: private:
void AddLP(size_t lp) { void AddLP(size_t lp) {
@ -126,6 +141,10 @@ class BoundedTopology {
size_t num_workers_ = 0; size_t num_workers_ = 0;
// NUMA node, set from hwy::Topology::LP::node. // NUMA node, set from hwy::Topology::LP::node.
size_t node_ = 0; size_t node_ = 0;
// L2 cache size in KiB, or 0 if unknown.
size_t private_kib_ = 0;
// L3 cache size in KiB, or 0 if unknown.
size_t shared_kib_ = 0;
}; // Cluster }; // Cluster
size_t NumClusters(size_t package_idx) const { size_t NumClusters(size_t package_idx) const {
@ -145,6 +164,10 @@ class BoundedTopology {
return package.clusters[cluster_idx]; return package.clusters[cluster_idx];
} }
#if !GEMMA_DISABLE_TOPOLOGY
const hwy::Topology& FullTopology() const { return topology_; }
#endif
private: private:
struct Package { struct Package {
// Topology is unknown, rely on OS affinity and user-specified slice. // Topology is unknown, rely on OS affinity and user-specified slice.
@ -257,6 +280,8 @@ class NestedPools {
// Returns the first of `cluster.NumWorkers()` TLS indices, to which callers // Returns the first of `cluster.NumWorkers()` TLS indices, to which callers
// add the worker index given by `cluster.Run`. // add the worker index given by `cluster.Run`.
size_t WorkerOffset(size_t package_idx, size_t cluster_idx) const { size_t WorkerOffset(size_t package_idx, size_t cluster_idx) const {
HWY_DASSERT(package_idx < packages_.size());
HWY_DASSERT(cluster_idx < packages_[package_idx].NumClusters());
return (package_idx * max_clusters_per_package_ + cluster_idx) * return (package_idx * max_clusters_per_package_ + cluster_idx) *
max_workers_per_cluster_; max_workers_per_cluster_;
} }
@ -267,26 +292,25 @@ class NestedPools {
const char* TopologyString() const { return topology_.TopologyString(); } const char* TopologyString() const { return topology_.TopologyString(); }
const char* PinString() const { return pin_string_; } const char* PinString() const { return pin_string_; }
// Returns a single pool on the first package: either one thread per cluster // Returns a single pool on the given package: either one thread per cluster
// if there is more than one, which maximizes available memory bandwidth, or // if there is more than one, which maximizes available memory bandwidth, or
// the first cluster, which is typically the whole package. For use by callers // the first cluster, which is typically the whole package. For use by callers
// that only parallelize over a 1D range, as opposed to the nested // that only have a single parallel-for.
// parallelism of `StaticPartitionRowsAndCols`. hwy::ThreadPool& Pool(size_t package_idx = 0) {
hwy::ThreadPool& Pool() {
// Only one cluster: use its pool, typically a whole socket. // Only one cluster: use its pool, typically a whole socket.
if (AllClusters(0).NumWorkers() == 1) return Cluster(0, 0); if (AllClusters(package_idx).NumWorkers() == 1) {
return AllClusters(0); return Cluster(package_idx, 0);
}
// One worker per cluster to maximize bandwidth availability.
return AllClusters(package_idx);
} }
private: private:
class Pinning;
class Package { class Package {
public: public:
Package() = default; // for vector Package() = default; // for vector
Package(const BoundedTopology& topology, size_t package_idx, Package(const BoundedTopology& topology, size_t package_idx,
size_t max_workers_per_package, Pinning& pinning, size_t max_workers_per_package, BoundedSlice lp_slice);
BoundedSlice lp_slice);
size_t NumClusters() const { return clusters_.size(); } size_t NumClusters() const { return clusters_.size(); }
size_t MaxWorkersPerCluster() const { size_t MaxWorkersPerCluster() const {
@ -330,11 +354,134 @@ class NestedPools {
std::vector<Package> packages_; std::vector<Package> packages_;
PoolPtr all_packages_; PoolPtr all_packages_;
// For TLS indices. // For TLS indices. One might think this belongs in BoundedTopology, but it
// depends on max_threads, which is passed to the NestedPools constructor.
size_t max_clusters_per_package_ = 0; size_t max_clusters_per_package_ = 0;
size_t max_workers_per_cluster_ = 0; size_t max_workers_per_cluster_ = 0;
}; };
// Splits `range` into subranges of size `task_size`, except for the last,
// which receives the remainder. Used with the `ParallelizeOneRange` etc.
// functions below.
class IndexRangePartition {
public:
IndexRangePartition(const IndexRange& range, const size_t task_size)
: range_(range), task_size_(task_size) {
const size_t num = range.Num();
HWY_DASSERT(task_size_ != 0);
num_tasks_ = hwy::DivCeil(num, task_size_);
HWY_DASSERT(num_tasks_ != 0);
if constexpr (HWY_IS_DEBUG_BUILD) {
const size_t handled = num_tasks_ * task_size_;
// The last task may extend beyond items, but at most by (task_size_ - 1).
HWY_DASSERT(num <= handled && handled < num + task_size_);
}
}
size_t TaskSize() const { return task_size_; }
size_t NumTasks() const { return num_tasks_; }
IndexRange Range(size_t task_idx) const {
HWY_DASSERT(task_idx < NumTasks());
return MakeIndexRange(range_.begin() + task_idx * task_size_, range_.end(),
task_size_);
}
private:
IndexRange range_;
size_t task_size_;
size_t num_tasks_;
};
// Starts with `max_size` and rounds DOWN to a multiple of `size_multiple`
// unless that would be zero. It is the caller's responsibility to choose
// `size_multiple` to avoid two heavily imbalanced tasks.
// Use when the number of tasks does not matter, but each must fit into caches.
static inline IndexRangePartition MaxSizePartition(const IndexRange& range,
const size_t max_size,
const size_t size_multiple) {
HWY_DASSERT(size_multiple != 0);
size_t size = HWY_MIN(range.Num(), max_size);
if (size > size_multiple) size = hwy::RoundDownTo(size, size_multiple);
return IndexRangePartition(range, size);
}
// Up to `max_tasks` tasks, each rounded UP to `size_multiple`, unless that
// would be more than the range. Use when the number of tasks is known, e.g.
// one per ThreadPool worker.
static inline IndexRangePartition StaticPartition(const IndexRange& range,
const size_t max_tasks,
const size_t size_multiple) {
HWY_DASSERT(max_tasks != 0);
size_t size =
hwy::RoundUpTo(hwy::DivCeil(range.Num(), max_tasks), size_multiple);
size = HWY_MIN(size, range.Num());
return IndexRangePartition(range, size);
}
// Parallel-for over a single range. This takes care of translating the task
// index to a range.
template <class Func>
void ParallelizeOneRange(const IndexRangePartition& get1, hwy::ThreadPool& pool,
const Func& func) {
const size_t num_tasks = get1.NumTasks();
pool.Run(0, num_tasks, [&](uint64_t task, size_t thread) {
const IndexRange range1 = get1.Range(task);
func(range1, thread);
});
}
// Parallel-for over the Cartesian product of the two sets of ranges. This
// combines their indices into a single 'task' so they can be executed by one
// `pool.Run`, which increases the amount of work available to workers and
// reduces fork-join overhead vs. nested parallel-for loops. Calls `func` with
// the two ranges and the thread index within `pool`.
template <class Func>
void ParallelizeTwoRanges(const IndexRangePartition& get1,
const IndexRangePartition& get2,
hwy::ThreadPool& pool, const Func& func) {
const hwy::Divisor div1(static_cast<uint32_t>(get1.NumTasks()));
const size_t num_tasks = get1.NumTasks() * get2.NumTasks();
pool.Run(0, num_tasks, [&](uint64_t task, size_t thread) {
HWY_DASSERT(task < (uint64_t{1} << 32));
const size_t idx2 = div1.Divide(static_cast<uint32_t>(task));
const size_t idx1 = div1.Remainder(static_cast<uint32_t>(task));
HWY_DASSERT(idx1 < get1.NumTasks());
HWY_DASSERT(idx2 < get2.NumTasks());
const IndexRange range1 = get1.Range(idx1);
const IndexRange range2 = get2.Range(idx2);
func(range1, range2, thread);
});
}
// As above, for three ranges.
template <class Func>
void ParallelizeThreeRanges(const IndexRangePartition& get1,
const IndexRangePartition& get2,
const IndexRangePartition& get3,
hwy::ThreadPool& pool, const Func& func) {
const hwy::Divisor div1(static_cast<uint32_t>(get1.NumTasks()));
const size_t num12 = get1.NumTasks() * get2.NumTasks();
const hwy::Divisor div12(static_cast<uint32_t>(num12));
const size_t num_tasks = num12 * get3.NumTasks();
pool.Run(0, num_tasks, [&](uint64_t task, size_t thread) {
HWY_DASSERT(task < (uint64_t{1} << 32));
const size_t idx3 = div12.Divide(static_cast<uint32_t>(task));
const size_t task12 = div12.Remainder(static_cast<uint32_t>(task));
const size_t idx2 = div1.Divide(static_cast<uint32_t>(task12));
const size_t idx1 = div1.Remainder(static_cast<uint32_t>(task12));
HWY_DASSERT(idx1 < get1.NumTasks());
HWY_DASSERT(idx2 < get2.NumTasks());
HWY_DASSERT(idx3 < get3.NumTasks());
const IndexRange range1 = get1.Range(idx1);
const IndexRange range2 = get2.Range(idx2);
const IndexRange range3 = get3.Range(idx3);
func(range1, range2, range3, thread);
});
}
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ #endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_

View File

@ -23,6 +23,7 @@
#include "gmock/gmock.h" #include "gmock/gmock.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "hwy/base.h" // HWY_ASSERT #include "hwy/base.h" // HWY_ASSERT
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp { namespace gcpp {
namespace { namespace {
@ -111,5 +112,229 @@ TEST(ThreadingTest, TestBoundedTopology) {
} }
} }
TEST(ThreadingTest, TestMaxSizePartition) {
const IndexRange range(0, 100);
// Round down
{
const IndexRangePartition partition = MaxSizePartition(range, 55, 32);
HWY_ASSERT(partition.TaskSize() == 32);
HWY_ASSERT(partition.NumTasks() == 4);
}
// Huge `max_size`: single task
{
const IndexRangePartition partition = MaxSizePartition(range, 9999, 1);
HWY_ASSERT(partition.TaskSize() == 100);
HWY_ASSERT(partition.NumTasks() == 1);
}
// Huge `max_size`: `size_multiple` is still respected
{
const IndexRangePartition partition = MaxSizePartition(range, 9999, 64);
HWY_ASSERT(partition.TaskSize() == 64);
HWY_ASSERT(partition.NumTasks() == 2);
}
// `size_multiple` larger than range: ignore multiple
{
const IndexRangePartition partition = MaxSizePartition(range, 55, 128);
HWY_ASSERT(partition.TaskSize() == 55);
HWY_ASSERT(partition.NumTasks() == 2);
}
// Small `max_size`: small tasks
{
const IndexRangePartition partition = MaxSizePartition(range, 2, 1);
HWY_ASSERT(partition.TaskSize() == 2);
HWY_ASSERT(partition.NumTasks() == 50);
}
// Large `max_size`: two tasks with lots of overhang
{
const IndexRangePartition partition = MaxSizePartition(range, 98, 1);
HWY_ASSERT(partition.TaskSize() == 98);
HWY_ASSERT(partition.NumTasks() == 2);
}
// `size_multiple` almost as large as a different, smaller range: imbalanced
{
const IndexRangePartition partition =
MaxSizePartition(IndexRange(0, 6), 6, 4);
HWY_ASSERT(partition.TaskSize() == 4);
HWY_ASSERT(partition.NumTasks() == 2);
}
}
TEST(ThreadingTest, TestStaticPartition) {
const IndexRange range(0, 100);
// Round up
{
const IndexRangePartition partition = StaticPartition(range, 2, 64);
HWY_ASSERT(partition.TaskSize() == 64);
HWY_ASSERT(partition.NumTasks() == 2);
}
// No `size_multiple`: division still rounds up
{
const IndexRangePartition partition = StaticPartition(range, 3, 1);
HWY_ASSERT(partition.TaskSize() == 34);
HWY_ASSERT(partition.NumTasks() == 3);
}
// Huge `max_tasks`: one each
{
const IndexRangePartition partition = StaticPartition(range, 9999, 1);
HWY_ASSERT(partition.TaskSize() == 1);
HWY_ASSERT(partition.NumTasks() == 100);
}
// `size_multiple` larger than range: single task
{
const IndexRangePartition partition = StaticPartition(range, 2, 128);
HWY_ASSERT(partition.TaskSize() == 100);
HWY_ASSERT(partition.NumTasks() == 1);
}
// `max_tasks` = 1: single task, even if rounding up would exceed the range
{
const IndexRangePartition partition = StaticPartition(range, 1, 8);
HWY_ASSERT(partition.TaskSize() == 100);
HWY_ASSERT(partition.NumTasks() == 1);
}
}
TEST(ThreadingTest, TestParallelizeOneRange) {
const IndexRange range(0, 10);
const IndexRangePartition partition = StaticPartition(range, 2, 4);
hwy::ThreadPool null_pool(0);
size_t calls = 0;
ParallelizeOneRange(partition, null_pool,
[&](const IndexRange& range, size_t) {
if (++calls == 1) {
HWY_ASSERT(range.begin() == 0 && range.end() == 8);
} else {
HWY_ASSERT(range.begin() == 8 && range.end() == 10);
}
});
HWY_ASSERT(calls == 2);
}
TEST(ThreadingTest, TestParallelizeTwoRanges) {
const IndexRangePartition partition1 =
StaticPartition(IndexRange(0, 10), 2, 4);
const IndexRangePartition partition2 =
MaxSizePartition(IndexRange(128, 256), 32, 32);
HWY_ASSERT(partition2.NumTasks() == 4);
hwy::ThreadPool null_pool(0);
{
size_t calls = 0;
ParallelizeTwoRanges(
partition1, partition2, null_pool,
[&](const IndexRange& range1, const IndexRange& range2, size_t) {
++calls;
HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8);
HWY_ASSERT(range2.begin() % 32 == 0);
HWY_ASSERT(range2.Num() % 32 == 0);
});
HWY_ASSERT(calls == 2 * 4);
}
// Also swap order to test Remainder() logic.
{
size_t calls = 0;
ParallelizeTwoRanges(
partition2, partition1, null_pool,
[&](const IndexRange& range2, const IndexRange& range1, size_t) {
++calls;
HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8);
HWY_ASSERT(range2.begin() % 32 == 0);
HWY_ASSERT(range2.Num() % 32 == 0);
});
HWY_ASSERT(calls == 2 * 4);
}
}
TEST(ThreadingTest, TestParallelizeThreeRanges) {
// Named according to number of tasks.
const IndexRangePartition partition3 =
StaticPartition(IndexRange(0, 8), 3, 1); // [0, 3) [3, 6) [6, 8)
HWY_ASSERT(partition3.NumTasks() == 3);
const IndexRangePartition partition2 =
MaxSizePartition(IndexRange(10, 30), 10, 10); // [10, 20), [20, 30)
HWY_ASSERT(partition2.NumTasks() == 2);
const IndexRangePartition partition4 =
MaxSizePartition(IndexRange(100, 500), 100, 100); // 100, 200, 300, 400
HWY_ASSERT(partition4.NumTasks() == 4);
const auto check_ranges = [&](const IndexRange& range3,
const IndexRange& range2,
const IndexRange& range4) {
HWY_ASSERT(range3.begin() == 0 || range3.begin() == 3 ||
range3.begin() == 6);
HWY_ASSERT(range2.begin() == 10 || range2.begin() == 20);
HWY_ASSERT(range4.begin() % 100 == 0);
};
hwy::ThreadPool null_pool(0);
// All 6 permutations of the three ranges to test the Remainder() logic:
// 3, 2, 4
{
size_t calls = 0;
ParallelizeThreeRanges(
partition3, partition2, partition4, null_pool,
[&](IndexRange range3, IndexRange range2, IndexRange range4, size_t) {
++calls;
check_ranges(range3, range2, range4);
});
HWY_ASSERT(calls == 3 * 2 * 4);
}
// 3, 4, 2
{
size_t calls = 0;
ParallelizeThreeRanges(
partition3, partition4, partition2, null_pool,
[&](IndexRange range3, IndexRange range4, IndexRange range2, size_t) {
++calls;
check_ranges(range3, range2, range4);
});
HWY_ASSERT(calls == 3 * 2 * 4);
}
// 4, 2, 3
{
size_t calls = 0;
ParallelizeThreeRanges(
partition4, partition2, partition3, null_pool,
[&](IndexRange range4, IndexRange range2, IndexRange range3, size_t) {
++calls;
check_ranges(range3, range2, range4);
});
HWY_ASSERT(calls == 3 * 2 * 4);
}
// 4, 3, 2
{
size_t calls = 0;
ParallelizeThreeRanges(
partition4, partition3, partition2, null_pool,
[&](IndexRange range4, IndexRange range3, IndexRange range2, size_t) {
++calls;
check_ranges(range3, range2, range4);
});
HWY_ASSERT(calls == 3 * 2 * 4);
}
// 2, 3, 4
{
size_t calls = 0;
ParallelizeThreeRanges(
partition2, partition3, partition4, null_pool,
[&](IndexRange range2, IndexRange range3, IndexRange range4, size_t) {
++calls;
check_ranges(range3, range2, range4);
});
HWY_ASSERT(calls == 3 * 2 * 4);
}
// 2, 4, 3
{
size_t calls = 0;
ParallelizeThreeRanges(
partition2, partition4, partition3, null_pool,
[&](IndexRange range2, IndexRange range4, IndexRange range3, size_t) {
++calls;
check_ranges(range3, range2, range4);
});
HWY_ASSERT(calls == 3 * 2 * 4);
}
}
} // namespace } // namespace
} // namespace gcpp } // namespace gcpp