mirror of https://github.com/google/gemma.cpp.git
Threading/infra improvements.
* Add Parallelize*Range helpers and partitioning helpers * Refactor Pinning class, store original affinity (required to construct another NestedPools after pinning happened) Compress: * prevent Compress printing stats in tests * zero-pad tensors Matmul: * add matmul_unit_test (TODO) and bench_matmul * matmul_test: change norm to row vectors (that is what is added) and include bf16 rounding error * Prepare for L2/L3 retrieval PiperOrigin-RevId: 700603811
This commit is contained in:
parent
109a4d9f85
commit
f74d496879
46
BUILD.bazel
46
BUILD.bazel
|
|
@ -71,6 +71,7 @@ cc_test(
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -166,6 +167,26 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "matmul_unit_test",
|
||||||
|
size = "small",
|
||||||
|
timeout = "long",
|
||||||
|
srcs = ["ops/matmul_unit_test.cc"],
|
||||||
|
local_defines = ["HWY_IS_TEST"],
|
||||||
|
# for test_suite.
|
||||||
|
tags = ["hwy_ops_test"],
|
||||||
|
deps = [
|
||||||
|
":allocator",
|
||||||
|
":basics",
|
||||||
|
":ops",
|
||||||
|
":test_util",
|
||||||
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
|
"//compression:compress",
|
||||||
|
"@highway//:hwy",
|
||||||
|
"@highway//:hwy_test_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "matmul_test",
|
name = "matmul_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
|
@ -178,7 +199,28 @@ cc_test(
|
||||||
":allocator",
|
":allocator",
|
||||||
":basics",
|
":basics",
|
||||||
":ops",
|
":ops",
|
||||||
":test_util",
|
":threading",
|
||||||
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
|
"//compression:compress",
|
||||||
|
"@highway//:hwy",
|
||||||
|
"@highway//:hwy_test_util",
|
||||||
|
"@highway//:nanobenchmark",
|
||||||
|
"@highway//:thread_pool",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "bench_matmul",
|
||||||
|
size = "small",
|
||||||
|
timeout = "long",
|
||||||
|
srcs = ["ops/bench_matmul.cc"],
|
||||||
|
local_defines = ["HWY_IS_TEST"],
|
||||||
|
# for test_suite.
|
||||||
|
tags = ["hwy_ops_test"],
|
||||||
|
deps = [
|
||||||
|
":allocator",
|
||||||
|
":basics",
|
||||||
|
":ops",
|
||||||
":threading",
|
":threading",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
|
|
@ -343,6 +385,7 @@ cc_library(
|
||||||
":basics",
|
":basics",
|
||||||
":common",
|
":common",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
|
":ops",
|
||||||
":threading",
|
":threading",
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
|
|
@ -624,6 +667,7 @@ cc_test(
|
||||||
"mem": "28g",
|
"mem": "28g",
|
||||||
},
|
},
|
||||||
deps = [
|
deps = [
|
||||||
|
":allocator",
|
||||||
":backprop",
|
":backprop",
|
||||||
":basics",
|
":basics",
|
||||||
":common",
|
":common",
|
||||||
|
|
|
||||||
|
|
@ -163,9 +163,11 @@ set(GEMMA_TEST_FILES
|
||||||
compression/sfp_test.cc
|
compression/sfp_test.cc
|
||||||
evals/gemma_test.cc
|
evals/gemma_test.cc
|
||||||
gemma/tensor_index_test.cc
|
gemma/tensor_index_test.cc
|
||||||
|
ops/bench_matmul.cc
|
||||||
ops/dot_test.cc
|
ops/dot_test.cc
|
||||||
ops/gemma_matvec_test.cc
|
ops/gemma_matvec_test.cc
|
||||||
ops/matmul_test.cc
|
ops/matmul_test.cc
|
||||||
|
ops/matmul_unit_test.cc
|
||||||
ops/ops_test.cc
|
ops/ops_test.cc
|
||||||
paligemma/image_test.cc
|
paligemma/image_test.cc
|
||||||
paligemma/paligemma_test.cc
|
paligemma/paligemma_test.cc
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
|
#include "util/allocator.h"
|
||||||
#include "util/basics.h"
|
#include "util/basics.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
@ -42,6 +43,7 @@ namespace gcpp {
|
||||||
TEST(OptimizeTest, GradientDescent) {
|
TEST(OptimizeTest, GradientDescent) {
|
||||||
NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
|
NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
|
||||||
BoundedSlice(0, 1));
|
BoundedSlice(0, 1));
|
||||||
|
Allocator::Init(pools.Topology());
|
||||||
hwy::ThreadPool& pool = pools.Pool();
|
hwy::ThreadPool& pool = pools.Pool();
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,12 @@ namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
|
#ifdef HWY_IS_TEST
|
||||||
|
static constexpr bool kIsTest = true;
|
||||||
|
#else
|
||||||
|
static constexpr bool kIsTest = false;
|
||||||
|
#endif
|
||||||
|
|
||||||
// Enables generic code independent of compression type.
|
// Enables generic code independent of compression type.
|
||||||
template <typename T> // primary, must specialize
|
template <typename T> // primary, must specialize
|
||||||
struct CompressTraits {};
|
struct CompressTraits {};
|
||||||
|
|
@ -438,7 +444,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool want_bench = num > 1024 * 1024 || COMPRESS_STATS;
|
const bool want_bench = COMPRESS_STATS || !kIsTest;
|
||||||
const double t0 = want_bench ? hwy::platform::Now() : 0.0;
|
const double t0 = want_bench ? hwy::platform::Now() : 0.0;
|
||||||
|
|
||||||
using Traits = CompressTraits<Packed>;
|
using Traits = CompressTraits<Packed>;
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@
|
||||||
#include "util/basics.h"
|
#include "util/basics.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
|
#include "hwy/per_target.h"
|
||||||
#if COMPRESS_STATS
|
#if COMPRESS_STATS
|
||||||
#include "compression/distortion.h"
|
#include "compression/distortion.h"
|
||||||
#include "hwy/stats.h"
|
#include "hwy/stats.h"
|
||||||
|
|
@ -360,7 +361,11 @@ class MatStorageT : public MatPtrT<MatT> {
|
||||||
} else {
|
} else {
|
||||||
this->num_elements_ = num_elements;
|
this->num_elements_ = num_elements;
|
||||||
}
|
}
|
||||||
data_ = Allocator::Alloc<MatT>(num_elements);
|
// Pad to allow overrunning the last row by 2 BF16 vectors, hence at most
|
||||||
|
// `2 * VectorBytes / sizeof(BF16)` elements of MatT.
|
||||||
|
const size_t padding = hwy::VectorBytes();
|
||||||
|
data_ = Allocator::Alloc<MatT>(num_elements + padding);
|
||||||
|
hwy::ZeroBytes(&data_[num_elements], padding * sizeof(MatT));
|
||||||
this->ptr_ = data_.get();
|
this->ptr_ = data_.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
|
||||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
const AppArgs& app)
|
const AppArgs& app)
|
||||||
: pools_(CreatePools(app)) {
|
: pools_(CreatePools(app)) {
|
||||||
|
Allocator::Init(pools_.Topology());
|
||||||
InferenceArgs mutable_inference = inference;
|
InferenceArgs mutable_inference = inference;
|
||||||
AbortIfInvalidArgs(mutable_inference);
|
AbortIfInvalidArgs(mutable_inference);
|
||||||
LoaderArgs mutable_loader = loader;
|
LoaderArgs mutable_loader = loader;
|
||||||
|
|
|
||||||
|
|
@ -59,6 +59,7 @@ int main(int argc, char** argv) {
|
||||||
|
|
||||||
// Instantiate model and KV Cache
|
// Instantiate model and KV Cache
|
||||||
gcpp::NestedPools pools = gcpp::CreatePools(app);
|
gcpp::NestedPools pools = gcpp::CreatePools(app);
|
||||||
|
gcpp::Allocator::Init(pools.Topology());
|
||||||
gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
|
gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
|
||||||
gcpp::KVCache kv_cache =
|
gcpp::KVCache kv_cache =
|
||||||
gcpp::KVCache::Create(model.GetModelConfig(),
|
gcpp::KVCache::Create(model.GetModelConfig(),
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,214 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Benchmark of large MatMul instances for which the MatMulSlow would be too
|
||||||
|
// slow. This lacks a reference and is only useful for performance measurement.
|
||||||
|
|
||||||
|
#include "hwy/detect_compiler_arch.h"
|
||||||
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
|
// Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require
|
||||||
|
// double-precision support.
|
||||||
|
#if HWY_ARCH_ARM_V7
|
||||||
|
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON)
|
||||||
|
#else
|
||||||
|
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "compression/compress.h"
|
||||||
|
#include "compression/shared.h"
|
||||||
|
#include "ops/matmul.h"
|
||||||
|
#include "util/allocator.h"
|
||||||
|
#include "util/basics.h"
|
||||||
|
#include "util/threading.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
#include "hwy/timer.h"
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
#define HWY_TARGET_INCLUDE "ops/bench_matmul.cc" // NOLINT
|
||||||
|
// clang-format on
|
||||||
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
// After highway.h
|
||||||
|
#include "compression/compress-inl.h"
|
||||||
|
#include "ops/matmul-inl.h"
|
||||||
|
#include "hwy/tests/test_util-inl.h"
|
||||||
|
|
||||||
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
namespace gcpp {
|
||||||
|
// For running BenchAllMatMul only once. Defined within HWY_ONCE.
|
||||||
|
extern int64_t first_target;
|
||||||
|
|
||||||
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
|
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
|
||||||
|
|
||||||
|
template <typename MatT>
|
||||||
|
using MatStoragePtr = std::unique_ptr<MatStorageT<MatT>>;
|
||||||
|
|
||||||
|
// Generates inputs: deterministic, within max SfpStream range.
|
||||||
|
template <typename MatT>
|
||||||
|
MatStoragePtr<MatT> GenerateMat(const Extents2D extents,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
|
gcpp::CompressWorkingSet ws;
|
||||||
|
auto mat =
|
||||||
|
std::make_unique<MatStorageT<MatT>>("mat", extents.rows, extents.cols);
|
||||||
|
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
|
||||||
|
HWY_ASSERT(content);
|
||||||
|
const float scale = SfpStream::kMax / (mat->NumElements());
|
||||||
|
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
||||||
|
for (size_t c = 0; c < extents.cols; c++) {
|
||||||
|
float f = static_cast<float>(r * extents.cols + c) * scale;
|
||||||
|
if ((r + c) & 1) f = -f; // Also generate some negative values.
|
||||||
|
content[r * extents.cols + c] = f;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
|
||||||
|
mat->set_scale(0.6f); // Arbitrary value, different from 1.
|
||||||
|
return mat;
|
||||||
|
}
|
||||||
|
|
||||||
|
// extents describes the transposed matrix.
|
||||||
|
template <typename MatT>
|
||||||
|
MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
|
gcpp::CompressWorkingSet ws;
|
||||||
|
auto mat =
|
||||||
|
std::make_unique<MatStorageT<MatT>>("trans", extents.rows, extents.cols);
|
||||||
|
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
|
||||||
|
const float scale = SfpStream::kMax / (mat->NumElements());
|
||||||
|
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
||||||
|
for (size_t c = 0; c < extents.cols; c++) {
|
||||||
|
float f = static_cast<float>(c * extents.rows + r) * scale;
|
||||||
|
if ((r + c) & 1) f = -f; // Also generate some negative values.
|
||||||
|
content[r * extents.cols + c] = f;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
|
||||||
|
// Arbitrary value, different from 1, must match GenerateMat.
|
||||||
|
mat->set_scale(0.6f);
|
||||||
|
return mat;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PrintSpeed(const char* algo, const Extents2D& A_extents,
|
||||||
|
const Extents2D& B_extents, double elapsed) {
|
||||||
|
const size_t num_b = B_extents.Area();
|
||||||
|
// 2x because of FMA.
|
||||||
|
fprintf(stderr, " %10s: %f seconds, %.1f GFLOPS.\n", algo,
|
||||||
|
elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generates inputs and prints observed throughput of MatMul.
|
||||||
|
template <typename MatTA, typename MatTB = MatTA>
|
||||||
|
void BenchMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||||
|
MatMulEnv& env) {
|
||||||
|
hwy::ThreadPool& pool = env.Pool();
|
||||||
|
fprintf(stderr, "BenchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n",
|
||||||
|
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<MatTA>(),
|
||||||
|
TypeName<MatTB>());
|
||||||
|
|
||||||
|
const Extents2D A_extents(rows_ac, cols_a_rows_b);
|
||||||
|
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
|
||||||
|
const Extents2D C_extents(rows_ac, cols_bc);
|
||||||
|
|
||||||
|
MatStoragePtr<MatTA> a = GenerateMat<MatTA>(A_extents, pool);
|
||||||
|
MatStoragePtr<MatTB> b_trans = GenerateTransposedMat<MatTB>(B_extents, pool);
|
||||||
|
RowVectorBatch<float> c_slow_batch(C_extents);
|
||||||
|
RowVectorBatch<float> c_batch(C_extents);
|
||||||
|
HWY_ASSERT(a && b_trans);
|
||||||
|
|
||||||
|
std::unique_ptr<MatStorageT<float>> add_storage;
|
||||||
|
if (add) {
|
||||||
|
add_storage = GenerateMat<float>(Extents2D(1, cols_bc), pool);
|
||||||
|
HWY_ASSERT(add_storage);
|
||||||
|
add_storage->set_scale(1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto A = ConstMatFromWeights(*a);
|
||||||
|
const auto B = ConstMatFromWeights(*b_trans);
|
||||||
|
const float* add_row = add ? add_storage->data_scale1() : nullptr;
|
||||||
|
const RowPtrF C = RowPtrFromBatch(c_batch);
|
||||||
|
|
||||||
|
double min_elapsed = hwy::HighestValue<double>();
|
||||||
|
for (int rep = 0; rep < 3; ++rep) {
|
||||||
|
const double start_tiled = hwy::platform::Now();
|
||||||
|
MatMul(A, B, add_row, env, C);
|
||||||
|
min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled);
|
||||||
|
}
|
||||||
|
PrintSpeed("MatMul", A_extents, B_extents, min_elapsed);
|
||||||
|
}
|
||||||
|
|
||||||
|
using F32 = float;
|
||||||
|
using SFP = SfpStream;
|
||||||
|
|
||||||
|
void BenchAllMatMul() {
|
||||||
|
if (first_target == 0) first_target = HWY_TARGET;
|
||||||
|
if (HWY_TARGET != first_target) return;
|
||||||
|
|
||||||
|
for (size_t max_packages : {1, 2}) {
|
||||||
|
const size_t max_threads = 0; // no limit
|
||||||
|
NestedPools pools(max_threads, Tristate::kDefault,
|
||||||
|
BoundedSlice(0, max_packages));
|
||||||
|
#if GEMMA_DISABLE_TOPOLOGY
|
||||||
|
if (max_packages == 2) break; // we only have one package
|
||||||
|
#else
|
||||||
|
// If less than the limit, we have already tested all num_packages.
|
||||||
|
if (pools.Topology().FullTopology().packages.size() < max_packages) break;
|
||||||
|
#endif
|
||||||
|
fprintf(stderr, "BenchAllMatMul %zu: %s %s\n", max_packages,
|
||||||
|
pools.TopologyString(), pools.PinString());
|
||||||
|
|
||||||
|
Tristate use_spinning = Tristate::kDefault;
|
||||||
|
pools.MaybeStartSpinning(use_spinning);
|
||||||
|
Allocator::Init(pools.Topology());
|
||||||
|
MatMulEnv env(pools);
|
||||||
|
|
||||||
|
for (size_t batch_size : {1, /*4, 64,*/ 128}) {
|
||||||
|
BenchMatMul<F32, F32>(batch_size, 24576, 3072, /*add=*/false, env);
|
||||||
|
BenchMatMul<F32, F32>(batch_size, 3072, 24576, /*add=*/false, env);
|
||||||
|
BenchMatMul<BF16, SFP>(batch_size, 24576, 3072, /*add=*/false, env);
|
||||||
|
BenchMatMul<BF16, SFP>(batch_size, 3072, 24576, /*add=*/false, env);
|
||||||
|
BenchMatMul<F32, SFP>(batch_size, 24576, 3072, /*add=*/false, env);
|
||||||
|
BenchMatMul<F32, SFP>(batch_size, 3072, 24576, /*add=*/false, env);
|
||||||
|
}
|
||||||
|
pools.MaybeStopSpinning(use_spinning);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
} // namespace HWY_NAMESPACE
|
||||||
|
} // namespace gcpp
|
||||||
|
HWY_AFTER_NAMESPACE();
|
||||||
|
|
||||||
|
#if HWY_ONCE
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
int64_t first_target = 0; // none run yet
|
||||||
|
HWY_BEFORE_TEST(BenchMatMul);
|
||||||
|
HWY_EXPORT_AND_TEST_P(BenchMatMul, BenchAllMatMul);
|
||||||
|
HWY_AFTER_TEST();
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
@ -1109,6 +1109,7 @@ void TestAllDot() {
|
||||||
const size_t num = 24 * 1024;
|
const size_t num = 24 * 1024;
|
||||||
NestedPools pools(kMaxWorkers - 1, /*pin=*/Tristate::kDefault,
|
NestedPools pools(kMaxWorkers - 1, /*pin=*/Tristate::kDefault,
|
||||||
BoundedSlice(0, 1), BoundedSlice(0, 1));
|
BoundedSlice(0, 1), BoundedSlice(0, 1));
|
||||||
|
Allocator::Init(pools.Topology());
|
||||||
RowVectorBatch<float> a(Extents2D(kMaxWorkers, num));
|
RowVectorBatch<float> a(Extents2D(kMaxWorkers, num));
|
||||||
RowVectorBatch<float> b(Extents2D(kMaxWorkers, num));
|
RowVectorBatch<float> b(Extents2D(kMaxWorkers, num));
|
||||||
RowVectorBatch<double> bufs(Extents2D(kMaxWorkers, num));
|
RowVectorBatch<double> bufs(Extents2D(kMaxWorkers, num));
|
||||||
|
|
|
||||||
|
|
@ -38,22 +38,6 @@ namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
// The MatMul result C[r,c] is Dot(A.Row(r), B.Col(c)). To reduce the number of
|
|
||||||
// loads, we reuse the same A row for several B columns, which are also loaded
|
|
||||||
// once for several rows of C. Thus we produce one 'tile' of C at a time of
|
|
||||||
// dimensions `kRegRows` x `kRegCols`. The Reg naming is because these are
|
|
||||||
// limited by the number of registers: 32 for NEON/SVE/AVX-512. `kRegCols` == 4
|
|
||||||
// enables the `StoreInterleaved4` transpose in `AddHorizontalSums`. We assume
|
|
||||||
// and verify that `C.cols % kRegCols == 0`.
|
|
||||||
constexpr size_t kRegCols = 4;
|
|
||||||
|
|
||||||
// Choosing `kRegRows == kRegCols` minimizes the ratio of loads to FMA, because
|
|
||||||
// we load `kRegCols + kRegRows` vectors per `kRegRows * kRegCols` element tile.
|
|
||||||
// In general, `batch_size` (C rows) is not a multiple of `kRegRows`. Thus
|
|
||||||
// functions that load or store a tile are parameterized on `kNumRows`, which is
|
|
||||||
// generally `kRegRows`, but `batch_size % kRegRows` on the last row (if != 0).
|
|
||||||
constexpr size_t kRegRows = kRegCols;
|
|
||||||
|
|
||||||
// Loads two vectors at a time with element type hn::TFromD<DR> from a row of
|
// Loads two vectors at a time with element type hn::TFromD<DR> from a row of
|
||||||
// transposed B. Called in a loop over col_ab. No bounds checking because
|
// transposed B. Called in a loop over col_ab. No bounds checking because
|
||||||
// `kRow` is from B columns, which we checked is a multiple of `kRegCols`.
|
// `kRow` is from B columns, which we checked is a multiple of `kRegCols`.
|
||||||
|
|
|
||||||
35
ops/matmul.h
35
ops/matmul.h
|
|
@ -21,6 +21,7 @@
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
#include "util/basics.h"
|
#include "util/basics.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
|
|
||||||
|
|
@ -28,6 +29,40 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
// TODO: remove deprecated typedef.
|
||||||
|
using Range1D = IndexRange;
|
||||||
|
|
||||||
|
// The MatMul result C[r,c] is Dot(A.Row(r), B.Col(c)). To reduce the number of
|
||||||
|
// loads, we reuse the same A row for several B columns, which are also loaded
|
||||||
|
// once for several rows of C. Thus we produce one 'tile' of C at a time of
|
||||||
|
// dimensions `kRegRows` x `kRegCols`. The Reg naming is because these are
|
||||||
|
// limited by the number of registers: 32 for NEON/SVE/AVX-512. `kRegCols` == 4
|
||||||
|
// enables the `StoreInterleaved4` transpose in `StoreHorizontalSums`. We assume
|
||||||
|
// and verify that `C.Cols() % kRegCols == 0`.
|
||||||
|
constexpr size_t kRegCols = 4;
|
||||||
|
|
||||||
|
// Choosing `kRegRows == kRegCols` minimizes the ratio of loads to FMA, because
|
||||||
|
// we load `kRegCols + kRegRows` vectors per `kRegRows * kRegCols` element tile.
|
||||||
|
// In general, `batch_size` (A/C rows) is not a multiple of `kRegRows`. Thus
|
||||||
|
// functions that load or store a tile are parameterized on `kRowsPerTile`:
|
||||||
|
// usually `kRegRows`, but `batch_size % kRegRows` on the last row (if != 0).
|
||||||
|
constexpr size_t kRegRows = kRegCols;
|
||||||
|
|
||||||
|
struct CacheSizes {
|
||||||
|
CacheSizes() = default;
|
||||||
|
CacheSizes(const BoundedTopology::Cluster& cluster) {
|
||||||
|
// Assumes each package and cluster has the same cache sizes, and uses
|
||||||
|
// reasonable defaults if unknown.
|
||||||
|
l1_bytes = 32 * 1024; // typical size, rarely changes
|
||||||
|
l2_bytes = (cluster.PrivateKiB() ? cluster.PrivateKiB() : 256) * 1024;
|
||||||
|
l3_bytes = (cluster.SharedKiB() ? cluster.SharedKiB() : 1024) * 1024;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t l1_bytes;
|
||||||
|
size_t l2_bytes;
|
||||||
|
size_t l3_bytes;
|
||||||
|
};
|
||||||
|
|
||||||
// Allocations and threads, shared across MatMul calls.
|
// Allocations and threads, shared across MatMul calls.
|
||||||
class MatMulEnv {
|
class MatMulEnv {
|
||||||
public:
|
public:
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,9 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
// End to end test of MatMul, comparing against a reference implementation.
|
||||||
|
|
||||||
|
#include "hwy/detect_compiler_arch.h"
|
||||||
#ifndef HWY_DISABLED_TARGETS
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
// Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require
|
// Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require
|
||||||
// double-precision support.
|
// double-precision support.
|
||||||
|
|
@ -23,14 +26,14 @@
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "ops/matmul.h"
|
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
|
#include "compression/shared.h"
|
||||||
|
#include "ops/matmul.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
#include "util/basics.h"
|
#include "util/basics.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
|
|
@ -52,7 +55,11 @@
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
// For running TestBatchSizes only once. Defined within HWY_ONCE.
|
||||||
|
extern int64_t first_target;
|
||||||
|
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
|
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
|
||||||
|
|
||||||
|
|
@ -71,8 +78,9 @@ MatStoragePtr<MatT> GenerateMat(const Extents2D extents,
|
||||||
const float scale = SfpStream::kMax / (mat->NumElements());
|
const float scale = SfpStream::kMax / (mat->NumElements());
|
||||||
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
||||||
for (size_t c = 0; c < extents.cols; c++) {
|
for (size_t c = 0; c < extents.cols; c++) {
|
||||||
content[r * extents.cols + c] =
|
float f = static_cast<float>(r * extents.cols + c) * scale;
|
||||||
static_cast<float>(r * extents.cols + c) * scale;
|
if ((r + c) & 1) f = -f; // Also generate some negative values.
|
||||||
|
content[r * extents.cols + c] = f;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -92,8 +100,9 @@ MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
|
||||||
const float scale = SfpStream::kMax / (mat->NumElements());
|
const float scale = SfpStream::kMax / (mat->NumElements());
|
||||||
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
||||||
for (size_t c = 0; c < extents.cols; c++) {
|
for (size_t c = 0; c < extents.cols; c++) {
|
||||||
content[r * extents.cols + c] =
|
float f = static_cast<float>(c * extents.rows + r) * scale;
|
||||||
static_cast<float>(c * extents.rows + r) * scale;
|
if ((r + c) & 1) f = -f; // Also generate some negative values.
|
||||||
|
content[r * extents.cols + c] = f;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -104,16 +113,28 @@ MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns 1-norm, used for estimating tolerable numerical differences.
|
// Returns 1-norm, used for estimating tolerable numerical differences.
|
||||||
double MaxColAbsSum(const float* HWY_RESTRICT a, const Extents2D& extents) {
|
double MaxRowAbsSum(const float* HWY_RESTRICT a, const Extents2D& extents) {
|
||||||
double max_col_abs_sum = 0.0;
|
double max_row_abs_sum = 0.0;
|
||||||
for (size_t c = 0; c < extents.cols; c++) {
|
for (size_t r = 0; r < extents.rows; r++) {
|
||||||
double col_abs_sum = 0.0;
|
const float* row = a + r * extents.cols;
|
||||||
for (size_t r = 0; r < extents.rows; r++) {
|
double row_abs_sum = 0.0;
|
||||||
col_abs_sum += hwy::ScalarAbs(a[r * extents.cols + c]);
|
for (size_t c = 0; c < extents.cols; c++) {
|
||||||
|
row_abs_sum += hwy::ScalarAbs(row[c]);
|
||||||
}
|
}
|
||||||
max_col_abs_sum = HWY_MAX(max_col_abs_sum, col_abs_sum);
|
max_row_abs_sum = HWY_MAX(max_row_abs_sum, row_abs_sum);
|
||||||
}
|
}
|
||||||
return max_col_abs_sum;
|
return max_row_abs_sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the maximum absolute value of `a`.
|
||||||
|
float MaxAbs(const float* HWY_RESTRICT a, const Extents2D& extents) {
|
||||||
|
float max_abs = 0.0f;
|
||||||
|
for (size_t c = 0; c < extents.cols; c++) {
|
||||||
|
for (size_t r = 0; r < extents.rows; r++) {
|
||||||
|
max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(a[r * extents.cols + c]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return max_abs;
|
||||||
}
|
}
|
||||||
|
|
||||||
// B is already transposed.
|
// B is already transposed.
|
||||||
|
|
@ -132,12 +153,25 @@ void AssertClose(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
|
||||||
DecompressAndZeroPad(df, MakeSpan(A.ptr, num_a), 0, a.get(), num_a);
|
DecompressAndZeroPad(df, MakeSpan(A.ptr, num_a), 0, a.get(), num_a);
|
||||||
DecompressAndZeroPad(df, MakeSpan(B.ptr, num_b), 0, b_trans.get(), num_b);
|
DecompressAndZeroPad(df, MakeSpan(B.ptr, num_b), 0, b_trans.get(), num_b);
|
||||||
|
|
||||||
const double norm = MaxColAbsSum(a.get(), A.Extents()) *
|
// MatMul rounds inputs to BF16, so error is proportional to the max input
|
||||||
MaxColAbsSum(b_trans.get(), B.Extents());
|
// magnitude, but also to f32 accumulation of rows in A and B.
|
||||||
// Dot(float,BF16) rounds both to BF16.
|
const double norm = MaxRowAbsSum(a.get(), A.Extents()) *
|
||||||
using RefType = hwy::If<IsF32<MatTA>() && IsF32<MatTB>(), float, BF16>;
|
MaxRowAbsSum(b_trans.get(), B.Extents());
|
||||||
const double epsilon = hwy::ConvertScalarTo<double>(hwy::Epsilon<RefType>());
|
const float max_abs =
|
||||||
const double tolerance = 200.0 * norm * epsilon;
|
MaxAbs(a.get(), A.Extents()) * MaxAbs(b_trans.get(), B.Extents());
|
||||||
|
const double eps_bf16 = hwy::ConvertScalarTo<double>(hwy::Epsilon<BF16>());
|
||||||
|
const double eps_f32 = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
|
||||||
|
double tolerance = 8 * norm * eps_f32;
|
||||||
|
// Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the
|
||||||
|
// tolerance there.
|
||||||
|
if (IsF32<MatTA>() && IsF32<MatTB>()) {
|
||||||
|
tolerance += 4 * max_abs * eps_bf16;
|
||||||
|
}
|
||||||
|
EXPECT_GE(tolerance, 1E-4);
|
||||||
|
if (tolerance > 4.0) {
|
||||||
|
fprintf(stderr, "WARN: high tolerance %f norm %f maxabs %f\n", tolerance,
|
||||||
|
norm, max_abs);
|
||||||
|
}
|
||||||
|
|
||||||
for (size_t r = 0; r < A.extents.rows; r++) {
|
for (size_t r = 0; r < A.extents.rows; r++) {
|
||||||
const float* expected_row = C_slow.Row(r);
|
const float* expected_row = C_slow.Row(r);
|
||||||
|
|
@ -148,10 +182,11 @@ void AssertClose(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
|
||||||
|
|
||||||
if (!(expected_value - tolerance <= actual_value &&
|
if (!(expected_value - tolerance <= actual_value &&
|
||||||
actual_value <= expected_value + tolerance)) {
|
actual_value <= expected_value + tolerance)) {
|
||||||
fprintf(
|
fprintf(stderr,
|
||||||
stderr,
|
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
|
||||||
"(%zu,%zu): expected %f, actual %f, norm %f eps %E tolerance %f\n",
|
"tolerance %f\n",
|
||||||
r, c, expected_value, actual_value, norm, epsilon, tolerance);
|
r, c, expected_value, actual_value, norm, max_abs, tolerance);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -171,20 +206,31 @@ HWY_INLINE void MatMulSlow(const ConstMat<MatTA> A, const ConstMat<MatTB> B,
|
||||||
const hn::ScalableTag<float> df; // lane type is ignored
|
const hn::ScalableTag<float> df; // lane type is ignored
|
||||||
const PackedSpan<const MatTB> b_span =
|
const PackedSpan<const MatTB> b_span =
|
||||||
MakeSpan(B.ptr, B.ofs + B.extents.Area());
|
MakeSpan(B.ptr, B.ofs + B.extents.Area());
|
||||||
const Extents2D C_extents(A.extents.rows, C.Cols());
|
const IndexRange all_rows_c(0, A.Extents().rows);
|
||||||
|
const IndexRange all_cols_c(0, C.Cols());
|
||||||
|
|
||||||
StaticPartitionRowsAndCols(
|
NestedPools& pools = env.Pools();
|
||||||
env.Pools(), C_extents, sizeof(MatTB),
|
hwy::ThreadPool& all_packages = pools.AllPackages();
|
||||||
[&](const Range2D& C_range, const TaskLocation& loc) {
|
const IndexRangePartition get_row_c =
|
||||||
loc.cluster.Run(
|
StaticPartition(all_rows_c, all_packages.NumWorkers(), 1);
|
||||||
C_range.rows.begin(), C_range.rows.end(),
|
ParallelizeOneRange(
|
||||||
[&](const uint64_t row, size_t /*thread*/) {
|
get_row_c, all_packages,
|
||||||
float* HWY_RESTRICT C_row = C.Row(row);
|
[&](const IndexRange& rows_c, size_t package_idx) HWY_ATTR {
|
||||||
for (size_t row_b_col_c : C_range.cols) {
|
hwy::ThreadPool& all_clusters = pools.AllClusters(package_idx);
|
||||||
const float add = add_row ? add_row[row_b_col_c] : 0.0f;
|
const size_t multiple = Allocator::Alignment() / sizeof(MatTB);
|
||||||
C_row[row_b_col_c] =
|
const IndexRangePartition get_col_c =
|
||||||
add + scale * Dot(df, b_span, row_b_col_c * B.extents.cols,
|
StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
|
||||||
A.ptr + A.Row(row), A.extents.cols);
|
ParallelizeOneRange(
|
||||||
|
get_col_c, all_clusters,
|
||||||
|
[&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR {
|
||||||
|
for (size_t r : rows_c) {
|
||||||
|
float* HWY_RESTRICT C_row = C.Row(r);
|
||||||
|
for (size_t c : cols_c) {
|
||||||
|
const float add = add_row ? add_row[c] : 0.0f;
|
||||||
|
C_row[c] =
|
||||||
|
add + scale * Dot(df, b_span, c * B.extents.cols,
|
||||||
|
A.ptr + A.Row(r), A.extents.cols);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
@ -250,6 +296,40 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||||
AssertClose(A, B, C_slow, C);
|
AssertClose(A, B, C_slow, C);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
using F32 = float;
|
||||||
|
using SFP = SfpStream;
|
||||||
|
|
||||||
|
// Sweep batch_size for a single input type and Highway target, to verify the
|
||||||
|
// row partitioning.
|
||||||
|
void TestBatchSizes() {
|
||||||
|
if (first_target == 0) first_target = HWY_TARGET;
|
||||||
|
if (HWY_TARGET != first_target) return;
|
||||||
|
|
||||||
|
for (size_t max_packages : {1, 2}) {
|
||||||
|
const size_t max_threads = 0; // no limit
|
||||||
|
NestedPools pools(max_threads, Tristate::kDefault,
|
||||||
|
BoundedSlice(0, max_packages));
|
||||||
|
#if GEMMA_DISABLE_TOPOLOGY
|
||||||
|
if (max_packages == 2) break; // we only have one package
|
||||||
|
#else
|
||||||
|
// If less than the limit, we have already tested all num_packages.
|
||||||
|
if (pools.Topology().FullTopology().packages.size() < max_packages) break;
|
||||||
|
#endif
|
||||||
|
fprintf(stderr, "TestBatchSizes %zu: %s %s\n", max_packages,
|
||||||
|
pools.TopologyString(), pools.PinString());
|
||||||
|
|
||||||
|
Tristate use_spinning = Tristate::kDefault;
|
||||||
|
pools.MaybeStartSpinning(use_spinning);
|
||||||
|
Allocator::Init(pools.Topology());
|
||||||
|
MatMulEnv env(pools);
|
||||||
|
|
||||||
|
for (size_t batch_size = 1; batch_size <= 3 * kRegRows; ++batch_size) {
|
||||||
|
TestMatMul<F32, F32>(batch_size, 256, 256, /*add=*/false, env);
|
||||||
|
}
|
||||||
|
pools.MaybeStopSpinning(use_spinning);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void TestAllMatMul() {
|
void TestAllMatMul() {
|
||||||
// Skip EMU128 (10x slower than SSE4 for SFP) and older x86.
|
// Skip EMU128 (10x slower than SSE4 for SFP) and older x86.
|
||||||
if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSE4 ||
|
if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSE4 ||
|
||||||
|
|
@ -257,32 +337,30 @@ void TestAllMatMul() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
NestedPools pools(4, /*pin=*/Tristate::kDefault);
|
NestedPools pools(0); // no limits
|
||||||
Tristate use_spinning = Tristate::kDefault;
|
Tristate use_spinning = Tristate::kDefault;
|
||||||
pools.MaybeStartSpinning(use_spinning);
|
pools.MaybeStartSpinning(use_spinning);
|
||||||
Allocator::Init(pools.Topology());
|
Allocator::Init(pools.Topology());
|
||||||
MatMulEnv env(pools);
|
MatMulEnv env(pools);
|
||||||
|
|
||||||
using F32 = float;
|
// Sizes seen in gemma_test 2B.
|
||||||
using SFP = SfpStream;
|
TestMatMul<F32>(1, 2048, 512, /*add=*/false, env);
|
||||||
|
TestMatMul<F32>(1, 2048, 2048, /*add=*/false, env);
|
||||||
|
TestMatMul<F32>(1, 2048, 16384, /*add=*/false, env);
|
||||||
|
TestMatMul<F32>(1, 16384, 2048, /*add=*/false, env);
|
||||||
|
TestMatMul<F32>(1, 2048, 256000, /*add=*/false, env);
|
||||||
|
TestMatMul<F32>(5, 2048, 512, /*add=*/false, env);
|
||||||
|
TestMatMul<F32>(5, 2048, 2048, /*add=*/false, env);
|
||||||
|
TestMatMul<F32>(5, 2048, 16384, /*add=*/false, env);
|
||||||
|
TestMatMul<F32>(5, 16384, 2048, /*add=*/false, env);
|
||||||
|
|
||||||
// large-scale test: batch_size=128 is better than 64 or 256 for SKX.
|
// medium-sized square
|
||||||
// TestMatMul<F32, SFP>(128, 24576, 3072, /*add=*/false, env);
|
TestMatMul<F32>(512, 512, 512, /*add=*/false, env);
|
||||||
// TestMatMul<F32, SFP>(128, 3072, 24576, /*add=*/false, env);
|
TestMatMul<BF16>(512, 512, 512, /*add=*/true, env);
|
||||||
TestMatMul<F32, F32>(1, 24576, 3072, /*add=*/false, env);
|
TestMatMul<F32, BF16>(512, 512, 512, /*add=*/false, env);
|
||||||
TestMatMul<F32, F32>(1, 3072, 24576, /*add=*/false, env);
|
TestMatMul<BF16, F32>(512, 512, 512, /*add=*/true, env);
|
||||||
TestMatMul<F32, SFP>(1, 24576, 3072, /*add=*/false, env);
|
TestMatMul<F32, SFP>(512, 512, 512, /*add=*/false, env);
|
||||||
TestMatMul<F32, SFP>(1, 3072, 24576, /*add=*/false, env);
|
TestMatMul<BF16, SFP>(512, 512, 512, /*add=*/true, env);
|
||||||
|
|
||||||
// medium-sized square test - temporarily disabled for faster testing.
|
|
||||||
if constexpr (false) {
|
|
||||||
TestMatMul<F32>(512, 512, 512, /*add=*/false, env);
|
|
||||||
TestMatMul<BF16>(512, 512, 512, /*add=*/true, env);
|
|
||||||
TestMatMul<F32, BF16>(512, 512, 512, /*add=*/false, env);
|
|
||||||
TestMatMul<BF16, F32>(512, 512, 512, /*add=*/true, env);
|
|
||||||
TestMatMul<F32, SFP>(512, 512, 512, /*add=*/false, env);
|
|
||||||
TestMatMul<BF16, SFP>(512, 512, 512, /*add=*/true, env);
|
|
||||||
}
|
|
||||||
|
|
||||||
// minimal non-square test. kColsARowsB must be at least 2 vectors.
|
// minimal non-square test. kColsARowsB must be at least 2 vectors.
|
||||||
TestMatMul<F32>(35, 128, 32, /*add=*/false, env);
|
TestMatMul<F32>(35, 128, 32, /*add=*/false, env);
|
||||||
|
|
@ -325,8 +403,10 @@ HWY_AFTER_NAMESPACE();
|
||||||
#if HWY_ONCE
|
#if HWY_ONCE
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
HWY_BEFORE_TEST(MatmulTest);
|
int64_t first_target = 0; // none run yet
|
||||||
HWY_EXPORT_AND_TEST_P(MatmulTest, TestAllMatMul);
|
HWY_BEFORE_TEST(MatMulTest);
|
||||||
|
HWY_EXPORT_AND_TEST_P(MatMulTest, TestBatchSizes);
|
||||||
|
HWY_EXPORT_AND_TEST_P(MatMulTest, TestAllMatMul);
|
||||||
HWY_AFTER_TEST();
|
HWY_AFTER_TEST();
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
// Copyright 2023 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// TODO: Tests of individual MatMul components.
|
||||||
|
int main() { return 0; }
|
||||||
|
|
@ -103,7 +103,7 @@ size_t CountBusyPages(size_t num_pages, size_t node, void** pages,
|
||||||
// which means it would have to be called before pages are faulted in, but
|
// which means it would have to be called before pages are faulted in, but
|
||||||
// `aligned_allocator.h` modifies the first bytes for its bookkeeping.
|
// `aligned_allocator.h` modifies the first bytes for its bookkeeping.
|
||||||
// May overwrite some of the memory with zeros.
|
// May overwrite some of the memory with zeros.
|
||||||
static void BindMemory(void* ptr, size_t bytes, size_t node) {
|
void BindMemory(void* ptr, size_t bytes, size_t node) {
|
||||||
constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough"
|
constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough"
|
||||||
// Avoid mbind because it does not report why it failed, which is most likely
|
// Avoid mbind because it does not report why it failed, which is most likely
|
||||||
// because pages are busy, in which case we want to know which.
|
// because pages are busy, in which case we want to know which.
|
||||||
|
|
@ -159,24 +159,7 @@ static void BindMemory(void* ptr, size_t bytes, size_t node) {
|
||||||
|
|
||||||
#else
|
#else
|
||||||
// TODO: support other OSes.
|
// TODO: support other OSes.
|
||||||
static void BindMemory(void*, size_t, size_t) {}
|
void BindMemory(void*, size_t, size_t) {}
|
||||||
#endif // GEMMA_NUMA && HWY_OS_LINUX
|
#endif // GEMMA_NUMA && HWY_OS_LINUX
|
||||||
|
|
||||||
void BindTensor(NestedPools& nested, const Extents2D& extents,
|
|
||||||
size_t bytes_per_col, void* ptr) {
|
|
||||||
if (!Allocator::UseNUMA()) return;
|
|
||||||
uint8_t* p8 = static_cast<uint8_t*>(ptr);
|
|
||||||
const size_t bytes_per_row = extents.cols * bytes_per_col;
|
|
||||||
StaticPartitionRowsAndCols(
|
|
||||||
nested, extents, bytes_per_col,
|
|
||||||
[&](const Range2D& r, const TaskLocation& loc) {
|
|
||||||
for (size_t row : r.rows) {
|
|
||||||
uint8_t* slice =
|
|
||||||
p8 + row * bytes_per_row + r.cols.begin() * bytes_per_col;
|
|
||||||
const size_t slice_size = r.cols.Num() * bytes_per_col;
|
|
||||||
BindMemory(slice, slice_size, loc.node);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -125,82 +125,8 @@ class Allocator {
|
||||||
static size_t alignment_;
|
static size_t alignment_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// For shorter arguments to the StaticPartitionRowsAndCols functor.
|
// For future NUMA support. TODO: use.
|
||||||
struct TaskLocation {
|
void BindMemory(void* ptr, size_t bytes, size_t node);
|
||||||
TaskLocation(size_t node, size_t package_idx, hwy::ThreadPool& cluster,
|
|
||||||
size_t worker_offset)
|
|
||||||
: node(node),
|
|
||||||
package_idx(package_idx),
|
|
||||||
cluster(cluster),
|
|
||||||
worker_offset(worker_offset) {}
|
|
||||||
size_t node;
|
|
||||||
size_t package_idx;
|
|
||||||
hwy::ThreadPool& cluster;
|
|
||||||
const size_t worker_offset;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Used in MatMul and allocator.h. Defined here because it depends on
|
|
||||||
// Allocator::Alignment().
|
|
||||||
template <class Func>
|
|
||||||
void StaticPartitionRowsAndCols(NestedPools& nested, Extents2D extents,
|
|
||||||
size_t bytes_per_element, const Func& func) {
|
|
||||||
// Both rows and cols must be a multiple of the alignment to avoid
|
|
||||||
// touching remote pages.
|
|
||||||
const size_t multiple = Allocator::Alignment() / bytes_per_element;
|
|
||||||
|
|
||||||
// Static partitioning of columns across packages. We assume that column
|
|
||||||
// sharding is more expensive, hence we distribute columns across packages,
|
|
||||||
// of which there are usually only one or two. For MatMul, the final result is
|
|
||||||
// the sum of each package's partial dot products.
|
|
||||||
hwy::ThreadPool& all_packages = nested.AllPackages();
|
|
||||||
const size_t num_packages = all_packages.NumWorkers();
|
|
||||||
const size_t cols_per_package =
|
|
||||||
hwy::RoundUpTo(hwy::DivCeil(extents.cols, num_packages), multiple);
|
|
||||||
const size_t col_tasks = hwy::DivCeil(extents.cols, cols_per_package);
|
|
||||||
HWY_ASSERT(col_tasks <= num_packages);
|
|
||||||
all_packages.Run(
|
|
||||||
0, col_tasks, [&](uint64_t package_idx, size_t package_thread) {
|
|
||||||
HWY_ASSERT(package_idx == package_thread); // one task per worker
|
|
||||||
const size_t col_begin = package_idx * cols_per_package;
|
|
||||||
const Range1D col_range =
|
|
||||||
MakeRange1D(col_begin, extents.cols, cols_per_package);
|
|
||||||
|
|
||||||
// Static partitioning of rows across the package's clusters. We assume
|
|
||||||
// that row sharding is cheaper. In MatMul, results can indeed be
|
|
||||||
// computed independently for each row of B.
|
|
||||||
hwy::ThreadPool& all_clusters = nested.AllClusters(package_idx);
|
|
||||||
const size_t num_clusters = all_clusters.NumWorkers();
|
|
||||||
const size_t rows_per_cluster =
|
|
||||||
hwy::RoundUpTo(hwy::DivCeil(extents.rows, num_clusters), multiple);
|
|
||||||
const size_t row_tasks = hwy::DivCeil(extents.rows, rows_per_cluster);
|
|
||||||
HWY_ASSERT(row_tasks <= num_clusters);
|
|
||||||
all_clusters.Run(
|
|
||||||
0, row_tasks, [&](uint64_t cluster_idx, size_t cluster_thread) {
|
|
||||||
HWY_ASSERT(cluster_idx == cluster_thread); // one task per worker
|
|
||||||
|
|
||||||
// For binding to NUMA node.
|
|
||||||
const size_t node = nested.Node(package_idx, cluster_idx);
|
|
||||||
// Older CPUs that predate chiplets typically have only one
|
|
||||||
// cluster, so callers should also parallelize using this
|
|
||||||
// per-cluster pool.
|
|
||||||
hwy::ThreadPool& cluster =
|
|
||||||
nested.Cluster(package_idx, cluster_idx);
|
|
||||||
// This plus the worker from `cluster->Run` is the TLS index.
|
|
||||||
const size_t worker_offset =
|
|
||||||
nested.WorkerOffset(package_idx, cluster_idx);
|
|
||||||
|
|
||||||
const size_t row_begin = cluster_idx * rows_per_cluster;
|
|
||||||
const Range1D row_range =
|
|
||||||
MakeRange1D(row_begin, extents.rows, rows_per_cluster);
|
|
||||||
|
|
||||||
func(Range2D(row_range, col_range),
|
|
||||||
TaskLocation(node, package_idx, cluster, worker_offset));
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void BindTensor(NestedPools& nested, size_t rows, size_t cols,
|
|
||||||
size_t bytes_per_col, void* ptr);
|
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@
|
||||||
#include "compression/io.h" // Path
|
#include "compression/io.h" // Path
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/gemma.h" // For CreateGemma
|
#include "gemma/gemma.h" // For CreateGemma
|
||||||
|
#include "ops/matmul.h"
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "util/basics.h" // Tristate
|
#include "util/basics.h" // Tristate
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
|
|
@ -115,6 +116,7 @@ class AppArgs : public ArgsBase<AppArgs> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Callers must call Allocator::Init(pools.Topology()) after this.
|
||||||
static inline NestedPools CreatePools(const AppArgs& app) {
|
static inline NestedPools CreatePools(const AppArgs& app) {
|
||||||
return NestedPools(app.max_threads, app.pin,
|
return NestedPools(app.max_threads, app.pin,
|
||||||
BoundedSlice(app.skip_packages, app.max_packages),
|
BoundedSlice(app.skip_packages, app.max_packages),
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h" // HWY_IS_MSAN
|
#include "hwy/base.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
|
|
||||||
#if HWY_IS_MSAN
|
#if HWY_IS_MSAN
|
||||||
|
|
@ -60,7 +60,7 @@ struct TokenAndProb {
|
||||||
float prob;
|
float prob;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Entire size of a 2D array. By contrast, Range2D is a subrange.
|
// Entire size of a 2D array.
|
||||||
struct Extents2D {
|
struct Extents2D {
|
||||||
Extents2D() : rows(0), cols(0) {}
|
Extents2D() : rows(0), cols(0) {}
|
||||||
Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) {
|
Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) {
|
||||||
|
|
@ -74,11 +74,13 @@ struct Extents2D {
|
||||||
size_t cols;
|
size_t cols;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Range2D consists of two Range1D.
|
struct IndexRange {
|
||||||
struct Range1D {
|
IndexRange(size_t begin, size_t end) : begin_(begin), end_(end) {
|
||||||
Range1D(size_t begin, size_t end) : begin_(begin), end_(end) {
|
|
||||||
HWY_DASSERT(begin < end);
|
HWY_DASSERT(begin < end);
|
||||||
}
|
}
|
||||||
|
IndexRange(const IndexRange& other) = default;
|
||||||
|
IndexRange& operator=(const IndexRange& other) = default;
|
||||||
|
|
||||||
size_t Num() const { return end_ - begin_; }
|
size_t Num() const { return end_ - begin_; }
|
||||||
|
|
||||||
// Enable range-based for loops.
|
// Enable range-based for loops.
|
||||||
|
|
@ -101,22 +103,15 @@ struct Range1D {
|
||||||
Iterator begin() const { return Iterator(begin_); }
|
Iterator begin() const { return Iterator(begin_); }
|
||||||
Iterator end() const { return Iterator(end_); }
|
Iterator end() const { return Iterator(end_); }
|
||||||
|
|
||||||
const size_t begin_;
|
size_t begin_;
|
||||||
const size_t end_;
|
size_t end_;
|
||||||
};
|
};
|
||||||
|
|
||||||
static inline Range1D MakeRange1D(size_t begin, size_t end, size_t max_size) {
|
static inline IndexRange MakeIndexRange(size_t begin, size_t end,
|
||||||
return Range1D(begin, HWY_MIN(begin + max_size, end));
|
size_t max_size) {
|
||||||
|
return IndexRange(begin, HWY_MIN(begin + max_size, end));
|
||||||
}
|
}
|
||||||
|
|
||||||
// In MatMul, the two axes are used independently, hence we do not define
|
|
||||||
// Range2D as a top-left and extents.
|
|
||||||
struct Range2D {
|
|
||||||
Range2D(Range1D rows, Range1D cols) : rows(rows), cols(cols) {}
|
|
||||||
const Range1D rows;
|
|
||||||
const Range1D cols;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because
|
// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because
|
||||||
// it is always float and does not support compressed T, but does support an
|
// it is always float and does not support compressed T, but does support an
|
||||||
// arbitrary stride >= cols.
|
// arbitrary stride >= cols.
|
||||||
|
|
@ -125,6 +120,10 @@ class RowPtr {
|
||||||
public:
|
public:
|
||||||
RowPtr(T* HWY_RESTRICT row0, size_t cols)
|
RowPtr(T* HWY_RESTRICT row0, size_t cols)
|
||||||
: row0_(row0), cols_(cols), stride_(cols) {}
|
: row0_(row0), cols_(cols), stride_(cols) {}
|
||||||
|
RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride)
|
||||||
|
: row0_(row0), cols_(cols), stride_(stride) {
|
||||||
|
HWY_DASSERT(stride >= cols);
|
||||||
|
}
|
||||||
|
|
||||||
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
|
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
|
||||||
size_t Cols() const { return cols_; }
|
size_t Cols() const { return cols_; }
|
||||||
|
|
@ -207,6 +206,7 @@ struct ConstMat {
|
||||||
}
|
}
|
||||||
|
|
||||||
const Extents2D& Extents() const { return extents; }
|
const Extents2D& Extents() const { return extents; }
|
||||||
|
size_t Stride() const { return extents.cols; }
|
||||||
|
|
||||||
// Shrinks the row-extent of this matrix view, i.e. reduces the view to a
|
// Shrinks the row-extent of this matrix view, i.e. reduces the view to a
|
||||||
// subrange of the original rows starting at row 0.
|
// subrange of the original rows starting at row 0.
|
||||||
|
|
|
||||||
|
|
@ -39,31 +39,106 @@ static void SortByDescendingSize(std::vector<T>& groups) {
|
||||||
[](const T& a, const T& b) { return a.Size() > b.Size(); });
|
[](const T& a, const T& b) { return a.Size() > b.Size(); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Singleton, holds the original process affinity and the pinning status.
|
||||||
|
class Pinning {
|
||||||
|
static bool InContainer() {
|
||||||
|
return false; }
|
||||||
|
|
||||||
|
public:
|
||||||
|
// Returns set of LPs available for use. Subsequent calls return the same
|
||||||
|
// set as the first, because pinning overwrites the main thread's affinity.
|
||||||
|
// Thread-hostile, not called concurrently.
|
||||||
|
LPS EnabledLPs() {
|
||||||
|
if (original_affinity_.Any()) return original_affinity_;
|
||||||
|
|
||||||
|
// Regardless of topology, ignore LPs disabled via OS, taskset, or numactl.
|
||||||
|
LPS enabled_lps;
|
||||||
|
if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) {
|
||||||
|
const size_t num_lps = hwy::TotalLogicalProcessors();
|
||||||
|
fprintf(
|
||||||
|
stderr,
|
||||||
|
"Warning, unknown OS affinity, considering all %zu LPs enabled\n.",
|
||||||
|
num_lps);
|
||||||
|
for (size_t lp = 0; lp < num_lps; ++lp) {
|
||||||
|
enabled_lps.Set(lp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Without threading support, only keep the first enabled LP; it might still
|
||||||
|
// make sense to pin the main thread to avoid migrations.
|
||||||
|
if (HWY_UNLIKELY(!hwy::HaveThreadingSupport())) {
|
||||||
|
HWY_ASSERT(enabled_lps.Any());
|
||||||
|
const size_t lp = enabled_lps.First();
|
||||||
|
enabled_lps = LPS();
|
||||||
|
enabled_lps.Set(lp);
|
||||||
|
fprintf(stderr,
|
||||||
|
"Warning, threads not supported, using only the main thread\n.");
|
||||||
|
}
|
||||||
|
|
||||||
|
original_affinity_ = enabled_lps;
|
||||||
|
return enabled_lps;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetPolicy(Tristate pin) {
|
||||||
|
if (pin == Tristate::kDefault) {
|
||||||
|
// Pinning is unreliable inside containers because the hypervisor might
|
||||||
|
// periodically change our affinity mask, or other processes might also
|
||||||
|
// pin themselves to the same LPs.
|
||||||
|
pin = InContainer() ? Tristate::kFalse : Tristate::kTrue;
|
||||||
|
}
|
||||||
|
want_pin_ = (pin == Tristate::kTrue);
|
||||||
|
any_error_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If want_pin_, tries to pin each worker in `pool` to an LP in `cluster`,
|
||||||
|
// and sets `any_error_` if any fails.
|
||||||
|
void MaybePin(const BoundedTopology::Cluster& cluster, PoolPtr& pool) {
|
||||||
|
if (HWY_UNLIKELY(!want_pin_)) return;
|
||||||
|
|
||||||
|
const std::vector<size_t> lps = cluster.LPVector();
|
||||||
|
HWY_ASSERT(pool->NumWorkers() <= lps.size());
|
||||||
|
pool->Run(
|
||||||
|
0, pool->NumWorkers(),
|
||||||
|
[this, &pool, &lps](uint64_t task, size_t thread) {
|
||||||
|
HWY_ASSERT(task == thread); // each worker has one task
|
||||||
|
if (HWY_UNLIKELY(!hwy::PinThreadToLogicalProcessor(lps[task]))) {
|
||||||
|
fprintf(stderr,
|
||||||
|
"Pinning failed for task %zu of %zu to %zu (size %zu)\n",
|
||||||
|
task, pool->NumWorkers(), lps[task], lps.size());
|
||||||
|
(void)any_error_.test_and_set();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Called ONCE after all MaybePin because it invalidates the error status.
|
||||||
|
bool AllPinned(const char** pin_string) {
|
||||||
|
// If !want_pin_, MaybePin will return without setting any_error_, but in
|
||||||
|
// that case we still want to return false to avoid spinning.
|
||||||
|
// .test() was only added in C++20, so we use .test_and_set() instead.
|
||||||
|
const bool all_pinned = want_pin_ && !any_error_.test_and_set();
|
||||||
|
*pin_string = all_pinned ? "pinned"
|
||||||
|
: want_pin_ ? "pinning failed"
|
||||||
|
: "pinning skipped";
|
||||||
|
return all_pinned;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::atomic_flag any_error_ = ATOMIC_FLAG_INIT;
|
||||||
|
bool want_pin_; // set in SetPolicy
|
||||||
|
LPS original_affinity_;
|
||||||
|
}; // Pinning
|
||||||
|
|
||||||
|
// Singleton saves global affinity across all BoundedTopology instances because
|
||||||
|
// pinning overwrites it.
|
||||||
|
static Pinning& GetPinning() {
|
||||||
|
static Pinning pinning;
|
||||||
|
return pinning;
|
||||||
|
}
|
||||||
|
|
||||||
BoundedTopology::BoundedTopology(BoundedSlice package_slice,
|
BoundedTopology::BoundedTopology(BoundedSlice package_slice,
|
||||||
BoundedSlice cluster_slice,
|
BoundedSlice cluster_slice,
|
||||||
BoundedSlice lp_slice) {
|
BoundedSlice lp_slice) {
|
||||||
// Regardless of topology, ignore LPs disabled via OS, taskset, or numactl.
|
const LPS enabled_lps = GetPinning().EnabledLPs();
|
||||||
LPS enabled_lps;
|
|
||||||
if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) {
|
|
||||||
const size_t num_lps = hwy::TotalLogicalProcessors();
|
|
||||||
fprintf(stderr,
|
|
||||||
"Warning, unknown OS affinity, considering all %zu LPs enabled\n.",
|
|
||||||
num_lps);
|
|
||||||
for (size_t lp = 0; lp < num_lps; ++lp) {
|
|
||||||
enabled_lps.Set(lp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Without threading support, only keep the first enabled LP; it might still
|
|
||||||
// make sense to pin the main thread to avoid migrations.
|
|
||||||
if (HWY_UNLIKELY(!hwy::HaveThreadingSupport())) {
|
|
||||||
HWY_ASSERT(enabled_lps.Any());
|
|
||||||
const size_t lp = enabled_lps.First();
|
|
||||||
enabled_lps = LPS();
|
|
||||||
enabled_lps.Set(lp);
|
|
||||||
fprintf(stderr,
|
|
||||||
"Warning, threads not supported, using only the main thread\n.");
|
|
||||||
}
|
|
||||||
|
|
||||||
#if !GEMMA_DISABLE_TOPOLOGY
|
#if !GEMMA_DISABLE_TOPOLOGY
|
||||||
if (HWY_LIKELY(!topology_.packages.empty())) {
|
if (HWY_LIKELY(!topology_.packages.empty())) {
|
||||||
|
|
@ -110,19 +185,33 @@ BoundedTopology::Cluster::Cluster(const LPS& enabled_lps,
|
||||||
|
|
||||||
AddLP(lp);
|
AddLP(lp);
|
||||||
|
|
||||||
// Set `node` once, and ensure subsequent nodes match - we assume there
|
// Set fields once, and ensure subsequent LPs match - we assume there
|
||||||
// is only one NUMA node per cluster.
|
// is only one NUMA node per cluster, with the same L2/L3 size.
|
||||||
const size_t lp_node = static_cast<size_t>(all_lps[lp].node);
|
const size_t lp_node = static_cast<size_t>(all_lps[lp].node);
|
||||||
if (is_first_lp) {
|
if (is_first_lp) {
|
||||||
is_first_lp = false;
|
is_first_lp = false;
|
||||||
node_ = lp_node;
|
node_ = lp_node;
|
||||||
|
private_kib_ = tcluster.private_kib;
|
||||||
|
shared_kib_ = tcluster.shared_kib;
|
||||||
} else {
|
} else {
|
||||||
static bool warned = false;
|
static bool warned = false;
|
||||||
if (lp_node != node_ && !warned) {
|
if (HWY_LIKELY(!warned)) {
|
||||||
warned = true;
|
if (HWY_UNLIKELY(lp_node != node_)) {
|
||||||
fprintf(stderr, "WARNING: lp %zu on node %zu != cluster node %zu.\n",
|
warned = true;
|
||||||
lp, lp_node, node_);
|
fprintf(stderr, "WARNING: lp %zu on node %zu != cluster node %zu.\n",
|
||||||
}
|
lp, lp_node, node_);
|
||||||
|
}
|
||||||
|
if (HWY_UNLIKELY(private_kib_ != tcluster.private_kib)) {
|
||||||
|
warned = true;
|
||||||
|
fprintf(stderr, "WARNING: lp %zu private_kib %zu != cluster %zu.\n",
|
||||||
|
lp, private_kib_, tcluster.private_kib);
|
||||||
|
}
|
||||||
|
if (HWY_UNLIKELY(shared_kib_ != tcluster.shared_kib)) {
|
||||||
|
warned = true;
|
||||||
|
fprintf(stderr, "WARNING: lp %zu shared_kib %zu != cluster %zu.\n",
|
||||||
|
lp, shared_kib_, tcluster.shared_kib);
|
||||||
|
}
|
||||||
|
} // !warned
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -141,6 +230,7 @@ BoundedTopology::Package::Package(const LPS& enabled_lps,
|
||||||
"cluster", tpackage.clusters.size(), [&](size_t cluster_idx) {
|
"cluster", tpackage.clusters.size(), [&](size_t cluster_idx) {
|
||||||
const hwy::Topology::Cluster& tcluster = tpackage.clusters[cluster_idx];
|
const hwy::Topology::Cluster& tcluster = tpackage.clusters[cluster_idx];
|
||||||
Cluster cluster(enabled_lps, topology.lps, tcluster);
|
Cluster cluster(enabled_lps, topology.lps, tcluster);
|
||||||
|
|
||||||
// Skip if empty, i.e. too few `enabled_lps`.
|
// Skip if empty, i.e. too few `enabled_lps`.
|
||||||
if (HWY_LIKELY(cluster.Size() != 0)) {
|
if (HWY_LIKELY(cluster.Size() != 0)) {
|
||||||
clusters.push_back(std::move(cluster));
|
clusters.push_back(std::move(cluster));
|
||||||
|
|
@ -267,56 +357,6 @@ static PoolPtr MakePool(size_t num_workers) {
|
||||||
return std::make_unique<hwy::ThreadPool>(num_threads);
|
return std::make_unique<hwy::ThreadPool>(num_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool InContainer() {
|
|
||||||
return false;}
|
|
||||||
|
|
||||||
class NestedPools::Pinning {
|
|
||||||
public:
|
|
||||||
Pinning(Tristate pin, const BoundedTopology& topology) {
|
|
||||||
if (pin == Tristate::kDefault) {
|
|
||||||
// Pinning is unreliable inside containers because the hypervisor might
|
|
||||||
// periodically change our affinity mask, or other processes might also
|
|
||||||
// pin themselves to the same LPs.
|
|
||||||
pin = InContainer() ? Tristate::kFalse : Tristate::kTrue;
|
|
||||||
}
|
|
||||||
want_pin_ = (pin == Tristate::kTrue);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If want_pin_, tries to pin each worker in `pool` to an LP in `cluster`,
|
|
||||||
// and sets `any_error_` if any fails.
|
|
||||||
void MaybePin(const BoundedTopology::Cluster& cluster, PoolPtr& pool) {
|
|
||||||
if (HWY_UNLIKELY(!want_pin_)) return;
|
|
||||||
|
|
||||||
const std::vector<size_t> lps = cluster.LPVector();
|
|
||||||
HWY_ASSERT(pool->NumWorkers() <= lps.size());
|
|
||||||
pool->Run(
|
|
||||||
0, pool->NumWorkers(),
|
|
||||||
[this, &pool, &lps](uint64_t task, size_t thread) {
|
|
||||||
HWY_ASSERT(task == thread); // each worker has one task
|
|
||||||
if (HWY_UNLIKELY(!hwy::PinThreadToLogicalProcessor(lps[task]))) {
|
|
||||||
fprintf(stderr,
|
|
||||||
"Pinning failed for task %zu of %zu to %zu (size %zu)\n",
|
|
||||||
task, pool->NumWorkers(), lps[task], lps.size());
|
|
||||||
(void)any_error_.test_and_set();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
bool WantPin() const { return want_pin_; }
|
|
||||||
|
|
||||||
// Called ONCE after all MaybePin because it invalidates the error status.
|
|
||||||
bool AllPinned() {
|
|
||||||
// If !want_pin_, MaybePin will return without setting any_error_, but in
|
|
||||||
// that case we still want to return false to avoid spinning.
|
|
||||||
// .test() was only added in C++20, so we use .test_and_set() instead.
|
|
||||||
return want_pin_ && !any_error_.test_and_set();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::atomic_flag any_error_ = ATOMIC_FLAG_INIT;
|
|
||||||
bool want_pin_; // set in ctor
|
|
||||||
}; // Pinning
|
|
||||||
|
|
||||||
// Used to divide max_threads and max_workers_per_package across packages and
|
// Used to divide max_threads and max_workers_per_package across packages and
|
||||||
// clusters. Ensures small upper bounds are respected.
|
// clusters. Ensures small upper bounds are respected.
|
||||||
static size_t DivideMaxAcross(const size_t max, const size_t instances) {
|
static size_t DivideMaxAcross(const size_t max, const size_t instances) {
|
||||||
|
|
@ -333,7 +373,7 @@ NestedPools::NestedPools(size_t max_threads, Tristate pin,
|
||||||
BoundedSlice package_slice, BoundedSlice cluster_slice,
|
BoundedSlice package_slice, BoundedSlice cluster_slice,
|
||||||
BoundedSlice lp_slice)
|
BoundedSlice lp_slice)
|
||||||
: topology_(package_slice, cluster_slice, lp_slice) {
|
: topology_(package_slice, cluster_slice, lp_slice) {
|
||||||
Pinning pinning(pin, topology_);
|
GetPinning().SetPolicy(pin);
|
||||||
packages_.resize(topology_.NumPackages());
|
packages_.resize(topology_.NumPackages());
|
||||||
all_packages_ = MakePool(packages_.size());
|
all_packages_ = MakePool(packages_.size());
|
||||||
const size_t max_workers_per_package =
|
const size_t max_workers_per_package =
|
||||||
|
|
@ -344,14 +384,11 @@ NestedPools::NestedPools(size_t max_threads, Tristate pin,
|
||||||
all_packages_->Run(
|
all_packages_->Run(
|
||||||
0, all_packages_->NumWorkers(), [&](uint64_t package_idx, size_t thread) {
|
0, all_packages_->NumWorkers(), [&](uint64_t package_idx, size_t thread) {
|
||||||
HWY_ASSERT(package_idx == thread); // each thread has one task
|
HWY_ASSERT(package_idx == thread); // each thread has one task
|
||||||
packages_[package_idx] = Package(
|
packages_[package_idx] =
|
||||||
topology_, package_idx, max_workers_per_package, pinning, lp_slice);
|
Package(topology_, package_idx, max_workers_per_package, lp_slice);
|
||||||
});
|
});
|
||||||
|
|
||||||
all_pinned_ = pinning.AllPinned();
|
all_pinned_ = GetPinning().AllPinned(&pin_string_);
|
||||||
pin_string_ = all_pinned_ ? "pinned"
|
|
||||||
: pinning.WantPin() ? "pinning failed"
|
|
||||||
: "pinning skipped";
|
|
||||||
|
|
||||||
// For mapping package/cluster/thread to noncontiguous TLS indices, in case
|
// For mapping package/cluster/thread to noncontiguous TLS indices, in case
|
||||||
// cluster/thread counts differ.
|
// cluster/thread counts differ.
|
||||||
|
|
@ -368,14 +405,9 @@ NestedPools::NestedPools(size_t max_threads, Tristate pin,
|
||||||
HWY_ASSERT(max_workers_per_cluster_ <= 256);
|
HWY_ASSERT(max_workers_per_cluster_ <= 256);
|
||||||
}
|
}
|
||||||
|
|
||||||
// `max_or_zero` == 0 means no limit.
|
|
||||||
static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) {
|
|
||||||
return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero);
|
|
||||||
}
|
|
||||||
|
|
||||||
NestedPools::Package::Package(const BoundedTopology& topology,
|
NestedPools::Package::Package(const BoundedTopology& topology,
|
||||||
size_t package_idx,
|
size_t package_idx,
|
||||||
size_t max_workers_per_package, Pinning& pinning,
|
size_t max_workers_per_package,
|
||||||
BoundedSlice lp_slice) {
|
BoundedSlice lp_slice) {
|
||||||
// Pre-allocate because elements are set concurrently.
|
// Pre-allocate because elements are set concurrently.
|
||||||
clusters_.resize(topology.NumClusters(package_idx));
|
clusters_.resize(topology.NumClusters(package_idx));
|
||||||
|
|
@ -393,7 +425,7 @@ NestedPools::Package::Package(const BoundedTopology& topology,
|
||||||
clusters_[cluster_idx] =
|
clusters_[cluster_idx] =
|
||||||
MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster));
|
MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster));
|
||||||
// Pin workers AND the calling thread from `all_clusters`.
|
// Pin workers AND the calling thread from `all_clusters`.
|
||||||
pinning.MaybePin(cluster, clusters_[cluster_idx]);
|
GetPinning().MaybePin(cluster, clusters_[cluster_idx]);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
169
util/threading.h
169
util/threading.h
|
|
@ -17,14 +17,17 @@
|
||||||
#define THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
|
#define THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
#include <memory> // std::unique_ptr
|
#include <memory> // std::unique_ptr
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
// IWYU pragma: begin_exports
|
||||||
#include "util/basics.h" // Tristate
|
#include "util/basics.h" // Tristate
|
||||||
#include "hwy/base.h" // HWY_ASSERT
|
#include "hwy/base.h" // HWY_ASSERT
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/contrib/thread_pool/topology.h"
|
#include "hwy/contrib/thread_pool/topology.h"
|
||||||
|
// IWYU pragma: end_exports
|
||||||
|
|
||||||
#ifndef GEMMA_DISABLE_TOPOLOGY
|
#ifndef GEMMA_DISABLE_TOPOLOGY
|
||||||
#define GEMMA_DISABLE_TOPOLOGY 0
|
#define GEMMA_DISABLE_TOPOLOGY 0
|
||||||
|
|
@ -32,6 +35,15 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
static inline size_t SaturatingSub(size_t a, size_t b) {
|
||||||
|
return a - HWY_MIN(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
// `max_or_zero` == 0 means no limit.
|
||||||
|
static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) {
|
||||||
|
return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero);
|
||||||
|
}
|
||||||
|
|
||||||
// A slice of a 1D integer range such as the indices of packages or clusters.
|
// A slice of a 1D integer range such as the indices of packages or clusters.
|
||||||
// This allows assigning them to multiple instances of our binary.
|
// This allows assigning them to multiple instances of our binary.
|
||||||
class BoundedSlice {
|
class BoundedSlice {
|
||||||
|
|
@ -86,6 +98,7 @@ using PoolPtr = std::unique_ptr<hwy::ThreadPool>;
|
||||||
// back to a single package and cluster.
|
// back to a single package and cluster.
|
||||||
class BoundedTopology {
|
class BoundedTopology {
|
||||||
public:
|
public:
|
||||||
|
// Thread-hostile, typically called from main thread.
|
||||||
BoundedTopology(BoundedSlice package_slice, BoundedSlice cluster_slice,
|
BoundedTopology(BoundedSlice package_slice, BoundedSlice cluster_slice,
|
||||||
BoundedSlice lp_slice);
|
BoundedSlice lp_slice);
|
||||||
|
|
||||||
|
|
@ -112,6 +125,8 @@ class BoundedTopology {
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Node() const { return node_; }
|
size_t Node() const { return node_; }
|
||||||
|
size_t PrivateKiB() const { return private_kib_; }
|
||||||
|
size_t SharedKiB() const { return shared_kib_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void AddLP(size_t lp) {
|
void AddLP(size_t lp) {
|
||||||
|
|
@ -126,6 +141,10 @@ class BoundedTopology {
|
||||||
size_t num_workers_ = 0;
|
size_t num_workers_ = 0;
|
||||||
// NUMA node, set from hwy::Topology::LP::node.
|
// NUMA node, set from hwy::Topology::LP::node.
|
||||||
size_t node_ = 0;
|
size_t node_ = 0;
|
||||||
|
// L2 cache size in KiB, or 0 if unknown.
|
||||||
|
size_t private_kib_ = 0;
|
||||||
|
// L3 cache size in KiB, or 0 if unknown.
|
||||||
|
size_t shared_kib_ = 0;
|
||||||
}; // Cluster
|
}; // Cluster
|
||||||
|
|
||||||
size_t NumClusters(size_t package_idx) const {
|
size_t NumClusters(size_t package_idx) const {
|
||||||
|
|
@ -145,6 +164,10 @@ class BoundedTopology {
|
||||||
return package.clusters[cluster_idx];
|
return package.clusters[cluster_idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if !GEMMA_DISABLE_TOPOLOGY
|
||||||
|
const hwy::Topology& FullTopology() const { return topology_; }
|
||||||
|
#endif
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct Package {
|
struct Package {
|
||||||
// Topology is unknown, rely on OS affinity and user-specified slice.
|
// Topology is unknown, rely on OS affinity and user-specified slice.
|
||||||
|
|
@ -257,6 +280,8 @@ class NestedPools {
|
||||||
// Returns the first of `cluster.NumWorkers()` TLS indices, to which callers
|
// Returns the first of `cluster.NumWorkers()` TLS indices, to which callers
|
||||||
// add the worker index given by `cluster.Run`.
|
// add the worker index given by `cluster.Run`.
|
||||||
size_t WorkerOffset(size_t package_idx, size_t cluster_idx) const {
|
size_t WorkerOffset(size_t package_idx, size_t cluster_idx) const {
|
||||||
|
HWY_DASSERT(package_idx < packages_.size());
|
||||||
|
HWY_DASSERT(cluster_idx < packages_[package_idx].NumClusters());
|
||||||
return (package_idx * max_clusters_per_package_ + cluster_idx) *
|
return (package_idx * max_clusters_per_package_ + cluster_idx) *
|
||||||
max_workers_per_cluster_;
|
max_workers_per_cluster_;
|
||||||
}
|
}
|
||||||
|
|
@ -267,26 +292,25 @@ class NestedPools {
|
||||||
const char* TopologyString() const { return topology_.TopologyString(); }
|
const char* TopologyString() const { return topology_.TopologyString(); }
|
||||||
const char* PinString() const { return pin_string_; }
|
const char* PinString() const { return pin_string_; }
|
||||||
|
|
||||||
// Returns a single pool on the first package: either one thread per cluster
|
// Returns a single pool on the given package: either one thread per cluster
|
||||||
// if there is more than one, which maximizes available memory bandwidth, or
|
// if there is more than one, which maximizes available memory bandwidth, or
|
||||||
// the first cluster, which is typically the whole package. For use by callers
|
// the first cluster, which is typically the whole package. For use by callers
|
||||||
// that only parallelize over a 1D range, as opposed to the nested
|
// that only have a single parallel-for.
|
||||||
// parallelism of `StaticPartitionRowsAndCols`.
|
hwy::ThreadPool& Pool(size_t package_idx = 0) {
|
||||||
hwy::ThreadPool& Pool() {
|
|
||||||
// Only one cluster: use its pool, typically a whole socket.
|
// Only one cluster: use its pool, typically a whole socket.
|
||||||
if (AllClusters(0).NumWorkers() == 1) return Cluster(0, 0);
|
if (AllClusters(package_idx).NumWorkers() == 1) {
|
||||||
return AllClusters(0);
|
return Cluster(package_idx, 0);
|
||||||
|
}
|
||||||
|
// One worker per cluster to maximize bandwidth availability.
|
||||||
|
return AllClusters(package_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class Pinning;
|
|
||||||
|
|
||||||
class Package {
|
class Package {
|
||||||
public:
|
public:
|
||||||
Package() = default; // for vector
|
Package() = default; // for vector
|
||||||
Package(const BoundedTopology& topology, size_t package_idx,
|
Package(const BoundedTopology& topology, size_t package_idx,
|
||||||
size_t max_workers_per_package, Pinning& pinning,
|
size_t max_workers_per_package, BoundedSlice lp_slice);
|
||||||
BoundedSlice lp_slice);
|
|
||||||
|
|
||||||
size_t NumClusters() const { return clusters_.size(); }
|
size_t NumClusters() const { return clusters_.size(); }
|
||||||
size_t MaxWorkersPerCluster() const {
|
size_t MaxWorkersPerCluster() const {
|
||||||
|
|
@ -330,11 +354,134 @@ class NestedPools {
|
||||||
std::vector<Package> packages_;
|
std::vector<Package> packages_;
|
||||||
PoolPtr all_packages_;
|
PoolPtr all_packages_;
|
||||||
|
|
||||||
// For TLS indices.
|
// For TLS indices. One might think this belongs in BoundedTopology, but it
|
||||||
|
// depends on max_threads, which is passed to the NestedPools constructor.
|
||||||
size_t max_clusters_per_package_ = 0;
|
size_t max_clusters_per_package_ = 0;
|
||||||
size_t max_workers_per_cluster_ = 0;
|
size_t max_workers_per_cluster_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Splits `range` into subranges of size `task_size`, except for the last,
|
||||||
|
// which receives the remainder. Used with the `ParallelizeOneRange` etc.
|
||||||
|
// functions below.
|
||||||
|
class IndexRangePartition {
|
||||||
|
public:
|
||||||
|
IndexRangePartition(const IndexRange& range, const size_t task_size)
|
||||||
|
: range_(range), task_size_(task_size) {
|
||||||
|
const size_t num = range.Num();
|
||||||
|
HWY_DASSERT(task_size_ != 0);
|
||||||
|
num_tasks_ = hwy::DivCeil(num, task_size_);
|
||||||
|
HWY_DASSERT(num_tasks_ != 0);
|
||||||
|
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||||
|
const size_t handled = num_tasks_ * task_size_;
|
||||||
|
// The last task may extend beyond items, but at most by (task_size_ - 1).
|
||||||
|
HWY_DASSERT(num <= handled && handled < num + task_size_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t TaskSize() const { return task_size_; }
|
||||||
|
size_t NumTasks() const { return num_tasks_; }
|
||||||
|
|
||||||
|
IndexRange Range(size_t task_idx) const {
|
||||||
|
HWY_DASSERT(task_idx < NumTasks());
|
||||||
|
return MakeIndexRange(range_.begin() + task_idx * task_size_, range_.end(),
|
||||||
|
task_size_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
IndexRange range_;
|
||||||
|
size_t task_size_;
|
||||||
|
size_t num_tasks_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Starts with `max_size` and rounds DOWN to a multiple of `size_multiple`
|
||||||
|
// unless that would be zero. It is the caller's responsibility to choose
|
||||||
|
// `size_multiple` to avoid two heavily imbalanced tasks.
|
||||||
|
// Use when the number of tasks does not matter, but each must fit into caches.
|
||||||
|
static inline IndexRangePartition MaxSizePartition(const IndexRange& range,
|
||||||
|
const size_t max_size,
|
||||||
|
const size_t size_multiple) {
|
||||||
|
HWY_DASSERT(size_multiple != 0);
|
||||||
|
size_t size = HWY_MIN(range.Num(), max_size);
|
||||||
|
if (size > size_multiple) size = hwy::RoundDownTo(size, size_multiple);
|
||||||
|
return IndexRangePartition(range, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Up to `max_tasks` tasks, each rounded UP to `size_multiple`, unless that
|
||||||
|
// would be more than the range. Use when the number of tasks is known, e.g.
|
||||||
|
// one per ThreadPool worker.
|
||||||
|
static inline IndexRangePartition StaticPartition(const IndexRange& range,
|
||||||
|
const size_t max_tasks,
|
||||||
|
const size_t size_multiple) {
|
||||||
|
HWY_DASSERT(max_tasks != 0);
|
||||||
|
size_t size =
|
||||||
|
hwy::RoundUpTo(hwy::DivCeil(range.Num(), max_tasks), size_multiple);
|
||||||
|
size = HWY_MIN(size, range.Num());
|
||||||
|
return IndexRangePartition(range, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parallel-for over a single range. This takes care of translating the task
|
||||||
|
// index to a range.
|
||||||
|
template <class Func>
|
||||||
|
void ParallelizeOneRange(const IndexRangePartition& get1, hwy::ThreadPool& pool,
|
||||||
|
const Func& func) {
|
||||||
|
const size_t num_tasks = get1.NumTasks();
|
||||||
|
pool.Run(0, num_tasks, [&](uint64_t task, size_t thread) {
|
||||||
|
const IndexRange range1 = get1.Range(task);
|
||||||
|
func(range1, thread);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parallel-for over the Cartesian product of the two sets of ranges. This
|
||||||
|
// combines their indices into a single 'task' so they can be executed by one
|
||||||
|
// `pool.Run`, which increases the amount of work available to workers and
|
||||||
|
// reduces fork-join overhead vs. nested parallel-for loops. Calls `func` with
|
||||||
|
// the two ranges and the thread index within `pool`.
|
||||||
|
template <class Func>
|
||||||
|
void ParallelizeTwoRanges(const IndexRangePartition& get1,
|
||||||
|
const IndexRangePartition& get2,
|
||||||
|
hwy::ThreadPool& pool, const Func& func) {
|
||||||
|
const hwy::Divisor div1(static_cast<uint32_t>(get1.NumTasks()));
|
||||||
|
|
||||||
|
const size_t num_tasks = get1.NumTasks() * get2.NumTasks();
|
||||||
|
pool.Run(0, num_tasks, [&](uint64_t task, size_t thread) {
|
||||||
|
HWY_DASSERT(task < (uint64_t{1} << 32));
|
||||||
|
const size_t idx2 = div1.Divide(static_cast<uint32_t>(task));
|
||||||
|
const size_t idx1 = div1.Remainder(static_cast<uint32_t>(task));
|
||||||
|
HWY_DASSERT(idx1 < get1.NumTasks());
|
||||||
|
HWY_DASSERT(idx2 < get2.NumTasks());
|
||||||
|
const IndexRange range1 = get1.Range(idx1);
|
||||||
|
const IndexRange range2 = get2.Range(idx2);
|
||||||
|
func(range1, range2, thread);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// As above, for three ranges.
|
||||||
|
template <class Func>
|
||||||
|
void ParallelizeThreeRanges(const IndexRangePartition& get1,
|
||||||
|
const IndexRangePartition& get2,
|
||||||
|
const IndexRangePartition& get3,
|
||||||
|
hwy::ThreadPool& pool, const Func& func) {
|
||||||
|
const hwy::Divisor div1(static_cast<uint32_t>(get1.NumTasks()));
|
||||||
|
const size_t num12 = get1.NumTasks() * get2.NumTasks();
|
||||||
|
const hwy::Divisor div12(static_cast<uint32_t>(num12));
|
||||||
|
|
||||||
|
const size_t num_tasks = num12 * get3.NumTasks();
|
||||||
|
pool.Run(0, num_tasks, [&](uint64_t task, size_t thread) {
|
||||||
|
HWY_DASSERT(task < (uint64_t{1} << 32));
|
||||||
|
const size_t idx3 = div12.Divide(static_cast<uint32_t>(task));
|
||||||
|
const size_t task12 = div12.Remainder(static_cast<uint32_t>(task));
|
||||||
|
const size_t idx2 = div1.Divide(static_cast<uint32_t>(task12));
|
||||||
|
const size_t idx1 = div1.Remainder(static_cast<uint32_t>(task12));
|
||||||
|
HWY_DASSERT(idx1 < get1.NumTasks());
|
||||||
|
HWY_DASSERT(idx2 < get2.NumTasks());
|
||||||
|
HWY_DASSERT(idx3 < get3.NumTasks());
|
||||||
|
const IndexRange range1 = get1.Range(idx1);
|
||||||
|
const IndexRange range2 = get2.Range(idx2);
|
||||||
|
const IndexRange range3 = get3.Range(idx3);
|
||||||
|
func(range1, range2, range3, thread);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@
|
||||||
#include "gmock/gmock.h"
|
#include "gmock/gmock.h"
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
#include "hwy/base.h" // HWY_ASSERT
|
#include "hwy/base.h" // HWY_ASSERT
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
@ -111,5 +112,229 @@ TEST(ThreadingTest, TestBoundedTopology) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ThreadingTest, TestMaxSizePartition) {
|
||||||
|
const IndexRange range(0, 100);
|
||||||
|
// Round down
|
||||||
|
{
|
||||||
|
const IndexRangePartition partition = MaxSizePartition(range, 55, 32);
|
||||||
|
HWY_ASSERT(partition.TaskSize() == 32);
|
||||||
|
HWY_ASSERT(partition.NumTasks() == 4);
|
||||||
|
}
|
||||||
|
// Huge `max_size`: single task
|
||||||
|
{
|
||||||
|
const IndexRangePartition partition = MaxSizePartition(range, 9999, 1);
|
||||||
|
HWY_ASSERT(partition.TaskSize() == 100);
|
||||||
|
HWY_ASSERT(partition.NumTasks() == 1);
|
||||||
|
}
|
||||||
|
// Huge `max_size`: `size_multiple` is still respected
|
||||||
|
{
|
||||||
|
const IndexRangePartition partition = MaxSizePartition(range, 9999, 64);
|
||||||
|
HWY_ASSERT(partition.TaskSize() == 64);
|
||||||
|
HWY_ASSERT(partition.NumTasks() == 2);
|
||||||
|
}
|
||||||
|
// `size_multiple` larger than range: ignore multiple
|
||||||
|
{
|
||||||
|
const IndexRangePartition partition = MaxSizePartition(range, 55, 128);
|
||||||
|
HWY_ASSERT(partition.TaskSize() == 55);
|
||||||
|
HWY_ASSERT(partition.NumTasks() == 2);
|
||||||
|
}
|
||||||
|
// Small `max_size`: small tasks
|
||||||
|
{
|
||||||
|
const IndexRangePartition partition = MaxSizePartition(range, 2, 1);
|
||||||
|
HWY_ASSERT(partition.TaskSize() == 2);
|
||||||
|
HWY_ASSERT(partition.NumTasks() == 50);
|
||||||
|
}
|
||||||
|
// Large `max_size`: two tasks with lots of overhang
|
||||||
|
{
|
||||||
|
const IndexRangePartition partition = MaxSizePartition(range, 98, 1);
|
||||||
|
HWY_ASSERT(partition.TaskSize() == 98);
|
||||||
|
HWY_ASSERT(partition.NumTasks() == 2);
|
||||||
|
}
|
||||||
|
// `size_multiple` almost as large as a different, smaller range: imbalanced
|
||||||
|
{
|
||||||
|
const IndexRangePartition partition =
|
||||||
|
MaxSizePartition(IndexRange(0, 6), 6, 4);
|
||||||
|
HWY_ASSERT(partition.TaskSize() == 4);
|
||||||
|
HWY_ASSERT(partition.NumTasks() == 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ThreadingTest, TestStaticPartition) {
|
||||||
|
const IndexRange range(0, 100);
|
||||||
|
// Round up
|
||||||
|
{
|
||||||
|
const IndexRangePartition partition = StaticPartition(range, 2, 64);
|
||||||
|
HWY_ASSERT(partition.TaskSize() == 64);
|
||||||
|
HWY_ASSERT(partition.NumTasks() == 2);
|
||||||
|
}
|
||||||
|
// No `size_multiple`: division still rounds up
|
||||||
|
{
|
||||||
|
const IndexRangePartition partition = StaticPartition(range, 3, 1);
|
||||||
|
HWY_ASSERT(partition.TaskSize() == 34);
|
||||||
|
HWY_ASSERT(partition.NumTasks() == 3);
|
||||||
|
}
|
||||||
|
// Huge `max_tasks`: one each
|
||||||
|
{
|
||||||
|
const IndexRangePartition partition = StaticPartition(range, 9999, 1);
|
||||||
|
HWY_ASSERT(partition.TaskSize() == 1);
|
||||||
|
HWY_ASSERT(partition.NumTasks() == 100);
|
||||||
|
}
|
||||||
|
// `size_multiple` larger than range: single task
|
||||||
|
{
|
||||||
|
const IndexRangePartition partition = StaticPartition(range, 2, 128);
|
||||||
|
HWY_ASSERT(partition.TaskSize() == 100);
|
||||||
|
HWY_ASSERT(partition.NumTasks() == 1);
|
||||||
|
}
|
||||||
|
// `max_tasks` = 1: single task, even if rounding up would exceed the range
|
||||||
|
{
|
||||||
|
const IndexRangePartition partition = StaticPartition(range, 1, 8);
|
||||||
|
HWY_ASSERT(partition.TaskSize() == 100);
|
||||||
|
HWY_ASSERT(partition.NumTasks() == 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ThreadingTest, TestParallelizeOneRange) {
|
||||||
|
const IndexRange range(0, 10);
|
||||||
|
const IndexRangePartition partition = StaticPartition(range, 2, 4);
|
||||||
|
hwy::ThreadPool null_pool(0);
|
||||||
|
size_t calls = 0;
|
||||||
|
ParallelizeOneRange(partition, null_pool,
|
||||||
|
[&](const IndexRange& range, size_t) {
|
||||||
|
if (++calls == 1) {
|
||||||
|
HWY_ASSERT(range.begin() == 0 && range.end() == 8);
|
||||||
|
} else {
|
||||||
|
HWY_ASSERT(range.begin() == 8 && range.end() == 10);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
HWY_ASSERT(calls == 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ThreadingTest, TestParallelizeTwoRanges) {
|
||||||
|
const IndexRangePartition partition1 =
|
||||||
|
StaticPartition(IndexRange(0, 10), 2, 4);
|
||||||
|
const IndexRangePartition partition2 =
|
||||||
|
MaxSizePartition(IndexRange(128, 256), 32, 32);
|
||||||
|
HWY_ASSERT(partition2.NumTasks() == 4);
|
||||||
|
hwy::ThreadPool null_pool(0);
|
||||||
|
{
|
||||||
|
size_t calls = 0;
|
||||||
|
ParallelizeTwoRanges(
|
||||||
|
partition1, partition2, null_pool,
|
||||||
|
[&](const IndexRange& range1, const IndexRange& range2, size_t) {
|
||||||
|
++calls;
|
||||||
|
HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8);
|
||||||
|
HWY_ASSERT(range2.begin() % 32 == 0);
|
||||||
|
HWY_ASSERT(range2.Num() % 32 == 0);
|
||||||
|
});
|
||||||
|
HWY_ASSERT(calls == 2 * 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also swap order to test Remainder() logic.
|
||||||
|
{
|
||||||
|
size_t calls = 0;
|
||||||
|
ParallelizeTwoRanges(
|
||||||
|
partition2, partition1, null_pool,
|
||||||
|
[&](const IndexRange& range2, const IndexRange& range1, size_t) {
|
||||||
|
++calls;
|
||||||
|
HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8);
|
||||||
|
HWY_ASSERT(range2.begin() % 32 == 0);
|
||||||
|
HWY_ASSERT(range2.Num() % 32 == 0);
|
||||||
|
});
|
||||||
|
HWY_ASSERT(calls == 2 * 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ThreadingTest, TestParallelizeThreeRanges) {
|
||||||
|
// Named according to number of tasks.
|
||||||
|
const IndexRangePartition partition3 =
|
||||||
|
StaticPartition(IndexRange(0, 8), 3, 1); // [0, 3) [3, 6) [6, 8)
|
||||||
|
HWY_ASSERT(partition3.NumTasks() == 3);
|
||||||
|
const IndexRangePartition partition2 =
|
||||||
|
MaxSizePartition(IndexRange(10, 30), 10, 10); // [10, 20), [20, 30)
|
||||||
|
HWY_ASSERT(partition2.NumTasks() == 2);
|
||||||
|
const IndexRangePartition partition4 =
|
||||||
|
MaxSizePartition(IndexRange(100, 500), 100, 100); // 100, 200, 300, 400
|
||||||
|
HWY_ASSERT(partition4.NumTasks() == 4);
|
||||||
|
|
||||||
|
const auto check_ranges = [&](const IndexRange& range3,
|
||||||
|
const IndexRange& range2,
|
||||||
|
const IndexRange& range4) {
|
||||||
|
HWY_ASSERT(range3.begin() == 0 || range3.begin() == 3 ||
|
||||||
|
range3.begin() == 6);
|
||||||
|
HWY_ASSERT(range2.begin() == 10 || range2.begin() == 20);
|
||||||
|
HWY_ASSERT(range4.begin() % 100 == 0);
|
||||||
|
};
|
||||||
|
|
||||||
|
hwy::ThreadPool null_pool(0);
|
||||||
|
// All 6 permutations of the three ranges to test the Remainder() logic:
|
||||||
|
// 3, 2, 4
|
||||||
|
{
|
||||||
|
size_t calls = 0;
|
||||||
|
ParallelizeThreeRanges(
|
||||||
|
partition3, partition2, partition4, null_pool,
|
||||||
|
[&](IndexRange range3, IndexRange range2, IndexRange range4, size_t) {
|
||||||
|
++calls;
|
||||||
|
check_ranges(range3, range2, range4);
|
||||||
|
});
|
||||||
|
HWY_ASSERT(calls == 3 * 2 * 4);
|
||||||
|
}
|
||||||
|
// 3, 4, 2
|
||||||
|
{
|
||||||
|
size_t calls = 0;
|
||||||
|
ParallelizeThreeRanges(
|
||||||
|
partition3, partition4, partition2, null_pool,
|
||||||
|
[&](IndexRange range3, IndexRange range4, IndexRange range2, size_t) {
|
||||||
|
++calls;
|
||||||
|
check_ranges(range3, range2, range4);
|
||||||
|
});
|
||||||
|
HWY_ASSERT(calls == 3 * 2 * 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4, 2, 3
|
||||||
|
{
|
||||||
|
size_t calls = 0;
|
||||||
|
ParallelizeThreeRanges(
|
||||||
|
partition4, partition2, partition3, null_pool,
|
||||||
|
[&](IndexRange range4, IndexRange range2, IndexRange range3, size_t) {
|
||||||
|
++calls;
|
||||||
|
check_ranges(range3, range2, range4);
|
||||||
|
});
|
||||||
|
HWY_ASSERT(calls == 3 * 2 * 4);
|
||||||
|
}
|
||||||
|
// 4, 3, 2
|
||||||
|
{
|
||||||
|
size_t calls = 0;
|
||||||
|
ParallelizeThreeRanges(
|
||||||
|
partition4, partition3, partition2, null_pool,
|
||||||
|
[&](IndexRange range4, IndexRange range3, IndexRange range2, size_t) {
|
||||||
|
++calls;
|
||||||
|
check_ranges(range3, range2, range4);
|
||||||
|
});
|
||||||
|
HWY_ASSERT(calls == 3 * 2 * 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2, 3, 4
|
||||||
|
{
|
||||||
|
size_t calls = 0;
|
||||||
|
ParallelizeThreeRanges(
|
||||||
|
partition2, partition3, partition4, null_pool,
|
||||||
|
[&](IndexRange range2, IndexRange range3, IndexRange range4, size_t) {
|
||||||
|
++calls;
|
||||||
|
check_ranges(range3, range2, range4);
|
||||||
|
});
|
||||||
|
HWY_ASSERT(calls == 3 * 2 * 4);
|
||||||
|
}
|
||||||
|
// 2, 4, 3
|
||||||
|
{
|
||||||
|
size_t calls = 0;
|
||||||
|
ParallelizeThreeRanges(
|
||||||
|
partition2, partition4, partition3, null_pool,
|
||||||
|
[&](IndexRange range2, IndexRange range4, IndexRange range3, size_t) {
|
||||||
|
++calls;
|
||||||
|
check_ranges(range3, range2, range4);
|
||||||
|
});
|
||||||
|
HWY_ASSERT(calls == 3 * 2 * 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue