Infra improvements:

allocator: support mmap, fixed Bind, add padding
bench_matmul: Add PreventElision
BUILD: add ops_test build target
matmul.h: move ConstMat here; dynamic alloc of MatMulEnv
matmul_test: remove benchmarking
replace fprintf with HWY_WARN
threading.cc: support splitting large clusters (disabled); package_idx->pkg_idx, smaller IndexRangePartition
PiperOrigin-RevId: 717512274
This commit is contained in:
Jan Wassenberg 2025-01-20 06:22:17 -08:00 committed by Copybara-Service
parent 493688f6f1
commit c4398fc72d
14 changed files with 729 additions and 638 deletions

View File

@ -76,6 +76,12 @@ cc_test(
],
)
# For building all tests in one command, so we can test several.
test_suite(
name = "ops_tests",
tags = ["ops_tests"],
)
cc_library(
name = "ops",
hdrs = [
@ -110,7 +116,7 @@ cc_test(
srcs = ["ops/dot_test.cc"],
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
tags = ["ops_tests"],
deps = [
":allocator",
":ops",
@ -135,7 +141,7 @@ cc_test(
srcs = ["ops/ops_test.cc"],
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
tags = ["ops_tests"],
deps = [
":allocator",
":common",
@ -157,7 +163,7 @@ cc_test(
srcs = ["ops/gemma_matvec_test.cc"],
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
tags = ["ops_tests"],
deps = [
":ops",
"@googletest//:gtest_main", # buildcleaner: keep
@ -175,7 +181,7 @@ cc_test(
srcs = ["ops/matmul_unit_test.cc"],
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
tags = ["ops_tests"],
deps = [
":allocator",
":basics",
@ -195,7 +201,7 @@ cc_test(
srcs = ["ops/matmul_test.cc"],
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
tags = ["ops_tests"],
deps = [
":allocator",
":basics",
@ -205,7 +211,6 @@ cc_test(
"//compression:compress",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
"@highway//:thread_pool",
],
)
@ -217,7 +222,7 @@ cc_test(
srcs = ["ops/bench_matmul.cc"],
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
tags = ["ops_tests"],
deps = [
":allocator",
":basics",
@ -228,6 +233,7 @@ cc_test(
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
"@highway//:profiler",
"@highway//:thread_pool",
],
)

View File

@ -309,13 +309,6 @@ decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) {
}
}
template <typename T>
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m, size_t ofs = 0) {
ConstMat<T> mat = MakeConstMat(const_cast<T*>(m.data()), m.Extents(), ofs);
mat.scale = m.scale();
return mat;
}
// MatStorageT adds the actual data storage to MatPtrT.
// TODO: use Extents2D instead of rows and cols.
template <typename MatT>
@ -361,7 +354,7 @@ class MatStorageT : public MatPtrT<MatT> {
}
private:
hwy::AlignedFreeUniquePtr<MatT[]> data_;
AlignedPtr<MatT> data_;
};
// MatStorage allows heterogeneous tensors to be stored in a single vector.

View File

@ -273,11 +273,11 @@ struct PackedSpan {
// Ensures callers can read or write `num_accessible` elements starting at
// `packed_ofs`.
void BoundsCheck(size_t packed_ofs, size_t num_accessible) const {
// For NUQ, there can be fewer Packed than the number of elements, hence
// check the compressed count and ensure we have that many.
const size_t required =
CompressedArrayElements<Packed>(packed_ofs + num_accessible);
if constexpr (HWY_IS_DEBUG_BUILD) {
// For NUQ, there can be fewer Packed than the number of elements, hence
// check the compressed count and ensure we have that many.
const size_t required =
CompressedArrayElements<Packed>(packed_ofs + num_accessible);
if (num < required) {
HWY_ABORT("PackedSpan: ofs %zu, want %zu, req %zu > %zu packed",
packed_ofs, num_accessible, required, num);

View File

@ -19,6 +19,7 @@
#include <stddef.h>
#include <cmath>
#include <memory> // std::unique_ptr
#include "compression/shared.h" // BF16
#include "gemma/configs.h"
@ -63,7 +64,8 @@ struct Activations {
// Rope
RowVectorBatch<float> inv_timescale;
MatMulEnv env;
// Dynamic because no default ctor and only initialized in `Allocate`.
std::unique_ptr<MatMulEnv> env;
PostQKType post_qk = PostQKType::Rope;
// And the config.
@ -122,7 +124,7 @@ struct Activations {
inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk);
env = MatMulEnv(pools);
env = std::make_unique<MatMulEnv>(pools);
}
};

View File

@ -81,7 +81,7 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
const KVCaches& kv_caches) {
PROFILER_ZONE("Gen.Griffin");
KVCache& kv_cache = kv_caches[0];
hwy::ThreadPool& pool = activations.env.Pool();
hwy::ThreadPool& pool = activations.env->Pool();
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const size_t model_dim = layer_weights->layer_config.model_dim;
@ -252,7 +252,7 @@ class GemmaAttention {
const size_t w1_rows = heads * layer_config_.QStride();
w_q1.ShrinkRows(w1_rows);
MatMul(pre_att_rms_out, w_q1,
/*add=*/nullptr, activations_.env, RowPtrFromBatch(activations_.q));
/*add=*/nullptr, *activations_.env, RowPtrFromBatch(activations_.q));
if (is_mha_) {
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
@ -275,7 +275,7 @@ class GemmaAttention {
RowPtrF kv_rows(kv, w_rows_kv_cols);
kv_rows.SetStride(cache_pos_size_);
MatMul(pre_att_rms_out, w_q2,
/*add=*/nullptr, activations_.env, kv_rows);
/*add=*/nullptr, *activations_.env, kv_rows);
} else {
// Proceed row by row because there will be wraparound.
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
@ -464,7 +464,7 @@ class GemmaAttention {
: nullptr;
MatMul(ConstMatFromBatch(num_interleaved, activations_.att_out),
ConstMatFromWeights(layer_weights_.att_weights), add,
activations_.env, RowPtrFromBatch(activations_.att_sums));
*activations_.env, RowPtrFromBatch(activations_.att_sums));
}
public:
@ -514,7 +514,7 @@ class GemmaAttention {
layer_weights_(*layer_weights),
div_seq_len_(div_seq_len),
kv_caches_(kv_caches),
pool_(activations.env.Pool()) {
pool_(activations.env->Pool()) {
HWY_DASSERT(num_queries_ <= kv_caches_.size());
HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0,
"query heads must be a multiple of key-value heads");
@ -587,7 +587,7 @@ class VitAttention {
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
MatMul(ConstMatFromBatch(num_tokens_, activations_.pre_att_rms_out),
ConstMatFromWeights(layer_weights_.vit.qkv_einsum_w),
layer_weights_.vit.qkv_einsum_b.data_scale1(), activations_.env,
layer_weights_.vit.qkv_einsum_b.data_scale1(), *activations_.env,
RowPtrFromBatch(qkv));
}
@ -641,7 +641,7 @@ class VitAttention {
auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out);
auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w);
auto att_sums = RowPtrFromBatch(activations_.att_sums);
MatMul(att_out, att_weights, bias, activations_.env, att_sums);
MatMul(att_out, att_weights, bias, *activations_.env, att_sums);
}
public:
@ -652,7 +652,7 @@ class VitAttention {
activations_(activations),
layer_weights_(*layer_weights),
layer_config_(layer_weights->layer_config),
pool_(activations.env.Pool()) {}
pool_(activations.env->Pool()) {}
HWY_INLINE void operator()() {
ComputeQKV();
@ -728,8 +728,8 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
auto w_output = ConstMatFromWeights(layer_weights->linear_w);
// Compute the hidden layer activations.
MatMul(x, w1, bias1, activations.env, hidden_activations);
MatMul(x, w2, bias2, activations.env, multiplier);
MatMul(x, w1, bias1, *activations.env, hidden_activations);
MatMul(x, w2, bias2, *activations.env, multiplier);
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
Activation(layer_weights->layer_config.activation, hidden_activations.Row(0),
@ -739,7 +739,7 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
auto activations_mat = MakeConstMat(
hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim));
MatMul(activations_mat, w_output, output_bias, activations.env, ffw_out);
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
}
// Same as FFWNoVit, but with different layer_weights members and no second
@ -769,7 +769,7 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
auto w_output = ConstMatFromWeights(layer_weights->vit.linear_1_w);
// Compute the hidden layer activations.
MatMul(x, w1, bias1, activations.env, hidden_activations);
MatMul(x, w1, bias1, *activations.env, hidden_activations);
// Activation (Gelu), store in act.
RowPtrF multiplier = RowPtrF(nullptr, 0);
@ -780,7 +780,7 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
auto activations_mat = MakeConstMat(
hidden_activations.Row(0), Extents2D(num_interleaved, ff_hidden_dim));
MatMul(activations_mat, w_output, output_bias, activations.env, ffw_out);
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
}
// `batch_idx` indicates which row of `x` to write to.
@ -1063,7 +1063,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
// MatMul(
// MatFromBatch(kVitSeqLen, image_patches),
// MatFromWeights(weights.vit_img_embedding_kernel),
// weights.vit_img_embedding_bias.data_scale1(), activations.env,
// weights.vit_img_embedding_bias.data_scale1(), *activations.env,
// RowPtrF(activations.x.All(), kVitModelDim));
// However, MatMul currently requires that
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
@ -1073,7 +1073,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size,
image_patches[i].get(),
weights.vit_img_embedding_bias.data_scale1(),
activations.x.Batch(i), activations.env.Pool());
activations.x.Batch(i), activations.env->Pool());
}
// Add position embeddings.
AddFrom(weights.vit_img_pos_embedding.data_scale1(), activations.x.All(),
@ -1108,7 +1108,7 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
// Apply head embedding into image_tokens of size of the LLM kModelDim.
MatMul(ConstMatFromBatch(num_tokens, activations.x),
ConstMatFromWeights(weights.vit_img_head_kernel),
weights.vit_img_head_bias.data_scale1(), activations.env,
weights.vit_img_head_bias.data_scale1(), *activations.env,
RowPtrFromBatch(image_tokens));
}
@ -1281,7 +1281,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
Activations prefill_activations(weights.weights_config);
if (use_prefill_activations) {
prefill_activations.Allocate(runtime_config.prefill_tbatch_size,
activations.env.Pools());
activations.env->Pools());
}
Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end,
query_idx_start, weights,
@ -1326,7 +1326,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
// Compute logits from last layer activations.
MatMul(ConstMatFromBatch(num_queries, activations.x),
ConstMatFromWeights(weights.embedder_input_embedding),
/*add=*/nullptr, activations.env,
/*add=*/nullptr, *activations.env,
RowPtrFromBatch(activations.logits));
}
PROFILER_ZONE("Gen.Softcap+Sample+Stream");

View File

@ -16,7 +16,7 @@
// Benchmark of large MatMul instances for which the MatMulSlow would be too
// slow. This lacks a reference and is only useful for performance measurement.
#include "hwy/detect_compiler_arch.h"
#include "hwy/base.h"
#ifndef HWY_DISABLED_TARGETS
// Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require
// double-precision support.
@ -30,7 +30,9 @@
#include <stddef.h>
#include <stdio.h>
#include <algorithm>
#include <memory>
#include <vector>
#include "compression/compress.h"
#include "compression/shared.h"
@ -38,8 +40,8 @@
#include "util/allocator.h"
#include "util/basics.h"
#include "util/threading.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/nanobenchmark.h"
#include "hwy/timer.h"
// clang-format off
@ -51,6 +53,7 @@
// After highway.h
#include "compression/compress-inl.h"
#include "ops/matmul-inl.h"
#include "hwy/profiler.h" // also uses SIMD
#include "hwy/tests/test_util-inl.h"
HWY_BEFORE_NAMESPACE();
@ -74,7 +77,8 @@ MatStoragePtr<MatT> GenerateMat(const Extents2D extents,
std::make_unique<MatStorageT<MatT>>("mat", extents.rows, extents.cols);
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
HWY_ASSERT(content);
const float scale = SfpStream::kMax / (mat->NumElements());
const float scale =
SfpStream::kMax / (mat->NumElements() + hwy::Unpredictable1() - 1);
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
for (size_t c = 0; c < extents.cols; c++) {
float f = static_cast<float>(r * extents.cols + c) * scale;
@ -96,7 +100,8 @@ MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
auto mat =
std::make_unique<MatStorageT<MatT>>("trans", extents.rows, extents.cols);
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
const float scale = SfpStream::kMax / (mat->NumElements());
const float scale =
SfpStream::kMax / (mat->NumElements() + hwy::Unpredictable1() - 1);
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
for (size_t c = 0; c < extents.cols; c++) {
float f = static_cast<float>(c * extents.rows + r) * scale;
@ -111,52 +116,63 @@ MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
return mat;
}
void PrintSpeed(const char* algo, const Extents2D& A_extents,
const Extents2D& B_extents, double elapsed) {
void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
std::vector<double>& times) {
std::sort(times.begin(), times.end());
// Many measurements are with suboptimal configs, so report the best like
// bench_dnn, but also the ratio to the 3rd best.
const double elapsed = times[0];
const double ratio = times[2] / HWY_MAX(elapsed, 1E-6);
const size_t num_b = B_extents.Area();
// 2x because of FMA.
fprintf(stderr, " %10s: %f seconds, %.1f GFLOPS.\n", algo,
elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed);
fprintf(stderr, "%.1f\t%.2f\n", 2 * 1E-9 * A_extents.rows * num_b / elapsed,
ratio);
}
// Generates inputs and prints observed throughput of MatMul.
// M = A rows, K = A cols, N = C cols.
template <typename MatTA, typename MatTB = MatTA>
void BenchMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatMulEnv& env) {
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
hwy::ThreadPool& pool = env.Pool();
fprintf(stderr, "BenchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n",
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<MatTA>(),
TypeName<MatTB>());
fprintf(stderr, "BenchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", M,
K, N, add, TypeName<MatTA>(), TypeName<MatTB>());
const Extents2D A_extents(rows_ac, cols_a_rows_b);
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
const Extents2D C_extents(rows_ac, cols_bc);
const Extents2D A_extents(M, K);
const Extents2D B_extents(N, K); // already transposed
const Extents2D C_extents(M, N);
MatStoragePtr<MatTA> a = GenerateMat<MatTA>(A_extents, pool);
MatStoragePtr<MatTB> b_trans = GenerateTransposedMat<MatTB>(B_extents, pool);
RowVectorBatch<float> c_slow_batch(C_extents);
RowVectorBatch<float> c_batch(C_extents);
HWY_ASSERT(a && b_trans);
std::unique_ptr<MatStorageT<float>> add_storage;
if (add) {
add_storage = GenerateMat<float>(Extents2D(1, cols_bc), pool);
add_storage = GenerateMat<float>(Extents2D(1, N), pool);
HWY_ASSERT(add_storage);
add_storage->set_scale(1.0f);
}
MatStoragePtr<MatTA> a = GenerateMat<MatTA>(A_extents, pool);
MatStoragePtr<MatTB> b_trans = GenerateTransposedMat<MatTB>(B_extents, pool);
HWY_ASSERT(a && b_trans);
const auto A = ConstMatFromWeights(*a);
const auto B = ConstMatFromWeights(*b_trans);
const float* add_row = add ? add_storage->data_scale1() : nullptr;
const RowPtrF C = RowPtrFromBatch(c_batch);
double min_elapsed = hwy::HighestValue<double>();
for (int rep = 0; rep < 3; ++rep) {
const double start_tiled = hwy::platform::Now();
std::vector<double> times;
times.reserve(20);
double result = 0.0;
for (;;) {
const double t0 = hwy::platform::Now();
MatMul(A, B, add_row, env, C);
min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled);
times.push_back(hwy::platform::Now() - t0);
result += C.Row(0)[hwy::Unpredictable1()];
if (times.size() >= 20) break;
}
PrintSpeed("MatMul", A_extents, B_extents, min_elapsed);
hwy::PreventElision(result);
PrintSpeed(A_extents, B_extents, times);
}
using F32 = float;
@ -184,16 +200,15 @@ void BenchAllMatMul() {
Allocator::Init(pools.Topology());
MatMulEnv env(pools);
for (size_t batch_size : {1, /*4, 64,*/ 128}) {
BenchMatMul<F32, F32>(batch_size, 24576, 3072, /*add=*/false, env);
BenchMatMul<F32, F32>(batch_size, 3072, 24576, /*add=*/false, env);
BenchMatMul<BF16, SFP>(batch_size, 24576, 3072, /*add=*/false, env);
BenchMatMul<BF16, SFP>(batch_size, 3072, 24576, /*add=*/false, env);
BenchMatMul<F32, SFP>(batch_size, 24576, 3072, /*add=*/false, env);
BenchMatMul<F32, SFP>(batch_size, 3072, 24576, /*add=*/false, env);
for (size_t batch_size : {1, /* 4, 128,*/ 512}) {
constexpr bool kAdd = false;
BenchMatMul<BF16, SFP>(batch_size, 24576, 3072, kAdd, env);
BenchMatMul<BF16, SFP>(batch_size, 3072, 24576, kAdd, env);
}
pools.MaybeStopSpinning(use_spinning);
}
PROFILER_PRINT_RESULTS();
}
// NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -19,6 +19,8 @@
#include <stddef.h>
// IWYU pragma: begin_exports
#include "compression/compress.h"
#include "util/allocator.h"
#include "util/basics.h"
#include "util/threading.h"
#include "hwy/base.h"
@ -81,6 +83,70 @@ class MatMulEnv {
NestedPools* pools_;
};
// Used for the A and B arguments of `MatMul`, which are always const.
// Create via MakeConstMat. This differs from `RowPtr` in that it supports the
// `ofs` required for compressed T.
template <typename T>
struct ConstMat {
ConstMat(const T* ptr, Extents2D extents, size_t ofs = 0)
: ptr(ptr), extents(extents), ofs(ofs) {
HWY_DASSERT(ptr != nullptr);
}
// TODO: support stride for page alignment.
size_t Row(size_t r) const {
if constexpr (HWY_IS_DEBUG_BUILD) {
if (r >= extents.rows) {
HWY_ABORT("ConstMat::Row %zu out of bounds %zu", r, extents.rows);
}
}
return ofs + extents.cols * r;
}
const Extents2D& Extents() const { return extents; }
size_t Stride() const { return extents.cols; }
// Shrinks the row-extent of this matrix view, i.e. reduces the view to a
// subrange of the original rows starting at row 0.
void ShrinkRows(size_t rows) {
HWY_ASSERT(rows <= extents.rows);
extents.rows = rows;
}
const T* HWY_RESTRICT ptr;
Extents2D extents;
// `scale` allows expanding the smaller range of `SfpStream` to the original
// values. MatFromWeights sets this from `MatPtr`.
float scale = 1.0f;
// Offset to add to `ptr`; separate because T=NuqStream does not support
// pointer arithmetic.
size_t ofs;
};
// For deducing T.
template <typename T>
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents,
size_t ofs = 0) {
return ConstMat<T>(ptr, extents, ofs);
}
// For A argument to MatMul (activations).
template <typename T>
ConstMat<T> ConstMatFromBatch(size_t batch_size,
const RowVectorBatch<T>& row_vectors) {
HWY_DASSERT(batch_size <= row_vectors.BatchSize());
return MakeConstMat(const_cast<T*>(row_vectors.Const()),
Extents2D(batch_size, row_vectors.Cols()));
}
template <typename T>
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m, size_t ofs = 0) {
ConstMat<T> mat = MakeConstMat(const_cast<T*>(m.data()), m.Extents(), ofs);
mat.scale = m.scale();
return mat;
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_

View File

@ -39,7 +39,6 @@
#include "util/threading.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/timer.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
@ -55,7 +54,7 @@
HWY_BEFORE_NAMESPACE();
namespace gcpp {
// For running TestBatchSizes only once. Defined within HWY_ONCE.
// For running TestTiny only once. Defined within HWY_ONCE.
extern int64_t first_target;
namespace HWY_NAMESPACE {
@ -144,10 +143,10 @@ void AssertClose(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
const hn::ScalableTag<float> df;
const size_t num_a = A.extents.Area();
const size_t num_b = B.extents.Area();
HWY_ASSERT(num_a % hn::Lanes(df) == 0); // for DecompressAndZeroPad
HWY_ASSERT(num_b % hn::Lanes(df) == 0); // for DecompressAndZeroPad
FloatPtr a = hwy::AllocateAligned<float>(num_a);
FloatPtr b_trans = hwy::AllocateAligned<float>(num_b);
const size_t N = hn::Lanes(df);
// Round up for DecompressAndZeroPad.
FloatPtr a = hwy::AllocateAligned<float>(hwy::RoundUpTo(num_a, N));
FloatPtr b_trans = hwy::AllocateAligned<float>(hwy::RoundUpTo(num_b, N));
HWY_ASSERT(a && b_trans);
HWY_ASSERT(A.ofs == 0 && B.ofs == 0);
DecompressAndZeroPad(df, MakeSpan(A.ptr, num_a), 0, a.get(), num_a);
@ -164,13 +163,11 @@ void AssertClose(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
double tolerance = 8 * norm * eps_f32;
// Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the
// tolerance there.
if (IsF32<MatTA>() && IsF32<MatTB>()) {
if (IsF32<MatTA>() && !IsF32<MatTB>()) {
tolerance += 4 * max_abs * eps_bf16;
}
EXPECT_GE(tolerance, 1E-4);
if (tolerance > 4.0) {
fprintf(stderr, "WARN: high tolerance %f norm %f maxabs %f\n", tolerance,
norm, max_abs);
if (tolerance > 8.0) {
HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs);
}
for (size_t r = 0; r < A.extents.rows; r++) {
@ -182,11 +179,10 @@ void AssertClose(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
fprintf(stderr,
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
"tolerance %f\n",
r, c, expected_value, actual_value, norm, max_abs, tolerance);
return;
HWY_ABORT(
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
"tolerance %f\n",
r, c, expected_value, actual_value, norm, max_abs, tolerance);
}
}
}
@ -217,7 +213,7 @@ HWY_INLINE void MatMulSlow(const ConstMat<MatTA> A, const ConstMat<MatTB> B,
get_row_c, all_packages,
[&](const IndexRange& rows_c, size_t package_idx) HWY_ATTR {
hwy::ThreadPool& all_clusters = pools.AllClusters(package_idx);
const size_t multiple = Allocator::Alignment() / sizeof(MatTB);
const size_t multiple = Allocator::QuantumBytes() / sizeof(MatTB);
const IndexRangePartition get_col_c =
StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
ParallelizeOneRange(
@ -248,7 +244,6 @@ template <typename MatTA, typename MatTB = MatTA>
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatMulEnv& env) {
hwy::ThreadPool& pool = env.Pool();
const bool want_bench = cols_bc > 2000; // avoid spam for small matrices
fprintf(stderr, "TestMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n",
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<MatTA>(),
TypeName<MatTB>());
@ -276,32 +271,17 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
const RowPtrF C_slow = RowPtrFromBatch(c_slow_batch);
const RowPtrF C = RowPtrFromBatch(c_batch);
const double start_slow = hwy::platform::Now();
MatMulSlow(A, B, add_row, env, C_slow);
if (want_bench) {
PrintSpeed("MatMulSlow", A_extents, B_extents,
hwy::platform::Now() - start_slow);
}
double min_elapsed = hwy::HighestValue<double>();
for (int rep = 0; rep < (want_bench ? 3 : 1); ++rep) {
const double start_tiled = hwy::platform::Now();
MatMul(A, B, add_row, env, C);
min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled);
}
if (want_bench) {
PrintSpeed("MatMul", A_extents, B_extents, min_elapsed);
}
MatMul(A, B, add_row, env, C);
AssertClose(A, B, C_slow, C);
}
using F32 = float;
using SFP = SfpStream;
// Sweep batch_size for a single input type and Highway target, to verify the
// row partitioning.
void TestBatchSizes() {
// Sweep all dimensions for a single input type and Highway target, to verify
// the remainder handling.
void TestTiny() {
if (first_target == 0) first_target = HWY_TARGET;
if (HWY_TARGET != first_target) return;
@ -315,7 +295,7 @@ void TestBatchSizes() {
// If less than the limit, we have already tested all num_packages.
if (pools.Topology().FullTopology().packages.size() < max_packages) break;
#endif
fprintf(stderr, "TestBatchSizes %zu: %s %s\n", max_packages,
fprintf(stderr, "TestTiny %zu: %s %s\n", max_packages,
pools.TopologyString(), pools.PinString());
Tristate use_spinning = Tristate::kDefault;
@ -405,7 +385,7 @@ HWY_AFTER_NAMESPACE();
namespace gcpp {
int64_t first_target = 0; // none run yet
HWY_BEFORE_TEST(MatMulTest);
HWY_EXPORT_AND_TEST_P(MatMulTest, TestBatchSizes);
HWY_EXPORT_AND_TEST_P(MatMulTest, TestTiny);
HWY_EXPORT_AND_TEST_P(MatMulTest, TestAllMatMul);
HWY_AFTER_TEST();

View File

@ -17,45 +17,160 @@
#include <stdio.h>
#include <atomic>
#include <cstdio>
#include <vector>
#include "util/basics.h" // MaybeCheckInitialized
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/per_target.h" // VectorBytes
#if GEMMA_NUMA
#if HWY_OS_WIN
#ifndef NOMINMAX
#define NOMINMAX
// To avoid a dependency on libnuma, use syscalls directly. We require six
// arguments, which has been supported by glibc since around 2010.
#if defined(__GLIBC__) && defined(__GLIBC_PREREQ)
#if HWY_OS_LINUX && __GLIBC_PREREQ(2, 11)
#define GEMMA_LINUX_SYSCALL6
#endif
#ifndef WIN32_LEAN_AND_MEAN
#define WIN32_LEAN_AND_MEAN
#endif
#include <windows.h>
#elif HWY_OS_LINUX
#ifndef GEMMA_BIND // allow override
#if defined(GEMMA_LINUX_SYSCALL6) && !defined(__ANDROID_API__)
#define GEMMA_BIND 1
#else
#define GEMMA_BIND 0
#endif
#endif // GEMMA_BIND
#if GEMMA_BIND && HWY_OS_LINUX
// `move_pages` requires anonymous/private mappings, hence mmap.
#include <sys/mman.h>
#include <sys/syscall.h>
#include <cerrno>
#endif // HWY_OS_*
#endif // GEMMA_NUMA
#endif // GEMMA_BIND && HWY_OS_LINUX
namespace gcpp {
namespace {
/*static*/ size_t Allocator::bytes_per_page_;
/*static*/ bool Allocator::use_numa_;
/*static*/ size_t Allocator::alignment_;
size_t DetectLineBytes() {
if (const hwy::Cache* caches = hwy::DataCaches()) {
// Might not have an L3.
return HWY_MAX(caches[2].bytes_per_line, caches[3].bytes_per_line);
} else {
return HWY_ALIGNMENT;
}
}
/*static*/ size_t Allocator::DetectPageSize() {
#if HWY_OS_WIN
SYSTEM_INFO sys_info;
GetSystemInfo(&sys_info);
return sys_info.dwPageSize;
#elif HWY_OS_LINUX
return sysconf(_SC_PAGESIZE);
size_t DetectPageSize() {
#if HWY_OS_LINUX
size_t page_bytes = static_cast<size_t>(sysconf(_SC_PAGESIZE));
HWY_ASSERT(page_bytes <= (4 << 20));
return page_bytes;
#else
return 0;
#endif
}
#if GEMMA_NUMA && HWY_OS_LINUX
} // namespace
static size_t line_bytes_;
static size_t vector_bytes_;
static size_t quantum_bytes_;
static size_t l1_bytes_;
static size_t l2_bytes_;
static bool should_bind_ = false;
size_t Allocator::LineBytes() { return line_bytes_; }
size_t Allocator::VectorBytes() { return vector_bytes_; }
size_t Allocator::QuantumBytes() { return quantum_bytes_; }
size_t Allocator::L1Bytes() { return l1_bytes_; }
size_t Allocator::L2Bytes() { return l2_bytes_; }
bool Allocator::ShouldBind() { return should_bind_; }
void Allocator::Init(const BoundedTopology& topology) {
line_bytes_ = DetectLineBytes();
vector_bytes_ = hwy::VectorBytes();
quantum_bytes_ = HWY_MAX(line_bytes_, vector_bytes_); // may overwrite below
if (const hwy::Cache* caches = hwy::DataCaches()) {
l1_bytes_ = caches[1].size_kib << 10;
l2_bytes_ = caches[2].size_kib << 10;
} else { // Unknown, make reasonable assumptions.
const BoundedTopology::Cluster& cluster = topology.GetCluster(0, 0);
l1_bytes_ = 32 << 10;
l2_bytes_ = (cluster.PrivateKiB() ? cluster.PrivateKiB() : 256) << 10;
}
// Prerequisites for binding:
// - supported by the OS (currently Linux only),
// - the page size is known and 'reasonably small', preferably less than
// a fraction of MatMul row/col sizes, which for 27B are up to 144 KiB.
// - we successfully detected topology and there are multiple nodes;
// - there are multiple packages, because we shard by package_idx.
if constexpr (GEMMA_BIND) {
const size_t page_bytes = DetectPageSize();
if ((page_bytes != 0 && page_bytes <= 16 * 1024) &&
topology.NumNodes() > 1 && topology.NumPackages() > 1) {
// Ensure pages meet the alignment requirements of `AllocBytes`.
HWY_ASSERT(page_bytes >= quantum_bytes_);
quantum_bytes_ = page_bytes;
should_bind_ = true;
}
}
}
Allocator::PtrAndDeleter Allocator::AllocBytes(size_t bytes) {
// If we are not binding, the Highway allocator is cheaper than `mmap`, and
// defends against 2K aliasing.
if (!should_bind_) {
// Perf warning if Highway's alignment is less than we want.
if (HWY_ALIGNMENT < QuantumBytes()) {
HWY_WARN(
"HWY_ALIGNMENT %d < QuantumBytes %zu: either vector or cache lines "
"are huge, enable GEMMA_BIND to avoid this warning.",
HWY_ALIGNMENT, QuantumBytes());
}
auto p = hwy::AllocateAligned<uint8_t>(bytes);
// The `hwy::AlignedFreeUniquePtr` deleter is unfortunately specific to the
// alignment scheme in aligned_allocator.cc and does not work for
// already-aligned pointers as returned by `mmap`, hence we wrap the Highway
// pointer in our own deleter.
auto call_free = [](void* ptr, size_t /*bytes*/) {
hwy::FreeAlignedBytes(ptr, nullptr, nullptr);
};
return PtrAndDeleter{p.release(), Deleter(call_free, bytes)};
}
// Binding, or large vector/cache line size: use platform-specific allocator.
#if HWY_OS_LINUX && !defined(__ANDROID_API__)
// `move_pages` is documented to require an anonymous/private mapping or
// `MAP_SHARED`. A normal allocation might not suffice, so we use `mmap`.
// `Init` verified that the page size is a multiple of `QuantumBytes()`.
const int prot = PROT_READ | PROT_WRITE;
const int flags = MAP_ANONYMOUS | MAP_PRIVATE;
const int fd = -1;
// Encourage transparent hugepages by rounding up to a multiple of 2 MiB.
bytes = hwy::RoundUpTo(bytes, 2ull << 20);
void* p = mmap(0, bytes, prot, flags, fd, off_t{0});
if (p == MAP_FAILED) p = nullptr;
const auto call_munmap = [](void* ptr, size_t bytes) {
const int ret = munmap(ptr, bytes);
HWY_ASSERT(ret == 0);
};
return PtrAndDeleter{p, Deleter(call_munmap, bytes)};
#elif HWY_OS_WIN
const auto call_free = [](void* ptr, void*) { _aligned_free(ptr); };
const size_t alignment = HWY_MAX(vector_bytes_, line_bytes_);
return PtrAndDeleter{_aligned_malloc(bytes, alignment),
Deleter(call_free, bytes)};
#else
return PtrAndDeleter{nullptr, Deleter(nullptr, 0)};
#endif
}
#if GEMMA_BIND && HWY_OS_LINUX
using Ret = long; // NOLINT(runtime/int)
using UL = unsigned long; // NOLINT(runtime/int)
@ -76,90 +191,91 @@ struct SyscallWrappers {
MaybeCheckInitialized(status, count * sizeof(int));
return syscall(__NR_move_pages, pid, count, pages, nodes, status, flags);
}
static Ret get_mempolicy(int* mode, UL* nodes, UL max_node, void* addr,
unsigned flags) {
return syscall(__NR_get_mempolicy, mode, nodes, max_node, addr, flags);
}
};
// Returns the number of pages that are currently busy (hence not yet moved),
// and warns if there are any other reasons for not moving a page. Note that
// `move_pages` can return 0 regardless of whether all pages were moved.
size_t CountBusyPages(size_t num_pages, size_t node, void** pages,
const int* status) {
// Return value 0 does not actually guarantee all pages were moved.
size_t num_busy = 0;
for (size_t i = 0; i < num_pages; ++i) {
if (status[i] == -EBUSY) {
++num_busy;
// Touch
hwy::ZeroBytes(pages[i], 8);
} else if (status[i] != static_cast<int>(node)) {
fprintf(stderr, "Error %d moving pages[%zu]=%p to node %zu (errno %d)\n",
status[i], i, pages[i], node, errno);
static std::atomic_flag first = ATOMIC_FLAG_INIT;
if (!first.test_and_set()) {
HWY_WARN("Error %d moving pages[%zu]=%p to node %zu (errno %d).",
status[i], i, pages[i], node, errno);
}
}
}
return num_busy;
}
// Attempts to move(!) memory to the given NUMA node, typically obtained from
// `BoundedTopology::GetCluster(package_idx, cluster_idx).node`. Using `mbind`
// directly is easier than calling libnuma's `numa_move_pages`, which requires
// an array of pages. Note that `numa_tonode_memory` is insufficient because
// it does not specify the `MPOL_MF_MOVE` flag, so it only sets the policy,
// which means it would have to be called before pages are faulted in, but
// `aligned_allocator.h` modifies the first bytes for its bookkeeping.
// May overwrite some of the memory with zeros.
void BindMemory(void* ptr, size_t bytes, size_t node) {
bool Allocator::BindMemory(void* ptr, size_t bytes, size_t node) {
HWY_DASSERT(should_bind_);
constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough"
if constexpr (HWY_IS_DEBUG_BUILD) {
// Ensure the requested `node` is allowed.
UL nodes[kMaxNodes / 64] = {0};
const unsigned flags = 4; // MPOL_F_MEMS_ALLOWED
HWY_ASSERT(SyscallWrappers::get_mempolicy(nullptr, nodes, kMaxNodes,
nullptr, flags) == 0);
HWY_ASSERT(nodes[node / 64] & (1ull << (node % 64)));
}
// Avoid mbind because it does not report why it failed, which is most likely
// because pages are busy, in which case we want to know which.
#if 0
// nodemask with only the given node set.
UL nodes[hwy::DivCeil(kMaxNodes, ULBits)] = {};
nodes[node / ULBits] = 1ULL << (node % ULBits);
const int mode = 2; // MPOL_BIND
const unsigned flags = 3; // MPOL_MF_MOVE | MPOL_MF_STRICT
const int ret =
SyscallWrappers::mbind(ptr, bytes, mode, nodes, kMaxNodes, flags);
if (ret != 0) {
fprintf(stderr, "Failed to bind %p %zu to node %zu (errno %d)\n", ptr,
bytes, node, errno);
}
#elif 1
// `MPOL_MF_MOVE_ALL` requires cap sys_nice, which is not easy to set.
const unsigned flags = 2; // MPOL_MF_MOVE
const size_t bytes_per_page = static_cast<size_t>(sysconf(_SC_PAGESIZE));
HWY_ASSERT(bytes % bytes_per_page == 0);
const size_t num_pages = bytes / bytes_per_page;
HWY_ASSERT(bytes % quantum_bytes_ == 0);
const size_t num_pages = bytes / quantum_bytes_;
std::vector<void*> pages;
pages.reserve(num_pages);
for (size_t i = 0; i < num_pages; ++i) {
pages.push_back(static_cast<uint8_t*>(ptr) + i * bytes_per_page);
pages.push_back(static_cast<uint8_t*>(ptr) + i * quantum_bytes_);
// Ensure the page is faulted in to prevent `move_pages` from failing,
// because freshly allocated pages may be mapped to a shared 'zero page'.
hwy::ZeroBytes(pages.back(), 8);
}
std::vector<int> nodes(num_pages, node);
std::vector<int> status(num_pages, static_cast<int>(kMaxNodes));
Ret ret = SyscallWrappers::move_pages(
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
size_t num_busy =
CountBusyPages(num_pages, node, pages.data(), status.data());
if (num_busy != 0) {
// Try again
ret = SyscallWrappers::move_pages(
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
const size_t num_busy_before = num_busy;
num_busy = CountBusyPages(num_pages, node, pages.data(), status.data());
fprintf(
stderr,
"second try still %zu busy, was %zu. 2nd ret %d status %d %d %d %d\n",
num_busy, num_busy_before, static_cast<int>(ret), status[0], status[1],
status[2], status[3]);
if (ret < 0) {
HWY_WARN("Failed to bind %p %zu to node %zu (errno %d) status %d.", ptr,
bytes, node, errno, status[0]);
return false;
}
if (ret < 0) {
fprintf(stderr,
"Failed to bind %p %zu to node %zu (errno %d) status %d %d\n", ptr,
bytes, node, errno, status[0], status[1]);
const size_t num_busy =
CountBusyPages(num_pages, node, pages.data(), status.data());
if (HWY_UNLIKELY(num_busy != 0)) {
// Trying again is usually enough to succeed.
usleep(5); // NOLINT(runtime/sleep)
(void)SyscallWrappers::move_pages(
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
const size_t still_busy =
CountBusyPages(num_pages, node, pages.data(), status.data());
if (HWY_UNLIKELY(still_busy != 0)) {
HWY_WARN("BindMemory: %zu pages still busy after retrying %zu.",
still_busy, num_busy);
}
}
#endif
return true;
}
#else
// TODO: support other OSes.
void BindMemory(void*, size_t, size_t) {}
#endif // GEMMA_NUMA && HWY_OS_LINUX
bool Allocator::BindMemory(void*, size_t, size_t) { return false; }
#endif // GEMMA_BIND && HWY_OS_LINUX
} // namespace gcpp

View File

@ -19,114 +19,232 @@
#include <stddef.h>
#include <stdint.h>
#include <cstdlib> // std::aligned_alloc / _aligned_malloc
// IWYU pragma: begin_exports
#include <memory>
#include "util/basics.h"
#include "util/threading.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// IWYU pragma: end_exports
#ifndef GEMMA_NUMA
// The check below requires two #if, hence start with 0 and redefine to 1.
#define GEMMA_NUMA 0
// To avoid a dependency on libnuma, use syscalls directly. We require six
// arguments, which has been supported by glibc since around 2010.
#if defined(__GLIBC__) && defined(__GLIBC_PREREQ)
#if HWY_OS_LINUX && __GLIBC_PREREQ(2, 11)
#undef GEMMA_NUMA
#define GEMMA_NUMA 1
#endif
#endif
#endif // GEMMA_NUMA
#include "hwy/aligned_allocator.h"
namespace gcpp {
using ByteStorageT = hwy::AlignedFreeUniquePtr<uint8_t[]>;
// Points to an adapter lambda that calls `FreeAlignedBytes` or `munmap`. The
// `bytes` argument is required for the latter.
using FreeFunc = void (*)(void* mem, size_t bytes);
template <typename T>
ByteStorageT AllocateSizeof() {
return hwy::AllocateAligned<uint8_t>(sizeof(T));
}
// Stateful in order to know whether to bind to NUMA nodes. `Monostate` for
// convenience - avoids passing around a reference.
class Allocator {
// Custom deleter for std::unique_ptr that calls `FreeFunc`.
class Deleter {
public:
static void Init(const BoundedTopology& topology) {
bytes_per_page_ = DetectPageSize();
HWY_ASSERT(bytes_per_page_ <= (4 << 20));
// NUMA only makes sense if:
// - the page size is known and 'reasonably small', preferably less than
// a fraction of MatMul row/col sizes, which for 27B are up to 144 KiB.
// - we successfully detected topology and there are multiple nodes;
// - there are multiple packages, because we shard by package_idx.
use_numa_ = (bytes_per_page_ != 0 && bytes_per_page_ <= 16 * 1024) &&
topology.NumNodes() > 1 && topology.NumPackages() > 1;
// TODO: remove once tensors are page-aligned.
use_numa_ = false;
fprintf(stderr, "Warning: disabling use_numa_\n");
alignment_ = use_numa_ ? bytes_per_page_ : HWY_ALIGNMENT;
}
static bool UseNUMA() { return use_numa_; }
// BindTensor requires row pointers and lengths be a multiple of this.
static size_t Alignment() { return alignment_; }
// `MatStorageT` requires this to be default-constructible.
Deleter() : free_func_(nullptr), bytes_(0) {}
Deleter(FreeFunc free_func, size_t bytes)
: free_func_(free_func), bytes_(bytes) {}
template <typename T>
static hwy::AlignedFreeUniquePtr<T[]> Alloc(size_t num) {
// For non-NUMA, use the Highway allocator because it defends against 2k
// aliasing.
if (!use_numa_) return hwy::AllocateAligned<T>(num);
void operator()(T* p) const {
free_func_(p, bytes_);
}
private:
FreeFunc free_func_;
size_t bytes_;
};
// Unique (move-only) pointer to an aligned array of POD T.
template <typename T>
using AlignedPtr = std::unique_ptr<T[], Deleter>;
// Both allocation, binding, and row accessors depend on the sizes of memory
// pages and cache lines. To avoid having to pass `Allocator&` everywhere, we
// use `Monostate` (static members).
class Allocator {
public:
// Must be called at least once before any other function. Not thread-safe,
// hence only call this from the main thread.
static void Init(const BoundedTopology& topology);
// Bytes per cache line, or a reasonable guess if unknown. Used to choose
// ranges such that there will be no false sharing.
static size_t LineBytes();
// Bytes per full vector. Used to compute loop steps.
static size_t VectorBytes();
// Granularity of regions processed by different threads. Their start and
// length of regions should be divisible by this, which is at least
// `HWY_MAX(LineBytes(), VectorBytes())`.
static size_t QuantumBytes();
static size_t L1Bytes();
static size_t L2Bytes();
// Returns pointer aligned to `QuantumBytes()`.
template <typename T>
static AlignedPtr<T> Alloc(size_t num) {
constexpr size_t kSize = sizeof(T);
// Ensure the `bytes = num * kSize` computation did not overflow.
constexpr bool kIsPow2 = (kSize & (kSize - 1)) == 0;
constexpr size_t kBits = hwy::detail::ShiftCount(kSize);
static_assert(!kIsPow2 || (1ull << kBits) == kSize, "ShiftCount has a bug");
const size_t bytes = kIsPow2 ? num << kBits : num * kSize;
// Fail if the `bytes = num * kSize` computation overflowed.
const size_t check = kIsPow2 ? bytes >> kBits : bytes / kSize;
if (check != num) {
return hwy::AlignedFreeUniquePtr<T[]>(); // overflowed
}
if (check != num) return AlignedPtr<T>();
// AlignedFreeUniquePtr has a deleter that can call an arbitrary `free`, but
// with an extra opaque pointer, which we discard via `call_free`.
#if defined(__ANDROID_API__) && __ANDROID_API__ < 28
const auto call_free = [](void* ptr, void*) { std::free(ptr); };
void* mem = nullptr;
int err = posix_memalign(&mem, Alignment(), bytes);
HWY_ASSERT(err == 0);
T* p = static_cast<T*>(mem);
#elif HWY_OS_WIN
const auto call_free = [](void* ptr, void*) { _aligned_free(ptr); };
T* p = static_cast<T*>(_aligned_malloc(bytes, Alignment()));
#else
const auto call_free = [](void* ptr, void*) { std::free(ptr); };
T* p = static_cast<T*>(std::aligned_alloc(Alignment(), bytes));
#endif
return hwy::AlignedFreeUniquePtr<T[]>(
p, hwy::AlignedFreer(call_free, nullptr));
PtrAndDeleter pd = AllocBytes(bytes);
return AlignedPtr<T>(static_cast<T*>(pd.p), pd.deleter);
}
// Returns whether `BindMemory` can/should be called, i.e. we have page-level
// control over memory placement and multiple packages and NUMA nodes.
static bool ShouldBind();
// Attempts to move(!) `[p, p + bytes)` to the given NUMA node, which is
// typically `BoundedTopology::GetCluster(package_idx, cluster_idx).node`.
// Writes zeros to SOME of the memory. Only call if `ShouldBind()`.
// `p` and `bytes` must be multiples of `QuantumBytes()`.
static bool BindMemory(void* p, size_t bytes, size_t node);
private:
// Type-erased so this can be implemented in allocator.cc.
struct PtrAndDeleter {
void* p;
Deleter deleter;
};
static PtrAndDeleter AllocBytes(size_t bytes);
};
// Owns dynamically-allocated aligned memory for a batch of row vectors.
// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns
// the memory.
template <typename T>
class RowVectorBatch {
public:
// Default ctor for Activations ctor.
RowVectorBatch() = default;
// Main ctor, called from Activations::Allocate. If `stride` = 0, the default,
// we default to tightly packed rows (`stride = cols`).
// WARNING: not all call sites support `stride` != cols.
RowVectorBatch(Extents2D extents, size_t stride = 0) : extents_(extents) {
if (stride == 0) {
stride_ = extents_.cols;
} else {
HWY_ASSERT(stride >= extents_.cols);
stride_ = stride;
}
mem_ = Allocator::Alloc<T>(extents_.rows * stride_);
}
// Move-only
RowVectorBatch(RowVectorBatch&) noexcept = delete;
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
RowVectorBatch(RowVectorBatch&&) noexcept = default;
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
size_t BatchSize() const { return extents_.rows; }
size_t Cols() const { return extents_.cols; }
size_t Stride() const { return stride_; }
Extents2D Extents() const { return extents_; }
// Returns the given row vector of length `Cols()`.
T* Batch(size_t batch_idx) {
HWY_DASSERT(batch_idx < BatchSize());
return mem_.get() + batch_idx * stride_;
}
const T* Batch(size_t batch_idx) const {
HWY_DASSERT(batch_idx < BatchSize());
return mem_.get() + batch_idx * stride_;
}
// For MatMul or other operations that process the entire batch at once.
// TODO: remove once we only use Mat.
T* All() { return mem_.get(); }
const T* Const() const { return mem_.get(); }
size_t NumBytes() const { return BatchSize() * stride_ * sizeof(T); }
private:
AlignedPtr<T> mem_;
Extents2D extents_;
size_t stride_;
};
// Returns `num` rounded up to an odd number of cache lines. This is used to
// compute strides. An odd number of cache lines prevents 2K aliasing and is
// coprime with the cache associativity, which reduces conflict misses.
template <typename T>
static HWY_INLINE size_t RoundUpToOddLines(size_t num, size_t line_bytes) {
HWY_DASSERT(line_bytes >= 32);
HWY_DASSERT(line_bytes % sizeof(T) == 0);
const size_t lines = hwy::DivCeil(num * sizeof(T), line_bytes);
const size_t padded_num = (lines | 1) * line_bytes / sizeof(T);
HWY_DASSERT(padded_num >= num);
return padded_num;
}
// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because
// it is always float and does not support compressed T, but does support an
// arbitrary stride >= cols.
#pragma pack(push, 1) // power of two size
template <typename T>
class RowPtr {
public:
RowPtr() = default; // for `MMPtrs`.
RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride)
: row0_(row0),
stride_(stride),
step_(static_cast<uint32_t>(
HWY_MAX(Allocator::LineBytes(), Allocator::VectorBytes()))),
cols_(static_cast<uint32_t>(cols)),
row_mask_(Allocator::QuantumBytes() / step_ - 1) {
HWY_DASSERT(stride >= cols);
HWY_DASSERT(row_mask_ != ~size_t{0});
row_mask_ = 0; // TODO: remove
}
RowPtr(T* HWY_RESTRICT row0, size_t cols) : RowPtr(row0, cols, cols) {}
T* HWY_RESTRICT Row(size_t r) const {
// How much of the previous row's padding to consume.
const size_t pad_bytes = (r & row_mask_) * step_;
HWY_DASSERT(pad_bytes < Allocator::QuantumBytes());
return row0_ + stride_ * r - pad_bytes;
}
size_t Cols() const { return cols_; }
size_t Stride() const { return stride_; }
void SetStride(size_t stride) {
HWY_DASSERT(stride >= Cols());
stride_ = stride;
// The caller might not have padded enough, so disable the padding in Row().
// Rows will now be exactly `stride` elements apart. This is used when
// writing to the KV cache via MatMul.
row_mask_ = 0;
}
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
RowPtr<T> View(size_t r, size_t c, size_t cols) const {
HWY_DASSERT(c < cols_);
HWY_DASSERT(cols <= cols_ - c);
return RowPtr<T>(Row(r) + c, cols, stride_);
}
private:
static size_t DetectPageSize();
// Required for BindMemory. Usually 4K, but can differ on Arm.
static size_t bytes_per_page_;
static bool use_numa_;
static size_t alignment_;
T* HWY_RESTRICT row0_;
size_t stride_;
uint32_t step_; // Copy from Allocator::LineBytes() to improve locality.
uint32_t cols_;
size_t row_mask_;
};
#pragma pack(pop)
// For future NUMA support. TODO: use.
void BindMemory(void* ptr, size_t bytes, size_t node);
using RowPtrBF = RowPtr<BF16>;
using RowPtrF = RowPtr<float>;
using RowPtrD = RowPtr<double>;
// For C argument to MatMul.
template <typename T>
RowPtr<T> RowPtrFromBatch(RowVectorBatch<T>& row_vectors) {
return RowPtr<T>(row_vectors.All(), row_vectors.Cols(), row_vectors.Stride());
}
} // namespace gcpp

View File

@ -64,8 +64,8 @@ struct TokenAndProb {
// Entire size of a 2D array.
struct Extents2D {
Extents2D() : rows(0), cols(0) {}
Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) {
constexpr Extents2D() : rows(0), cols(0) {}
constexpr Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) {
HWY_DASSERT(rows != 0);
HWY_DASSERT(cols != 0);
}
@ -77,6 +77,7 @@ struct Extents2D {
};
struct IndexRange {
IndexRange() = default;
IndexRange(size_t begin, size_t end) : begin_(begin), end_(end) {
HWY_DASSERT(begin < end);
}
@ -113,144 +114,6 @@ static inline IndexRange MakeIndexRange(size_t begin, size_t end,
size_t max_size) {
return IndexRange(begin, HWY_MIN(begin + max_size, end));
}
// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because
// it is always float and does not support compressed T, but does support an
// arbitrary stride >= cols.
template <typename T>
class RowPtr {
public:
RowPtr(T* HWY_RESTRICT row0, size_t cols)
: row0_(row0), cols_(cols), stride_(cols) {}
RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride)
: row0_(row0), cols_(cols), stride_(stride) {
HWY_DASSERT(stride >= cols);
}
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
size_t Cols() const { return cols_; }
size_t Stride() const { return stride_; }
void SetStride(size_t stride) {
HWY_DASSERT(stride >= Cols());
stride_ = stride;
}
private:
T* HWY_RESTRICT row0_;
size_t stride_;
size_t cols_;
};
using RowPtrF = RowPtr<float>;
// Owns dynamically-allocated aligned memory for a batch of row vectors.
// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns
// the memory.
template <typename T>
class RowVectorBatch {
public:
// Default ctor for Activations ctor.
RowVectorBatch() = default;
// Main ctor, called from Activations::Allocate.
RowVectorBatch(Extents2D extents) : extents_(extents) {
mem_ = hwy::AllocateAligned<T>(extents_.rows * extents_.cols);
}
// Move-only
RowVectorBatch(RowVectorBatch&) noexcept = delete;
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
RowVectorBatch(RowVectorBatch&&) noexcept = default;
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
size_t BatchSize() const { return extents_.rows; }
size_t Cols() const { return extents_.cols; }
Extents2D Extents() const { return extents_; }
// Returns the given row vector of length `Cols()`.
T* Batch(size_t batch_idx) {
HWY_DASSERT(batch_idx < BatchSize());
return mem_.get() + batch_idx * Cols();
}
const T* Batch(size_t batch_idx) const {
HWY_DASSERT(batch_idx < BatchSize());
return mem_.get() + batch_idx * Cols();
}
// For MatMul or other operations that process the entire batch at once.
// TODO: remove once we only use Mat.
T* All() { return mem_.get(); }
const T* Const() const { return mem_.get(); }
size_t NumBytes() const { return BatchSize() * Cols() * sizeof(T); }
private:
hwy::AlignedFreeUniquePtr<T[]> mem_;
Extents2D extents_;
};
// Used for the A and B arguments of `MatMul`, which are always const.
// Create via MakeConstMat. This differs from `RowPtr` in that it supports the
// `ofs` required for compressed T.
template <typename T>
struct ConstMat {
ConstMat(const T* ptr, Extents2D extents, size_t ofs = 0)
: ptr(ptr), extents(extents), ofs(ofs) {
HWY_DASSERT(ptr != nullptr);
}
// TODO: support stride for page alignment.
size_t Row(size_t r) const {
if constexpr (HWY_IS_DEBUG_BUILD) {
if (r >= extents.rows) {
HWY_ABORT("ConstMat::Row %zu out of bounds %zu", r, extents.rows);
}
}
return ofs + extents.cols * r;
}
const Extents2D& Extents() const { return extents; }
size_t Stride() const { return extents.cols; }
// Shrinks the row-extent of this matrix view, i.e. reduces the view to a
// subrange of the original rows starting at row 0.
void ShrinkRows(size_t rows) {
HWY_ASSERT(rows <= extents.rows);
extents.rows = rows;
}
const T* HWY_RESTRICT ptr;
Extents2D extents;
// `scale` allows expanding the smaller range of `SfpStream` to the original
// values. MatFromWeights sets this from `MatPtr`.
float scale = 1.0f;
// Offset to add to `ptr`; separate because T=NuqStream does not support
// pointer arithmetic.
size_t ofs;
};
// For deducing T.
template <typename T>
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents,
size_t ofs = 0) {
return ConstMat<T>(ptr, extents, ofs);
}
// For A argument to MatMul (activations).
template <typename T>
ConstMat<T> ConstMatFromBatch(size_t batch_size,
const RowVectorBatch<T>& row_vectors) {
HWY_DASSERT(batch_size <= row_vectors.BatchSize());
return MakeConstMat(const_cast<T*>(row_vectors.Const()),
Extents2D(batch_size, row_vectors.Cols()));
}
// For C argument to MatMul.
template <typename T>
RowPtr<T> RowPtrFromBatch(RowVectorBatch<T>& row_vectors) {
return RowPtr<T>(row_vectors.All(), row_vectors.Cols());
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_

View File

@ -55,10 +55,8 @@ class Pinning {
LPS enabled_lps;
if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) {
const size_t num_lps = hwy::TotalLogicalProcessors();
fprintf(
stderr,
"Warning, unknown OS affinity, considering all %zu LPs enabled\n.",
num_lps);
HWY_WARN("unknown OS affinity, considering all %zu LPs enabled.",
num_lps);
for (size_t lp = 0; lp < num_lps; ++lp) {
enabled_lps.Set(lp);
}
@ -71,8 +69,7 @@ class Pinning {
const size_t lp = enabled_lps.First();
enabled_lps = LPS();
enabled_lps.Set(lp);
fprintf(stderr,
"Warning, threads not supported, using only the main thread\n.");
HWY_WARN("Warning, threads not supported, using only the main thread.");
}
original_affinity_ = enabled_lps;
@ -155,23 +152,10 @@ BoundedTopology::BoundedTopology(BoundedSlice package_slice,
HWY_ASSERT(NumPackages() != 0 && NumClusters(0) != 0 && NumNodes() != 0);
}
// Topology is unknown, rely on OS affinity and user-specified slice.
BoundedTopology::Cluster::Cluster(const LPS& enabled_lps,
BoundedSlice lp_slice) {
// Interpret `lp_slice` as a slice of the 1-bits of `enabled_lps`, so
// we honor both the OS affinity and the user-specified slice. Note that
// this can be used to exclude hyperthreads because Linux groups LPs by
// sibling index. For example, the first `num_cores` are not siblings.
const size_t detected = enabled_lps.Count();
size_t enabled_idx = 0;
enabled_lps.Foreach([&](size_t lp) {
if (lp_slice.Contains(detected, enabled_idx++)) {
AddLP(lp);
}
});
// lp_slice can only reduce the number of `enabled_lps`, and not below 1.
HWY_ASSERT(num_workers_ != 0);
// Topology is unknown, take the given set of LPs.
BoundedTopology::Cluster::Cluster(const LPS& lps) {
lps_ = lps;
num_workers_ = lps.Count();
}
BoundedTopology::Cluster::Cluster(const LPS& enabled_lps,
@ -183,7 +167,9 @@ BoundedTopology::Cluster::Cluster(const LPS& enabled_lps,
// Skip if not first-hyperthread or disabled.
if (all_lps[lp].smt != 0 || !enabled_lps.Get(lp)) return;
AddLP(lp);
HWY_ASSERT(!lps_.Get(lp)); // Foreach ensures uniqueness
lps_.Set(lp);
++num_workers_;
// Set fields once, and ensure subsequent LPs match - we assume there
// is only one NUMA node per cluster, with the same L2/L3 size.
@ -198,30 +184,63 @@ BoundedTopology::Cluster::Cluster(const LPS& enabled_lps,
if (HWY_LIKELY(!warned)) {
if (HWY_UNLIKELY(lp_node != node_)) {
warned = true;
fprintf(stderr, "WARNING: lp %zu on node %zu != cluster node %zu.\n",
lp, lp_node, node_);
HWY_WARN("lp %zu on node %zu != cluster node %zu.", lp, lp_node,
node_);
}
if (HWY_UNLIKELY(private_kib_ != tcluster.private_kib)) {
warned = true;
fprintf(stderr, "WARNING: lp %zu private_kib %zu != cluster %zu.\n",
lp, private_kib_, tcluster.private_kib);
HWY_WARN("lp %zu private_kib %zu != cluster %zu.", lp, private_kib_,
tcluster.private_kib);
}
if (HWY_UNLIKELY(shared_kib_ != tcluster.shared_kib)) {
warned = true;
fprintf(stderr, "WARNING: lp %zu shared_kib %zu != cluster %zu.\n",
lp, shared_kib_, tcluster.shared_kib);
HWY_WARN("lp %zu shared_kib %zu != cluster %zu.", lp, shared_kib_,
tcluster.shared_kib);
}
} // !warned
}
});
}
// CPUs without clusters are rarely more than dozens of cores, and 6 is a
// decent number of threads in a per-cluster pool.
constexpr bool kSplitLargeClusters = false;
constexpr size_t kMaxClusters = 8;
constexpr size_t kMaxLPsPerCluster = 6;
// Topology is unknown, rely on OS affinity and user-specified slice.
BoundedTopology::Package::Package(const LPS& enabled_lps,
BoundedSlice lp_slice) {
LPS clusters_lps[kMaxClusters];
const size_t num_clusters =
kSplitLargeClusters
? HWY_MIN(kMaxClusters,
hwy::DivCeil(enabled_lps.Count(), kMaxLPsPerCluster))
: 1;
// Interpret `lp_slice` as a slice of the 1-bits of `enabled_lps`, so
// we honor both the OS affinity and the user-specified slice. Note that
// this can be used to exclude hyperthreads because Linux groups LPs by
// sibling index. For example, the first `num_cores` are not siblings.
const size_t detected = enabled_lps.Count();
size_t enabled_idx = 0;
enabled_lps.Foreach([&](size_t lp) {
if (lp_slice.Contains(detected, enabled_idx)) {
clusters_lps[enabled_idx % num_clusters].Set(lp);
}
++enabled_idx;
});
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
clusters.push_back(Cluster(clusters_lps[cluster_idx]));
}
}
// NOTE: caller is responsible for checking whether `clusters` is empty.
BoundedTopology::Package::Package(const LPS& enabled_lps,
const hwy::Topology& topology,
size_t package_idx,
const hwy::Topology& topology, size_t pkg_idx,
BoundedSlice cluster_slice) {
const hwy::Topology::Package& tpackage = topology.packages[package_idx];
const hwy::Topology::Package& tpackage = topology.packages[pkg_idx];
// Populate `clusters` with the subset of clusters in `cluster_slice` that
// have any enabled LPs. If `clusters` remains empty, the caller will
// skip this `Package`.
@ -233,10 +252,34 @@ BoundedTopology::Package::Package(const LPS& enabled_lps,
// Skip if empty, i.e. too few `enabled_lps`.
if (HWY_LIKELY(cluster.Size() != 0)) {
clusters.push_back(std::move(cluster));
clusters.push_back(cluster);
}
});
SortByDescendingSize(clusters);
// If there is only one large cluster, split it into smaller ones.
if (kSplitLargeClusters && clusters.size() == 1 &&
enabled_lps.Count() >= 16) {
const LPS lps = clusters[0].LPSet(); // copy so we can clear
clusters.clear();
// Split `lps` into several clusters.
LPS clusters_lps[kMaxClusters];
const size_t num_clusters =
HWY_MIN(kMaxClusters, hwy::DivCeil(lps.Count(), kMaxLPsPerCluster));
size_t num_lps = 0;
lps.Foreach(
[&](size_t lp) { clusters_lps[num_lps++ % num_clusters].Set(lp); });
HWY_DASSERT(num_lps == lps.Count());
// Create new clusters, just inserting the new LPS.
hwy::Topology::Cluster tcluster = tpackage.clusters[0]; // modifiable copy
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
tcluster.lps = clusters_lps[cluster_idx];
// Keep same `private_kib` and `shared_kib`.
clusters.push_back(Cluster(enabled_lps, topology.lps, tcluster));
}
}
}
#if !GEMMA_DISABLE_TOPOLOGY
@ -256,10 +299,9 @@ static void ScanTClusters(hwy::Topology& topology_, size_t& max_tclusters,
max_tclusters = 0;
max_tcluster_cores = 0;
max_tcluster_lps = 0;
for (size_t package_idx = 0; package_idx < topology_.packages.size();
++package_idx) {
for (size_t pkg_idx = 0; pkg_idx < topology_.packages.size(); ++pkg_idx) {
const std::vector<hwy::Topology::Cluster>& tclusters =
topology_.packages[package_idx].clusters;
topology_.packages[pkg_idx].clusters;
max_tclusters = HWY_MAX(max_tclusters, tclusters.size());
size_t tcluster_cores = 0;
size_t tcluster_lps = 0;
@ -272,10 +314,10 @@ static void ScanTClusters(hwy::Topology& topology_, size_t& max_tclusters,
}
if (tclusters.size() > 1 && tcluster_cores > 8) {
fprintf(stderr,
"Package %zu: multiple clusters with max size %zu, whereas CCX "
"only have 8, may indicate a bug in hwy::Topology.\n",
package_idx, tcluster_cores);
HWY_WARN(
"Package %zu: multiple clusters with max size %zu, whereas CCX "
"only have 8, may indicate a bug in hwy::Topology.",
pkg_idx, tcluster_cores);
}
max_tcluster_cores = HWY_MAX(max_tcluster_cores, tcluster_cores);
max_tcluster_lps = HWY_MAX(max_tcluster_lps, tcluster_lps);
@ -294,8 +336,8 @@ void BoundedTopology::InitFromTopology(const LPS& enabled_lps,
// (Possibly empty) subset of `Topology` packages that have `enabled_lps`.
package_slice.Foreach(
"package", topology_.packages.size(), [&](size_t package_idx) {
Package package(enabled_lps, topology_, package_idx, cluster_slice);
"package", topology_.packages.size(), [&](size_t pkg_idx) {
Package package(enabled_lps, topology_, pkg_idx, cluster_slice);
// Skip if empty, i.e. too few `enabled_lps`.
if (HWY_LIKELY(!package.clusters.empty())) {
packages_.push_back(std::move(package));
@ -313,18 +355,18 @@ void BoundedTopology::InitFromTopology(const LPS& enabled_lps,
// Scan for max BoundedTopology clusters and their size, for topology_string_.
size_t all_max_cluster_size = 0;
for (size_t package_idx = 0; package_idx < NumPackages(); ++package_idx) {
for (size_t pkg_idx = 0; pkg_idx < NumPackages(); ++pkg_idx) {
size_t max_cluster_size = 0;
for (size_t cluster_idx = 0; cluster_idx < NumClusters(package_idx);
for (size_t cluster_idx = 0; cluster_idx < NumClusters(pkg_idx);
++cluster_idx) {
max_cluster_size = HWY_MAX(max_cluster_size,
GetCluster(package_idx, cluster_idx).Size());
max_cluster_size =
HWY_MAX(max_cluster_size, GetCluster(pkg_idx, cluster_idx).Size());
}
if (NumClusters(package_idx) > 1 && max_cluster_size > 8) {
fprintf(stderr,
"Package %zu: multiple clusters with max size %zu, whereas CCX "
"only have 8, may indicate a bug in BoundedTopology.\n",
package_idx, max_cluster_size);
if (NumClusters(pkg_idx) > 1 && max_cluster_size > 8) {
HWY_WARN(
"Package %zu: multiple clusters with max size %zu, whereas CCX "
"only have 8, may indicate a bug in BoundedTopology.",
pkg_idx, max_cluster_size);
}
all_max_cluster_size = HWY_MAX(all_max_cluster_size, max_cluster_size);
}
@ -382,10 +424,10 @@ NestedPools::NestedPools(size_t max_threads, Tristate pin,
// calling thread of an all_clusters->Run, and hence pinned to one of the
// `cluster.lps` if `pin`.
all_packages_->Run(
0, all_packages_->NumWorkers(), [&](uint64_t package_idx, size_t thread) {
HWY_ASSERT(package_idx == thread); // each thread has one task
packages_[package_idx] =
Package(topology_, package_idx, max_workers_per_package, lp_slice);
0, all_packages_->NumWorkers(), [&](uint64_t pkg_idx, size_t thread) {
HWY_ASSERT(pkg_idx == thread); // each thread has one task
packages_[pkg_idx] =
Package(topology_, pkg_idx, max_workers_per_package, lp_slice);
});
all_pinned_ = GetPinning().AllPinned(&pin_string_);
@ -405,12 +447,11 @@ NestedPools::NestedPools(size_t max_threads, Tristate pin,
HWY_ASSERT(max_workers_per_cluster_ <= 256);
}
NestedPools::Package::Package(const BoundedTopology& topology,
size_t package_idx,
NestedPools::Package::Package(const BoundedTopology& topology, size_t pkg_idx,
size_t max_workers_per_package,
BoundedSlice lp_slice) {
// Pre-allocate because elements are set concurrently.
clusters_.resize(topology.NumClusters(package_idx));
clusters_.resize(topology.NumClusters(pkg_idx));
const size_t max_workers_per_cluster =
DivideMaxAcross(max_workers_per_package, clusters_.size());
@ -421,7 +462,7 @@ NestedPools::Package::Package(const BoundedTopology& topology,
0, all_clusters_->NumWorkers(), [&](size_t cluster_idx, size_t thread) {
HWY_ASSERT(cluster_idx == thread); // each thread has one task
const BoundedTopology::Cluster& cluster =
topology.GetCluster(package_idx, cluster_idx);
topology.GetCluster(pkg_idx, cluster_idx);
clusters_[cluster_idx] =
MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster));
// Pin workers AND the calling thread from `all_clusters`.

View File

@ -108,7 +108,7 @@ class BoundedTopology {
class Cluster {
public:
Cluster(const LPS& enabled_lps, BoundedSlice lp_slice);
Cluster(const LPS& lps);
Cluster(const LPS& enabled_lps,
const std::vector<hwy::Topology::LP>& all_lps,
const hwy::Topology::Cluster& tcluster);
@ -124,17 +124,12 @@ class BoundedTopology {
return lps;
}
const LPS& LPSet() const { return lps_; }
size_t Node() const { return node_; }
size_t PrivateKiB() const { return private_kib_; }
size_t SharedKiB() const { return shared_kib_; }
private:
void AddLP(size_t lp) {
HWY_ASSERT(!lps_.Get(lp)); // Foreach ensures uniqueness
lps_.Set(lp);
++num_workers_;
}
// Enabled LPs; if topology is known, only the ones in this cluster.
LPS lps_;
// How many workers in the per-cluster pool. If 0, this Cluster is removed.
@ -147,19 +142,19 @@ class BoundedTopology {
size_t shared_kib_ = 0;
}; // Cluster
size_t NumClusters(size_t package_idx) const {
HWY_ASSERT(package_idx < NumPackages());
return packages_[package_idx].clusters.size();
size_t NumClusters(size_t pkg_idx) const {
HWY_ASSERT(pkg_idx < NumPackages());
return packages_[pkg_idx].clusters.size();
}
const Cluster& GetCluster(size_t package_idx, size_t cluster_idx) const {
HWY_ASSERT(package_idx < NumPackages());
const Package& package = packages_[package_idx];
const Cluster& GetCluster(size_t pkg_idx, size_t cluster_idx) const {
HWY_ASSERT(pkg_idx < NumPackages());
const Package& package = packages_[pkg_idx];
HWY_ASSERT(cluster_idx < package.clusters.size());
return package.clusters[cluster_idx];
}
Cluster& GetCluster(size_t package_idx, size_t cluster_idx) {
HWY_ASSERT(package_idx < NumPackages());
Package& package = packages_[package_idx];
Cluster& GetCluster(size_t pkg_idx, size_t cluster_idx) {
HWY_ASSERT(pkg_idx < NumPackages());
Package& package = packages_[pkg_idx];
HWY_ASSERT(cluster_idx < package.clusters.size());
return package.clusters[cluster_idx];
}
@ -170,13 +165,9 @@ class BoundedTopology {
private:
struct Package {
// Topology is unknown, rely on OS affinity and user-specified slice.
Package(const LPS& enabled_lps, BoundedSlice lp_slice) {
clusters.push_back(Cluster(enabled_lps, lp_slice));
}
Package(const LPS& enabled_lps, BoundedSlice lp_slice);
Package(const LPS& enabled_lps, const hwy::Topology& topology,
size_t package_idx, BoundedSlice cluster_slice);
size_t pkg_idx, BoundedSlice cluster_slice);
// For SortByDescendingSize.
size_t Size() const { return clusters.size(); }
@ -257,33 +248,36 @@ class NestedPools {
}
}
size_t NumPackages() const { return packages_.size(); }
hwy::ThreadPool& AllPackages() { return *all_packages_; }
hwy::ThreadPool& AllClusters(size_t package_idx) {
HWY_DASSERT(package_idx < packages_.size());
return packages_[package_idx].AllClusters();
hwy::ThreadPool& AllClusters(size_t pkg_idx) {
HWY_DASSERT(pkg_idx < NumPackages());
return packages_[pkg_idx].AllClusters();
}
hwy::ThreadPool& Cluster(size_t package_idx, size_t cluster_idx) {
HWY_DASSERT(package_idx < packages_.size());
return packages_[package_idx].Cluster(cluster_idx);
hwy::ThreadPool& Cluster(size_t pkg_idx, size_t cluster_idx) {
HWY_DASSERT(pkg_idx < NumPackages());
return packages_[pkg_idx].Cluster(cluster_idx);
}
// For binding to NUMA nodes.
size_t Node(size_t package_idx, size_t cluster_idx) const {
return topology_.GetCluster(package_idx, cluster_idx).Node();
size_t Node(size_t pkg_idx, size_t cluster_idx) const {
return topology_.GetCluster(pkg_idx, cluster_idx).Node();
}
// Reasonably tight upper bound for allocating thread-local storage (TLS).
size_t MaxWorkers() const {
return packages_.size() * max_clusters_per_package_ *
max_workers_per_cluster_;
// Reasonably tight upper bounds for allocating thread-local storage (TLS).
size_t MaxWorkersPerCluster() const { return max_workers_per_cluster_; }
size_t MaxWorkersPerPackage() const {
return max_clusters_per_package_ * MaxWorkersPerCluster();
}
// Returns the first of `cluster.NumWorkers()` TLS indices, to which callers
// add the worker index given by `cluster.Run`.
size_t WorkerOffset(size_t package_idx, size_t cluster_idx) const {
HWY_DASSERT(package_idx < packages_.size());
HWY_DASSERT(cluster_idx < packages_[package_idx].NumClusters());
return (package_idx * max_clusters_per_package_ + cluster_idx) *
max_workers_per_cluster_;
size_t MaxWorkers() const { return NumPackages() * MaxWorkersPerPackage(); }
// Actual number of workers.
size_t TotalWorkers() const {
size_t total_workers = 0;
for (size_t pkg_idx = 0; pkg_idx < NumPackages(); ++pkg_idx) {
total_workers += packages_[pkg_idx].TotalWorkers();
}
return total_workers;
}
// For Allocator
@ -296,20 +290,20 @@ class NestedPools {
// if there is more than one, which maximizes available memory bandwidth, or
// the first cluster, which is typically the whole package. For use by callers
// that only have a single parallel-for.
hwy::ThreadPool& Pool(size_t package_idx = 0) {
hwy::ThreadPool& Pool(size_t pkg_idx = 0) {
// Only one cluster: use its pool, typically a whole socket.
if (AllClusters(package_idx).NumWorkers() == 1) {
return Cluster(package_idx, 0);
if (AllClusters(pkg_idx).NumWorkers() == 1) {
return Cluster(pkg_idx, 0);
}
// One worker per cluster to maximize bandwidth availability.
return AllClusters(package_idx);
return AllClusters(pkg_idx);
}
private:
class Package {
public:
Package() = default; // for vector
Package(const BoundedTopology& topology, size_t package_idx,
Package(const BoundedTopology& topology, size_t pkg_idx,
size_t max_workers_per_package, BoundedSlice lp_slice);
size_t NumClusters() const { return clusters_.size(); }
@ -321,6 +315,13 @@ class NestedPools {
}
return max_workers_per_cluster;
}
size_t TotalWorkers() const {
size_t total_workers = 0;
for (const PoolPtr& cluster : clusters_) {
total_workers += cluster->NumWorkers();
}
return total_workers;
}
hwy::ThreadPool& AllClusters() { return *all_clusters_; }
hwy::ThreadPool& Cluster(size_t cluster_idx) {
@ -365,32 +366,34 @@ class NestedPools {
// functions below.
class IndexRangePartition {
public:
IndexRangePartition() = default; // for MMPartitions
IndexRangePartition(const IndexRange& range, const size_t task_size)
: range_(range), task_size_(task_size) {
const size_t num = range.Num();
: range_(range), task_size_(static_cast<uint32_t>(task_size)) {
const uint32_t num = static_cast<uint32_t>(range.Num());
HWY_DASSERT(task_size_ != 0);
num_tasks_ = hwy::DivCeil(num, task_size_);
HWY_DASSERT(num_tasks_ != 0);
if constexpr (HWY_IS_DEBUG_BUILD) {
const size_t handled = num_tasks_ * task_size_;
const uint32_t handled = num_tasks_ * task_size_;
// The last task may extend beyond items, but at most by (task_size_ - 1).
HWY_DASSERT(num <= handled && handled < num + task_size_);
(void)handled;
}
}
size_t TaskSize() const { return task_size_; }
size_t NumTasks() const { return num_tasks_; }
size_t TaskSize() const { return static_cast<size_t>(task_size_); }
size_t NumTasks() const { return static_cast<size_t>(num_tasks_); }
IndexRange Range(size_t task_idx) const {
HWY_DASSERT(task_idx < NumTasks());
return MakeIndexRange(range_.begin() + task_idx * task_size_, range_.end(),
task_size_);
return MakeIndexRange(range_.begin() + task_idx * TaskSize(), range_.end(),
TaskSize());
}
private:
IndexRange range_;
size_t task_size_;
size_t num_tasks_;
uint32_t task_size_;
uint32_t num_tasks_;
};
// Starts with `max_size` and rounds DOWN to a multiple of `size_multiple`
@ -455,33 +458,6 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1,
});
}
// As above, for three ranges.
template <class Func>
void ParallelizeThreeRanges(const IndexRangePartition& get1,
const IndexRangePartition& get2,
const IndexRangePartition& get3,
hwy::ThreadPool& pool, const Func& func) {
const hwy::Divisor div1(static_cast<uint32_t>(get1.NumTasks()));
const size_t num12 = get1.NumTasks() * get2.NumTasks();
const hwy::Divisor div12(static_cast<uint32_t>(num12));
const size_t num_tasks = num12 * get3.NumTasks();
pool.Run(0, num_tasks, [&](uint64_t task, size_t thread) {
HWY_DASSERT(task < (uint64_t{1} << 32));
const size_t idx3 = div12.Divide(static_cast<uint32_t>(task));
const size_t task12 = div12.Remainder(static_cast<uint32_t>(task));
const size_t idx2 = div1.Divide(static_cast<uint32_t>(task12));
const size_t idx1 = div1.Remainder(static_cast<uint32_t>(task12));
HWY_DASSERT(idx1 < get1.NumTasks());
HWY_DASSERT(idx2 < get2.NumTasks());
HWY_DASSERT(idx3 < get3.NumTasks());
const IndexRange range1 = get1.Range(idx1);
const IndexRange range2 = get2.Range(idx2);
const IndexRange range3 = get3.Range(idx3);
func(range1, range2, range3, thread);
});
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_

View File

@ -138,6 +138,13 @@ TEST(ThreadingTest, TestMaxSizePartition) {
HWY_ASSERT(partition.TaskSize() == 55);
HWY_ASSERT(partition.NumTasks() == 2);
}
// `size_multiple` almost as large as range: imbalanced
{
const IndexRangePartition partition =
MaxSizePartition(IndexRange(0, 6), 6, 4);
HWY_ASSERT(partition.TaskSize() == 4);
HWY_ASSERT(partition.NumTasks() == 2);
}
// Small `max_size`: small tasks
{
const IndexRangePartition partition = MaxSizePartition(range, 2, 1);
@ -244,97 +251,5 @@ TEST(ThreadingTest, TestParallelizeTwoRanges) {
}
}
TEST(ThreadingTest, TestParallelizeThreeRanges) {
// Named according to number of tasks.
const IndexRangePartition partition3 =
StaticPartition(IndexRange(0, 8), 3, 1); // [0, 3) [3, 6) [6, 8)
HWY_ASSERT(partition3.NumTasks() == 3);
const IndexRangePartition partition2 =
MaxSizePartition(IndexRange(10, 30), 10, 10); // [10, 20), [20, 30)
HWY_ASSERT(partition2.NumTasks() == 2);
const IndexRangePartition partition4 =
MaxSizePartition(IndexRange(100, 500), 100, 100); // 100, 200, 300, 400
HWY_ASSERT(partition4.NumTasks() == 4);
const auto check_ranges = [&](const IndexRange& range3,
const IndexRange& range2,
const IndexRange& range4) {
HWY_ASSERT(range3.begin() == 0 || range3.begin() == 3 ||
range3.begin() == 6);
HWY_ASSERT(range2.begin() == 10 || range2.begin() == 20);
HWY_ASSERT(range4.begin() % 100 == 0);
};
hwy::ThreadPool null_pool(0);
// All 6 permutations of the three ranges to test the Remainder() logic:
// 3, 2, 4
{
size_t calls = 0;
ParallelizeThreeRanges(
partition3, partition2, partition4, null_pool,
[&](IndexRange range3, IndexRange range2, IndexRange range4, size_t) {
++calls;
check_ranges(range3, range2, range4);
});
HWY_ASSERT(calls == 3 * 2 * 4);
}
// 3, 4, 2
{
size_t calls = 0;
ParallelizeThreeRanges(
partition3, partition4, partition2, null_pool,
[&](IndexRange range3, IndexRange range4, IndexRange range2, size_t) {
++calls;
check_ranges(range3, range2, range4);
});
HWY_ASSERT(calls == 3 * 2 * 4);
}
// 4, 2, 3
{
size_t calls = 0;
ParallelizeThreeRanges(
partition4, partition2, partition3, null_pool,
[&](IndexRange range4, IndexRange range2, IndexRange range3, size_t) {
++calls;
check_ranges(range3, range2, range4);
});
HWY_ASSERT(calls == 3 * 2 * 4);
}
// 4, 3, 2
{
size_t calls = 0;
ParallelizeThreeRanges(
partition4, partition3, partition2, null_pool,
[&](IndexRange range4, IndexRange range3, IndexRange range2, size_t) {
++calls;
check_ranges(range3, range2, range4);
});
HWY_ASSERT(calls == 3 * 2 * 4);
}
// 2, 3, 4
{
size_t calls = 0;
ParallelizeThreeRanges(
partition2, partition3, partition4, null_pool,
[&](IndexRange range2, IndexRange range3, IndexRange range4, size_t) {
++calls;
check_ranges(range3, range2, range4);
});
HWY_ASSERT(calls == 3 * 2 * 4);
}
// 2, 4, 3
{
size_t calls = 0;
ParallelizeThreeRanges(
partition2, partition4, partition3, null_pool,
[&](IndexRange range2, IndexRange range4, IndexRange range3, size_t) {
++calls;
check_ranges(range3, range2, range4);
});
HWY_ASSERT(calls == 3 * 2 * 4);
}
}
} // namespace
} // namespace gcpp