mirror of https://github.com/google/gemma.cpp.git
Infra improvements:
allocator: support mmap, fixed Bind, add padding bench_matmul: Add PreventElision BUILD: add ops_test build target matmul.h: move ConstMat here; dynamic alloc of MatMulEnv matmul_test: remove benchmarking replace fprintf with HWY_WARN threading.cc: support splitting large clusters (disabled); package_idx->pkg_idx, smaller IndexRangePartition PiperOrigin-RevId: 717512274
This commit is contained in:
parent
493688f6f1
commit
c4398fc72d
20
BUILD.bazel
20
BUILD.bazel
|
|
@ -76,6 +76,12 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# For building all tests in one command, so we can test several.
|
||||||
|
test_suite(
|
||||||
|
name = "ops_tests",
|
||||||
|
tags = ["ops_tests"],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "ops",
|
name = "ops",
|
||||||
hdrs = [
|
hdrs = [
|
||||||
|
|
@ -110,7 +116,7 @@ cc_test(
|
||||||
srcs = ["ops/dot_test.cc"],
|
srcs = ["ops/dot_test.cc"],
|
||||||
local_defines = ["HWY_IS_TEST"],
|
local_defines = ["HWY_IS_TEST"],
|
||||||
# for test_suite.
|
# for test_suite.
|
||||||
tags = ["hwy_ops_test"],
|
tags = ["ops_tests"],
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
":allocator",
|
||||||
":ops",
|
":ops",
|
||||||
|
|
@ -135,7 +141,7 @@ cc_test(
|
||||||
srcs = ["ops/ops_test.cc"],
|
srcs = ["ops/ops_test.cc"],
|
||||||
local_defines = ["HWY_IS_TEST"],
|
local_defines = ["HWY_IS_TEST"],
|
||||||
# for test_suite.
|
# for test_suite.
|
||||||
tags = ["hwy_ops_test"],
|
tags = ["ops_tests"],
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
":allocator",
|
||||||
":common",
|
":common",
|
||||||
|
|
@ -157,7 +163,7 @@ cc_test(
|
||||||
srcs = ["ops/gemma_matvec_test.cc"],
|
srcs = ["ops/gemma_matvec_test.cc"],
|
||||||
local_defines = ["HWY_IS_TEST"],
|
local_defines = ["HWY_IS_TEST"],
|
||||||
# for test_suite.
|
# for test_suite.
|
||||||
tags = ["hwy_ops_test"],
|
tags = ["ops_tests"],
|
||||||
deps = [
|
deps = [
|
||||||
":ops",
|
":ops",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
|
|
@ -175,7 +181,7 @@ cc_test(
|
||||||
srcs = ["ops/matmul_unit_test.cc"],
|
srcs = ["ops/matmul_unit_test.cc"],
|
||||||
local_defines = ["HWY_IS_TEST"],
|
local_defines = ["HWY_IS_TEST"],
|
||||||
# for test_suite.
|
# for test_suite.
|
||||||
tags = ["hwy_ops_test"],
|
tags = ["ops_tests"],
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
":allocator",
|
||||||
":basics",
|
":basics",
|
||||||
|
|
@ -195,7 +201,7 @@ cc_test(
|
||||||
srcs = ["ops/matmul_test.cc"],
|
srcs = ["ops/matmul_test.cc"],
|
||||||
local_defines = ["HWY_IS_TEST"],
|
local_defines = ["HWY_IS_TEST"],
|
||||||
# for test_suite.
|
# for test_suite.
|
||||||
tags = ["hwy_ops_test"],
|
tags = ["ops_tests"],
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
":allocator",
|
||||||
":basics",
|
":basics",
|
||||||
|
|
@ -205,7 +211,6 @@ cc_test(
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
"@highway//:nanobenchmark",
|
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -217,7 +222,7 @@ cc_test(
|
||||||
srcs = ["ops/bench_matmul.cc"],
|
srcs = ["ops/bench_matmul.cc"],
|
||||||
local_defines = ["HWY_IS_TEST"],
|
local_defines = ["HWY_IS_TEST"],
|
||||||
# for test_suite.
|
# for test_suite.
|
||||||
tags = ["hwy_ops_test"],
|
tags = ["ops_tests"],
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
":allocator",
|
||||||
":basics",
|
":basics",
|
||||||
|
|
@ -228,6 +233,7 @@ cc_test(
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
"@highway//:nanobenchmark",
|
"@highway//:nanobenchmark",
|
||||||
|
"@highway//:profiler",
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -309,13 +309,6 @@ decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m, size_t ofs = 0) {
|
|
||||||
ConstMat<T> mat = MakeConstMat(const_cast<T*>(m.data()), m.Extents(), ofs);
|
|
||||||
mat.scale = m.scale();
|
|
||||||
return mat;
|
|
||||||
}
|
|
||||||
|
|
||||||
// MatStorageT adds the actual data storage to MatPtrT.
|
// MatStorageT adds the actual data storage to MatPtrT.
|
||||||
// TODO: use Extents2D instead of rows and cols.
|
// TODO: use Extents2D instead of rows and cols.
|
||||||
template <typename MatT>
|
template <typename MatT>
|
||||||
|
|
@ -361,7 +354,7 @@ class MatStorageT : public MatPtrT<MatT> {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
hwy::AlignedFreeUniquePtr<MatT[]> data_;
|
AlignedPtr<MatT> data_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// MatStorage allows heterogeneous tensors to be stored in a single vector.
|
// MatStorage allows heterogeneous tensors to be stored in a single vector.
|
||||||
|
|
|
||||||
|
|
@ -273,11 +273,11 @@ struct PackedSpan {
|
||||||
// Ensures callers can read or write `num_accessible` elements starting at
|
// Ensures callers can read or write `num_accessible` elements starting at
|
||||||
// `packed_ofs`.
|
// `packed_ofs`.
|
||||||
void BoundsCheck(size_t packed_ofs, size_t num_accessible) const {
|
void BoundsCheck(size_t packed_ofs, size_t num_accessible) const {
|
||||||
// For NUQ, there can be fewer Packed than the number of elements, hence
|
|
||||||
// check the compressed count and ensure we have that many.
|
|
||||||
const size_t required =
|
|
||||||
CompressedArrayElements<Packed>(packed_ofs + num_accessible);
|
|
||||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||||
|
// For NUQ, there can be fewer Packed than the number of elements, hence
|
||||||
|
// check the compressed count and ensure we have that many.
|
||||||
|
const size_t required =
|
||||||
|
CompressedArrayElements<Packed>(packed_ofs + num_accessible);
|
||||||
if (num < required) {
|
if (num < required) {
|
||||||
HWY_ABORT("PackedSpan: ofs %zu, want %zu, req %zu > %zu packed",
|
HWY_ABORT("PackedSpan: ofs %zu, want %zu, req %zu > %zu packed",
|
||||||
packed_ofs, num_accessible, required, num);
|
packed_ofs, num_accessible, required, num);
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <memory> // std::unique_ptr
|
||||||
|
|
||||||
#include "compression/shared.h" // BF16
|
#include "compression/shared.h" // BF16
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
|
|
@ -63,7 +64,8 @@ struct Activations {
|
||||||
// Rope
|
// Rope
|
||||||
RowVectorBatch<float> inv_timescale;
|
RowVectorBatch<float> inv_timescale;
|
||||||
|
|
||||||
MatMulEnv env;
|
// Dynamic because no default ctor and only initialized in `Allocate`.
|
||||||
|
std::unique_ptr<MatMulEnv> env;
|
||||||
|
|
||||||
PostQKType post_qk = PostQKType::Rope;
|
PostQKType post_qk = PostQKType::Rope;
|
||||||
// And the config.
|
// And the config.
|
||||||
|
|
@ -122,7 +124,7 @@ struct Activations {
|
||||||
|
|
||||||
inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk);
|
inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk);
|
||||||
|
|
||||||
env = MatMulEnv(pools);
|
env = std::make_unique<MatMulEnv>(pools);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,7 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
||||||
const KVCaches& kv_caches) {
|
const KVCaches& kv_caches) {
|
||||||
PROFILER_ZONE("Gen.Griffin");
|
PROFILER_ZONE("Gen.Griffin");
|
||||||
KVCache& kv_cache = kv_caches[0];
|
KVCache& kv_cache = kv_caches[0];
|
||||||
hwy::ThreadPool& pool = activations.env.Pool();
|
hwy::ThreadPool& pool = activations.env->Pool();
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
const size_t model_dim = layer_weights->layer_config.model_dim;
|
const size_t model_dim = layer_weights->layer_config.model_dim;
|
||||||
|
|
@ -252,7 +252,7 @@ class GemmaAttention {
|
||||||
const size_t w1_rows = heads * layer_config_.QStride();
|
const size_t w1_rows = heads * layer_config_.QStride();
|
||||||
w_q1.ShrinkRows(w1_rows);
|
w_q1.ShrinkRows(w1_rows);
|
||||||
MatMul(pre_att_rms_out, w_q1,
|
MatMul(pre_att_rms_out, w_q1,
|
||||||
/*add=*/nullptr, activations_.env, RowPtrFromBatch(activations_.q));
|
/*add=*/nullptr, *activations_.env, RowPtrFromBatch(activations_.q));
|
||||||
|
|
||||||
if (is_mha_) {
|
if (is_mha_) {
|
||||||
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
|
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
|
||||||
|
|
@ -275,7 +275,7 @@ class GemmaAttention {
|
||||||
RowPtrF kv_rows(kv, w_rows_kv_cols);
|
RowPtrF kv_rows(kv, w_rows_kv_cols);
|
||||||
kv_rows.SetStride(cache_pos_size_);
|
kv_rows.SetStride(cache_pos_size_);
|
||||||
MatMul(pre_att_rms_out, w_q2,
|
MatMul(pre_att_rms_out, w_q2,
|
||||||
/*add=*/nullptr, activations_.env, kv_rows);
|
/*add=*/nullptr, *activations_.env, kv_rows);
|
||||||
} else {
|
} else {
|
||||||
// Proceed row by row because there will be wraparound.
|
// Proceed row by row because there will be wraparound.
|
||||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||||
|
|
@ -464,7 +464,7 @@ class GemmaAttention {
|
||||||
: nullptr;
|
: nullptr;
|
||||||
MatMul(ConstMatFromBatch(num_interleaved, activations_.att_out),
|
MatMul(ConstMatFromBatch(num_interleaved, activations_.att_out),
|
||||||
ConstMatFromWeights(layer_weights_.att_weights), add,
|
ConstMatFromWeights(layer_weights_.att_weights), add,
|
||||||
activations_.env, RowPtrFromBatch(activations_.att_sums));
|
*activations_.env, RowPtrFromBatch(activations_.att_sums));
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
@ -514,7 +514,7 @@ class GemmaAttention {
|
||||||
layer_weights_(*layer_weights),
|
layer_weights_(*layer_weights),
|
||||||
div_seq_len_(div_seq_len),
|
div_seq_len_(div_seq_len),
|
||||||
kv_caches_(kv_caches),
|
kv_caches_(kv_caches),
|
||||||
pool_(activations.env.Pool()) {
|
pool_(activations.env->Pool()) {
|
||||||
HWY_DASSERT(num_queries_ <= kv_caches_.size());
|
HWY_DASSERT(num_queries_ <= kv_caches_.size());
|
||||||
HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0,
|
HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0,
|
||||||
"query heads must be a multiple of key-value heads");
|
"query heads must be a multiple of key-value heads");
|
||||||
|
|
@ -587,7 +587,7 @@ class VitAttention {
|
||||||
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
|
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
|
||||||
MatMul(ConstMatFromBatch(num_tokens_, activations_.pre_att_rms_out),
|
MatMul(ConstMatFromBatch(num_tokens_, activations_.pre_att_rms_out),
|
||||||
ConstMatFromWeights(layer_weights_.vit.qkv_einsum_w),
|
ConstMatFromWeights(layer_weights_.vit.qkv_einsum_w),
|
||||||
layer_weights_.vit.qkv_einsum_b.data_scale1(), activations_.env,
|
layer_weights_.vit.qkv_einsum_b.data_scale1(), *activations_.env,
|
||||||
RowPtrFromBatch(qkv));
|
RowPtrFromBatch(qkv));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -641,7 +641,7 @@ class VitAttention {
|
||||||
auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out);
|
auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out);
|
||||||
auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w);
|
auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w);
|
||||||
auto att_sums = RowPtrFromBatch(activations_.att_sums);
|
auto att_sums = RowPtrFromBatch(activations_.att_sums);
|
||||||
MatMul(att_out, att_weights, bias, activations_.env, att_sums);
|
MatMul(att_out, att_weights, bias, *activations_.env, att_sums);
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
@ -652,7 +652,7 @@ class VitAttention {
|
||||||
activations_(activations),
|
activations_(activations),
|
||||||
layer_weights_(*layer_weights),
|
layer_weights_(*layer_weights),
|
||||||
layer_config_(layer_weights->layer_config),
|
layer_config_(layer_weights->layer_config),
|
||||||
pool_(activations.env.Pool()) {}
|
pool_(activations.env->Pool()) {}
|
||||||
|
|
||||||
HWY_INLINE void operator()() {
|
HWY_INLINE void operator()() {
|
||||||
ComputeQKV();
|
ComputeQKV();
|
||||||
|
|
@ -728,8 +728,8 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
|
||||||
auto w_output = ConstMatFromWeights(layer_weights->linear_w);
|
auto w_output = ConstMatFromWeights(layer_weights->linear_w);
|
||||||
|
|
||||||
// Compute the hidden layer activations.
|
// Compute the hidden layer activations.
|
||||||
MatMul(x, w1, bias1, activations.env, hidden_activations);
|
MatMul(x, w1, bias1, *activations.env, hidden_activations);
|
||||||
MatMul(x, w2, bias2, activations.env, multiplier);
|
MatMul(x, w2, bias2, *activations.env, multiplier);
|
||||||
|
|
||||||
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
||||||
Activation(layer_weights->layer_config.activation, hidden_activations.Row(0),
|
Activation(layer_weights->layer_config.activation, hidden_activations.Row(0),
|
||||||
|
|
@ -739,7 +739,7 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
|
||||||
auto activations_mat = MakeConstMat(
|
auto activations_mat = MakeConstMat(
|
||||||
hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim));
|
hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim));
|
||||||
|
|
||||||
MatMul(activations_mat, w_output, output_bias, activations.env, ffw_out);
|
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as FFWNoVit, but with different layer_weights members and no second
|
// Same as FFWNoVit, but with different layer_weights members and no second
|
||||||
|
|
@ -769,7 +769,7 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
|
||||||
auto w_output = ConstMatFromWeights(layer_weights->vit.linear_1_w);
|
auto w_output = ConstMatFromWeights(layer_weights->vit.linear_1_w);
|
||||||
|
|
||||||
// Compute the hidden layer activations.
|
// Compute the hidden layer activations.
|
||||||
MatMul(x, w1, bias1, activations.env, hidden_activations);
|
MatMul(x, w1, bias1, *activations.env, hidden_activations);
|
||||||
|
|
||||||
// Activation (Gelu), store in act.
|
// Activation (Gelu), store in act.
|
||||||
RowPtrF multiplier = RowPtrF(nullptr, 0);
|
RowPtrF multiplier = RowPtrF(nullptr, 0);
|
||||||
|
|
@ -780,7 +780,7 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
|
||||||
auto activations_mat = MakeConstMat(
|
auto activations_mat = MakeConstMat(
|
||||||
hidden_activations.Row(0), Extents2D(num_interleaved, ff_hidden_dim));
|
hidden_activations.Row(0), Extents2D(num_interleaved, ff_hidden_dim));
|
||||||
|
|
||||||
MatMul(activations_mat, w_output, output_bias, activations.env, ffw_out);
|
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
// `batch_idx` indicates which row of `x` to write to.
|
// `batch_idx` indicates which row of `x` to write to.
|
||||||
|
|
@ -1063,7 +1063,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
||||||
// MatMul(
|
// MatMul(
|
||||||
// MatFromBatch(kVitSeqLen, image_patches),
|
// MatFromBatch(kVitSeqLen, image_patches),
|
||||||
// MatFromWeights(weights.vit_img_embedding_kernel),
|
// MatFromWeights(weights.vit_img_embedding_kernel),
|
||||||
// weights.vit_img_embedding_bias.data_scale1(), activations.env,
|
// weights.vit_img_embedding_bias.data_scale1(), *activations.env,
|
||||||
// RowPtrF(activations.x.All(), kVitModelDim));
|
// RowPtrF(activations.x.All(), kVitModelDim));
|
||||||
// However, MatMul currently requires that
|
// However, MatMul currently requires that
|
||||||
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
|
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
|
||||||
|
|
@ -1073,7 +1073,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
||||||
MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size,
|
MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size,
|
||||||
image_patches[i].get(),
|
image_patches[i].get(),
|
||||||
weights.vit_img_embedding_bias.data_scale1(),
|
weights.vit_img_embedding_bias.data_scale1(),
|
||||||
activations.x.Batch(i), activations.env.Pool());
|
activations.x.Batch(i), activations.env->Pool());
|
||||||
}
|
}
|
||||||
// Add position embeddings.
|
// Add position embeddings.
|
||||||
AddFrom(weights.vit_img_pos_embedding.data_scale1(), activations.x.All(),
|
AddFrom(weights.vit_img_pos_embedding.data_scale1(), activations.x.All(),
|
||||||
|
|
@ -1108,7 +1108,7 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
|
||||||
// Apply head embedding into image_tokens of size of the LLM kModelDim.
|
// Apply head embedding into image_tokens of size of the LLM kModelDim.
|
||||||
MatMul(ConstMatFromBatch(num_tokens, activations.x),
|
MatMul(ConstMatFromBatch(num_tokens, activations.x),
|
||||||
ConstMatFromWeights(weights.vit_img_head_kernel),
|
ConstMatFromWeights(weights.vit_img_head_kernel),
|
||||||
weights.vit_img_head_bias.data_scale1(), activations.env,
|
weights.vit_img_head_bias.data_scale1(), *activations.env,
|
||||||
RowPtrFromBatch(image_tokens));
|
RowPtrFromBatch(image_tokens));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1281,7 +1281,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
||||||
Activations prefill_activations(weights.weights_config);
|
Activations prefill_activations(weights.weights_config);
|
||||||
if (use_prefill_activations) {
|
if (use_prefill_activations) {
|
||||||
prefill_activations.Allocate(runtime_config.prefill_tbatch_size,
|
prefill_activations.Allocate(runtime_config.prefill_tbatch_size,
|
||||||
activations.env.Pools());
|
activations.env->Pools());
|
||||||
}
|
}
|
||||||
Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end,
|
Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end,
|
||||||
query_idx_start, weights,
|
query_idx_start, weights,
|
||||||
|
|
@ -1326,7 +1326,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
||||||
// Compute logits from last layer activations.
|
// Compute logits from last layer activations.
|
||||||
MatMul(ConstMatFromBatch(num_queries, activations.x),
|
MatMul(ConstMatFromBatch(num_queries, activations.x),
|
||||||
ConstMatFromWeights(weights.embedder_input_embedding),
|
ConstMatFromWeights(weights.embedder_input_embedding),
|
||||||
/*add=*/nullptr, activations.env,
|
/*add=*/nullptr, *activations.env,
|
||||||
RowPtrFromBatch(activations.logits));
|
RowPtrFromBatch(activations.logits));
|
||||||
}
|
}
|
||||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@
|
||||||
// Benchmark of large MatMul instances for which the MatMulSlow would be too
|
// Benchmark of large MatMul instances for which the MatMulSlow would be too
|
||||||
// slow. This lacks a reference and is only useful for performance measurement.
|
// slow. This lacks a reference and is only useful for performance measurement.
|
||||||
|
|
||||||
#include "hwy/detect_compiler_arch.h"
|
#include "hwy/base.h"
|
||||||
#ifndef HWY_DISABLED_TARGETS
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
// Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require
|
// Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require
|
||||||
// double-precision support.
|
// double-precision support.
|
||||||
|
|
@ -30,7 +30,9 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
|
|
@ -38,8 +40,8 @@
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
#include "util/basics.h"
|
#include "util/basics.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
#include "hwy/base.h"
|
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
#include "hwy/nanobenchmark.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
|
@ -51,6 +53,7 @@
|
||||||
// After highway.h
|
// After highway.h
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
#include "ops/matmul-inl.h"
|
#include "ops/matmul-inl.h"
|
||||||
|
#include "hwy/profiler.h" // also uses SIMD
|
||||||
#include "hwy/tests/test_util-inl.h"
|
#include "hwy/tests/test_util-inl.h"
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
|
@ -74,7 +77,8 @@ MatStoragePtr<MatT> GenerateMat(const Extents2D extents,
|
||||||
std::make_unique<MatStorageT<MatT>>("mat", extents.rows, extents.cols);
|
std::make_unique<MatStorageT<MatT>>("mat", extents.rows, extents.cols);
|
||||||
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
|
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
|
||||||
HWY_ASSERT(content);
|
HWY_ASSERT(content);
|
||||||
const float scale = SfpStream::kMax / (mat->NumElements());
|
const float scale =
|
||||||
|
SfpStream::kMax / (mat->NumElements() + hwy::Unpredictable1() - 1);
|
||||||
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
||||||
for (size_t c = 0; c < extents.cols; c++) {
|
for (size_t c = 0; c < extents.cols; c++) {
|
||||||
float f = static_cast<float>(r * extents.cols + c) * scale;
|
float f = static_cast<float>(r * extents.cols + c) * scale;
|
||||||
|
|
@ -96,7 +100,8 @@ MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
|
||||||
auto mat =
|
auto mat =
|
||||||
std::make_unique<MatStorageT<MatT>>("trans", extents.rows, extents.cols);
|
std::make_unique<MatStorageT<MatT>>("trans", extents.rows, extents.cols);
|
||||||
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
|
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
|
||||||
const float scale = SfpStream::kMax / (mat->NumElements());
|
const float scale =
|
||||||
|
SfpStream::kMax / (mat->NumElements() + hwy::Unpredictable1() - 1);
|
||||||
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
||||||
for (size_t c = 0; c < extents.cols; c++) {
|
for (size_t c = 0; c < extents.cols; c++) {
|
||||||
float f = static_cast<float>(c * extents.rows + r) * scale;
|
float f = static_cast<float>(c * extents.rows + r) * scale;
|
||||||
|
|
@ -111,52 +116,63 @@ MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
|
||||||
return mat;
|
return mat;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PrintSpeed(const char* algo, const Extents2D& A_extents,
|
void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
|
||||||
const Extents2D& B_extents, double elapsed) {
|
std::vector<double>& times) {
|
||||||
|
std::sort(times.begin(), times.end());
|
||||||
|
// Many measurements are with suboptimal configs, so report the best like
|
||||||
|
// bench_dnn, but also the ratio to the 3rd best.
|
||||||
|
const double elapsed = times[0];
|
||||||
|
const double ratio = times[2] / HWY_MAX(elapsed, 1E-6);
|
||||||
|
|
||||||
const size_t num_b = B_extents.Area();
|
const size_t num_b = B_extents.Area();
|
||||||
// 2x because of FMA.
|
// 2x because of FMA.
|
||||||
fprintf(stderr, " %10s: %f seconds, %.1f GFLOPS.\n", algo,
|
fprintf(stderr, "%.1f\t%.2f\n", 2 * 1E-9 * A_extents.rows * num_b / elapsed,
|
||||||
elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed);
|
ratio);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generates inputs and prints observed throughput of MatMul.
|
// Generates inputs and prints observed throughput of MatMul.
|
||||||
|
// M = A rows, K = A cols, N = C cols.
|
||||||
template <typename MatTA, typename MatTB = MatTA>
|
template <typename MatTA, typename MatTB = MatTA>
|
||||||
void BenchMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
||||||
MatMulEnv& env) {
|
|
||||||
hwy::ThreadPool& pool = env.Pool();
|
hwy::ThreadPool& pool = env.Pool();
|
||||||
fprintf(stderr, "BenchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n",
|
fprintf(stderr, "BenchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", M,
|
||||||
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<MatTA>(),
|
K, N, add, TypeName<MatTA>(), TypeName<MatTB>());
|
||||||
TypeName<MatTB>());
|
|
||||||
|
|
||||||
const Extents2D A_extents(rows_ac, cols_a_rows_b);
|
const Extents2D A_extents(M, K);
|
||||||
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
|
const Extents2D B_extents(N, K); // already transposed
|
||||||
const Extents2D C_extents(rows_ac, cols_bc);
|
const Extents2D C_extents(M, N);
|
||||||
|
|
||||||
MatStoragePtr<MatTA> a = GenerateMat<MatTA>(A_extents, pool);
|
|
||||||
MatStoragePtr<MatTB> b_trans = GenerateTransposedMat<MatTB>(B_extents, pool);
|
|
||||||
RowVectorBatch<float> c_slow_batch(C_extents);
|
RowVectorBatch<float> c_slow_batch(C_extents);
|
||||||
RowVectorBatch<float> c_batch(C_extents);
|
RowVectorBatch<float> c_batch(C_extents);
|
||||||
HWY_ASSERT(a && b_trans);
|
|
||||||
|
|
||||||
std::unique_ptr<MatStorageT<float>> add_storage;
|
std::unique_ptr<MatStorageT<float>> add_storage;
|
||||||
if (add) {
|
if (add) {
|
||||||
add_storage = GenerateMat<float>(Extents2D(1, cols_bc), pool);
|
add_storage = GenerateMat<float>(Extents2D(1, N), pool);
|
||||||
HWY_ASSERT(add_storage);
|
HWY_ASSERT(add_storage);
|
||||||
add_storage->set_scale(1.0f);
|
add_storage->set_scale(1.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MatStoragePtr<MatTA> a = GenerateMat<MatTA>(A_extents, pool);
|
||||||
|
MatStoragePtr<MatTB> b_trans = GenerateTransposedMat<MatTB>(B_extents, pool);
|
||||||
|
HWY_ASSERT(a && b_trans);
|
||||||
const auto A = ConstMatFromWeights(*a);
|
const auto A = ConstMatFromWeights(*a);
|
||||||
const auto B = ConstMatFromWeights(*b_trans);
|
const auto B = ConstMatFromWeights(*b_trans);
|
||||||
|
|
||||||
const float* add_row = add ? add_storage->data_scale1() : nullptr;
|
const float* add_row = add ? add_storage->data_scale1() : nullptr;
|
||||||
const RowPtrF C = RowPtrFromBatch(c_batch);
|
const RowPtrF C = RowPtrFromBatch(c_batch);
|
||||||
|
|
||||||
double min_elapsed = hwy::HighestValue<double>();
|
std::vector<double> times;
|
||||||
for (int rep = 0; rep < 3; ++rep) {
|
times.reserve(20);
|
||||||
const double start_tiled = hwy::platform::Now();
|
double result = 0.0;
|
||||||
|
for (;;) {
|
||||||
|
const double t0 = hwy::platform::Now();
|
||||||
MatMul(A, B, add_row, env, C);
|
MatMul(A, B, add_row, env, C);
|
||||||
min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled);
|
times.push_back(hwy::platform::Now() - t0);
|
||||||
|
result += C.Row(0)[hwy::Unpredictable1()];
|
||||||
|
if (times.size() >= 20) break;
|
||||||
}
|
}
|
||||||
PrintSpeed("MatMul", A_extents, B_extents, min_elapsed);
|
hwy::PreventElision(result);
|
||||||
|
PrintSpeed(A_extents, B_extents, times);
|
||||||
}
|
}
|
||||||
|
|
||||||
using F32 = float;
|
using F32 = float;
|
||||||
|
|
@ -184,16 +200,15 @@ void BenchAllMatMul() {
|
||||||
Allocator::Init(pools.Topology());
|
Allocator::Init(pools.Topology());
|
||||||
MatMulEnv env(pools);
|
MatMulEnv env(pools);
|
||||||
|
|
||||||
for (size_t batch_size : {1, /*4, 64,*/ 128}) {
|
for (size_t batch_size : {1, /* 4, 128,*/ 512}) {
|
||||||
BenchMatMul<F32, F32>(batch_size, 24576, 3072, /*add=*/false, env);
|
constexpr bool kAdd = false;
|
||||||
BenchMatMul<F32, F32>(batch_size, 3072, 24576, /*add=*/false, env);
|
BenchMatMul<BF16, SFP>(batch_size, 24576, 3072, kAdd, env);
|
||||||
BenchMatMul<BF16, SFP>(batch_size, 24576, 3072, /*add=*/false, env);
|
BenchMatMul<BF16, SFP>(batch_size, 3072, 24576, kAdd, env);
|
||||||
BenchMatMul<BF16, SFP>(batch_size, 3072, 24576, /*add=*/false, env);
|
|
||||||
BenchMatMul<F32, SFP>(batch_size, 24576, 3072, /*add=*/false, env);
|
|
||||||
BenchMatMul<F32, SFP>(batch_size, 3072, 24576, /*add=*/false, env);
|
|
||||||
}
|
}
|
||||||
pools.MaybeStopSpinning(use_spinning);
|
pools.MaybeStopSpinning(use_spinning);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PROFILER_PRINT_RESULTS();
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
|
|
||||||
66
ops/matmul.h
66
ops/matmul.h
|
|
@ -19,6 +19,8 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
|
#include "compression/compress.h"
|
||||||
|
#include "util/allocator.h"
|
||||||
#include "util/basics.h"
|
#include "util/basics.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
@ -81,6 +83,70 @@ class MatMulEnv {
|
||||||
NestedPools* pools_;
|
NestedPools* pools_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Used for the A and B arguments of `MatMul`, which are always const.
|
||||||
|
// Create via MakeConstMat. This differs from `RowPtr` in that it supports the
|
||||||
|
// `ofs` required for compressed T.
|
||||||
|
template <typename T>
|
||||||
|
struct ConstMat {
|
||||||
|
ConstMat(const T* ptr, Extents2D extents, size_t ofs = 0)
|
||||||
|
: ptr(ptr), extents(extents), ofs(ofs) {
|
||||||
|
HWY_DASSERT(ptr != nullptr);
|
||||||
|
}
|
||||||
|
// TODO: support stride for page alignment.
|
||||||
|
size_t Row(size_t r) const {
|
||||||
|
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||||
|
if (r >= extents.rows) {
|
||||||
|
HWY_ABORT("ConstMat::Row %zu out of bounds %zu", r, extents.rows);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ofs + extents.cols * r;
|
||||||
|
}
|
||||||
|
|
||||||
|
const Extents2D& Extents() const { return extents; }
|
||||||
|
size_t Stride() const { return extents.cols; }
|
||||||
|
|
||||||
|
// Shrinks the row-extent of this matrix view, i.e. reduces the view to a
|
||||||
|
// subrange of the original rows starting at row 0.
|
||||||
|
void ShrinkRows(size_t rows) {
|
||||||
|
HWY_ASSERT(rows <= extents.rows);
|
||||||
|
extents.rows = rows;
|
||||||
|
}
|
||||||
|
|
||||||
|
const T* HWY_RESTRICT ptr;
|
||||||
|
Extents2D extents;
|
||||||
|
|
||||||
|
// `scale` allows expanding the smaller range of `SfpStream` to the original
|
||||||
|
// values. MatFromWeights sets this from `MatPtr`.
|
||||||
|
float scale = 1.0f;
|
||||||
|
|
||||||
|
// Offset to add to `ptr`; separate because T=NuqStream does not support
|
||||||
|
// pointer arithmetic.
|
||||||
|
size_t ofs;
|
||||||
|
};
|
||||||
|
|
||||||
|
// For deducing T.
|
||||||
|
template <typename T>
|
||||||
|
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents,
|
||||||
|
size_t ofs = 0) {
|
||||||
|
return ConstMat<T>(ptr, extents, ofs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// For A argument to MatMul (activations).
|
||||||
|
template <typename T>
|
||||||
|
ConstMat<T> ConstMatFromBatch(size_t batch_size,
|
||||||
|
const RowVectorBatch<T>& row_vectors) {
|
||||||
|
HWY_DASSERT(batch_size <= row_vectors.BatchSize());
|
||||||
|
return MakeConstMat(const_cast<T*>(row_vectors.Const()),
|
||||||
|
Extents2D(batch_size, row_vectors.Cols()));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m, size_t ofs = 0) {
|
||||||
|
ConstMat<T> mat = MakeConstMat(const_cast<T*>(m.data()), m.Extents(), ofs);
|
||||||
|
mat.scale = m.scale();
|
||||||
|
return mat;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,6 @@
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/timer.h"
|
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#undef HWY_TARGET_INCLUDE
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
|
@ -55,7 +54,7 @@
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
// For running TestBatchSizes only once. Defined within HWY_ONCE.
|
// For running TestTiny only once. Defined within HWY_ONCE.
|
||||||
extern int64_t first_target;
|
extern int64_t first_target;
|
||||||
|
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
|
|
@ -144,10 +143,10 @@ void AssertClose(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
const size_t num_a = A.extents.Area();
|
const size_t num_a = A.extents.Area();
|
||||||
const size_t num_b = B.extents.Area();
|
const size_t num_b = B.extents.Area();
|
||||||
HWY_ASSERT(num_a % hn::Lanes(df) == 0); // for DecompressAndZeroPad
|
const size_t N = hn::Lanes(df);
|
||||||
HWY_ASSERT(num_b % hn::Lanes(df) == 0); // for DecompressAndZeroPad
|
// Round up for DecompressAndZeroPad.
|
||||||
FloatPtr a = hwy::AllocateAligned<float>(num_a);
|
FloatPtr a = hwy::AllocateAligned<float>(hwy::RoundUpTo(num_a, N));
|
||||||
FloatPtr b_trans = hwy::AllocateAligned<float>(num_b);
|
FloatPtr b_trans = hwy::AllocateAligned<float>(hwy::RoundUpTo(num_b, N));
|
||||||
HWY_ASSERT(a && b_trans);
|
HWY_ASSERT(a && b_trans);
|
||||||
HWY_ASSERT(A.ofs == 0 && B.ofs == 0);
|
HWY_ASSERT(A.ofs == 0 && B.ofs == 0);
|
||||||
DecompressAndZeroPad(df, MakeSpan(A.ptr, num_a), 0, a.get(), num_a);
|
DecompressAndZeroPad(df, MakeSpan(A.ptr, num_a), 0, a.get(), num_a);
|
||||||
|
|
@ -164,13 +163,11 @@ void AssertClose(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
|
||||||
double tolerance = 8 * norm * eps_f32;
|
double tolerance = 8 * norm * eps_f32;
|
||||||
// Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the
|
// Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the
|
||||||
// tolerance there.
|
// tolerance there.
|
||||||
if (IsF32<MatTA>() && IsF32<MatTB>()) {
|
if (IsF32<MatTA>() && !IsF32<MatTB>()) {
|
||||||
tolerance += 4 * max_abs * eps_bf16;
|
tolerance += 4 * max_abs * eps_bf16;
|
||||||
}
|
}
|
||||||
EXPECT_GE(tolerance, 1E-4);
|
if (tolerance > 8.0) {
|
||||||
if (tolerance > 4.0) {
|
HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs);
|
||||||
fprintf(stderr, "WARN: high tolerance %f norm %f maxabs %f\n", tolerance,
|
|
||||||
norm, max_abs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t r = 0; r < A.extents.rows; r++) {
|
for (size_t r = 0; r < A.extents.rows; r++) {
|
||||||
|
|
@ -182,11 +179,10 @@ void AssertClose(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
|
||||||
|
|
||||||
if (!(expected_value - tolerance <= actual_value &&
|
if (!(expected_value - tolerance <= actual_value &&
|
||||||
actual_value <= expected_value + tolerance)) {
|
actual_value <= expected_value + tolerance)) {
|
||||||
fprintf(stderr,
|
HWY_ABORT(
|
||||||
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
|
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
|
||||||
"tolerance %f\n",
|
"tolerance %f\n",
|
||||||
r, c, expected_value, actual_value, norm, max_abs, tolerance);
|
r, c, expected_value, actual_value, norm, max_abs, tolerance);
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -217,7 +213,7 @@ HWY_INLINE void MatMulSlow(const ConstMat<MatTA> A, const ConstMat<MatTB> B,
|
||||||
get_row_c, all_packages,
|
get_row_c, all_packages,
|
||||||
[&](const IndexRange& rows_c, size_t package_idx) HWY_ATTR {
|
[&](const IndexRange& rows_c, size_t package_idx) HWY_ATTR {
|
||||||
hwy::ThreadPool& all_clusters = pools.AllClusters(package_idx);
|
hwy::ThreadPool& all_clusters = pools.AllClusters(package_idx);
|
||||||
const size_t multiple = Allocator::Alignment() / sizeof(MatTB);
|
const size_t multiple = Allocator::QuantumBytes() / sizeof(MatTB);
|
||||||
const IndexRangePartition get_col_c =
|
const IndexRangePartition get_col_c =
|
||||||
StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
|
StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
|
||||||
ParallelizeOneRange(
|
ParallelizeOneRange(
|
||||||
|
|
@ -248,7 +244,6 @@ template <typename MatTA, typename MatTB = MatTA>
|
||||||
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||||
MatMulEnv& env) {
|
MatMulEnv& env) {
|
||||||
hwy::ThreadPool& pool = env.Pool();
|
hwy::ThreadPool& pool = env.Pool();
|
||||||
const bool want_bench = cols_bc > 2000; // avoid spam for small matrices
|
|
||||||
fprintf(stderr, "TestMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n",
|
fprintf(stderr, "TestMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n",
|
||||||
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<MatTA>(),
|
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<MatTA>(),
|
||||||
TypeName<MatTB>());
|
TypeName<MatTB>());
|
||||||
|
|
@ -276,32 +271,17 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||||
const RowPtrF C_slow = RowPtrFromBatch(c_slow_batch);
|
const RowPtrF C_slow = RowPtrFromBatch(c_slow_batch);
|
||||||
const RowPtrF C = RowPtrFromBatch(c_batch);
|
const RowPtrF C = RowPtrFromBatch(c_batch);
|
||||||
|
|
||||||
const double start_slow = hwy::platform::Now();
|
|
||||||
MatMulSlow(A, B, add_row, env, C_slow);
|
MatMulSlow(A, B, add_row, env, C_slow);
|
||||||
if (want_bench) {
|
MatMul(A, B, add_row, env, C);
|
||||||
PrintSpeed("MatMulSlow", A_extents, B_extents,
|
|
||||||
hwy::platform::Now() - start_slow);
|
|
||||||
}
|
|
||||||
|
|
||||||
double min_elapsed = hwy::HighestValue<double>();
|
|
||||||
for (int rep = 0; rep < (want_bench ? 3 : 1); ++rep) {
|
|
||||||
const double start_tiled = hwy::platform::Now();
|
|
||||||
MatMul(A, B, add_row, env, C);
|
|
||||||
min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled);
|
|
||||||
}
|
|
||||||
if (want_bench) {
|
|
||||||
PrintSpeed("MatMul", A_extents, B_extents, min_elapsed);
|
|
||||||
}
|
|
||||||
|
|
||||||
AssertClose(A, B, C_slow, C);
|
AssertClose(A, B, C_slow, C);
|
||||||
}
|
}
|
||||||
|
|
||||||
using F32 = float;
|
using F32 = float;
|
||||||
using SFP = SfpStream;
|
using SFP = SfpStream;
|
||||||
|
|
||||||
// Sweep batch_size for a single input type and Highway target, to verify the
|
// Sweep all dimensions for a single input type and Highway target, to verify
|
||||||
// row partitioning.
|
// the remainder handling.
|
||||||
void TestBatchSizes() {
|
void TestTiny() {
|
||||||
if (first_target == 0) first_target = HWY_TARGET;
|
if (first_target == 0) first_target = HWY_TARGET;
|
||||||
if (HWY_TARGET != first_target) return;
|
if (HWY_TARGET != first_target) return;
|
||||||
|
|
||||||
|
|
@ -315,7 +295,7 @@ void TestBatchSizes() {
|
||||||
// If less than the limit, we have already tested all num_packages.
|
// If less than the limit, we have already tested all num_packages.
|
||||||
if (pools.Topology().FullTopology().packages.size() < max_packages) break;
|
if (pools.Topology().FullTopology().packages.size() < max_packages) break;
|
||||||
#endif
|
#endif
|
||||||
fprintf(stderr, "TestBatchSizes %zu: %s %s\n", max_packages,
|
fprintf(stderr, "TestTiny %zu: %s %s\n", max_packages,
|
||||||
pools.TopologyString(), pools.PinString());
|
pools.TopologyString(), pools.PinString());
|
||||||
|
|
||||||
Tristate use_spinning = Tristate::kDefault;
|
Tristate use_spinning = Tristate::kDefault;
|
||||||
|
|
@ -405,7 +385,7 @@ HWY_AFTER_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
int64_t first_target = 0; // none run yet
|
int64_t first_target = 0; // none run yet
|
||||||
HWY_BEFORE_TEST(MatMulTest);
|
HWY_BEFORE_TEST(MatMulTest);
|
||||||
HWY_EXPORT_AND_TEST_P(MatMulTest, TestBatchSizes);
|
HWY_EXPORT_AND_TEST_P(MatMulTest, TestTiny);
|
||||||
HWY_EXPORT_AND_TEST_P(MatMulTest, TestAllMatMul);
|
HWY_EXPORT_AND_TEST_P(MatMulTest, TestAllMatMul);
|
||||||
HWY_AFTER_TEST();
|
HWY_AFTER_TEST();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,45 +17,160 @@
|
||||||
|
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <cstdio>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "util/basics.h" // MaybeCheckInitialized
|
#include "util/basics.h" // MaybeCheckInitialized
|
||||||
|
#include "hwy/aligned_allocator.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/per_target.h" // VectorBytes
|
||||||
|
|
||||||
#if GEMMA_NUMA
|
// To avoid a dependency on libnuma, use syscalls directly. We require six
|
||||||
#if HWY_OS_WIN
|
// arguments, which has been supported by glibc since around 2010.
|
||||||
#ifndef NOMINMAX
|
#if defined(__GLIBC__) && defined(__GLIBC_PREREQ)
|
||||||
#define NOMINMAX
|
#if HWY_OS_LINUX && __GLIBC_PREREQ(2, 11)
|
||||||
|
#define GEMMA_LINUX_SYSCALL6
|
||||||
#endif
|
#endif
|
||||||
#ifndef WIN32_LEAN_AND_MEAN
|
|
||||||
#define WIN32_LEAN_AND_MEAN
|
|
||||||
#endif
|
#endif
|
||||||
#include <windows.h>
|
|
||||||
#elif HWY_OS_LINUX
|
#ifndef GEMMA_BIND // allow override
|
||||||
|
#if defined(GEMMA_LINUX_SYSCALL6) && !defined(__ANDROID_API__)
|
||||||
|
#define GEMMA_BIND 1
|
||||||
|
#else
|
||||||
|
#define GEMMA_BIND 0
|
||||||
|
#endif
|
||||||
|
#endif // GEMMA_BIND
|
||||||
|
|
||||||
|
#if GEMMA_BIND && HWY_OS_LINUX
|
||||||
|
// `move_pages` requires anonymous/private mappings, hence mmap.
|
||||||
|
#include <sys/mman.h>
|
||||||
#include <sys/syscall.h>
|
#include <sys/syscall.h>
|
||||||
|
|
||||||
#include <cerrno>
|
#include <cerrno>
|
||||||
#endif // HWY_OS_*
|
#endif // GEMMA_BIND && HWY_OS_LINUX
|
||||||
#endif // GEMMA_NUMA
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
namespace {
|
||||||
|
|
||||||
/*static*/ size_t Allocator::bytes_per_page_;
|
size_t DetectLineBytes() {
|
||||||
/*static*/ bool Allocator::use_numa_;
|
if (const hwy::Cache* caches = hwy::DataCaches()) {
|
||||||
/*static*/ size_t Allocator::alignment_;
|
// Might not have an L3.
|
||||||
|
return HWY_MAX(caches[2].bytes_per_line, caches[3].bytes_per_line);
|
||||||
|
} else {
|
||||||
|
return HWY_ALIGNMENT;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/*static*/ size_t Allocator::DetectPageSize() {
|
size_t DetectPageSize() {
|
||||||
#if HWY_OS_WIN
|
#if HWY_OS_LINUX
|
||||||
SYSTEM_INFO sys_info;
|
size_t page_bytes = static_cast<size_t>(sysconf(_SC_PAGESIZE));
|
||||||
GetSystemInfo(&sys_info);
|
HWY_ASSERT(page_bytes <= (4 << 20));
|
||||||
return sys_info.dwPageSize;
|
return page_bytes;
|
||||||
#elif HWY_OS_LINUX
|
|
||||||
return sysconf(_SC_PAGESIZE);
|
|
||||||
#else
|
#else
|
||||||
return 0;
|
return 0;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
#if GEMMA_NUMA && HWY_OS_LINUX
|
} // namespace
|
||||||
|
|
||||||
|
static size_t line_bytes_;
|
||||||
|
static size_t vector_bytes_;
|
||||||
|
static size_t quantum_bytes_;
|
||||||
|
static size_t l1_bytes_;
|
||||||
|
static size_t l2_bytes_;
|
||||||
|
static bool should_bind_ = false;
|
||||||
|
|
||||||
|
size_t Allocator::LineBytes() { return line_bytes_; }
|
||||||
|
size_t Allocator::VectorBytes() { return vector_bytes_; }
|
||||||
|
size_t Allocator::QuantumBytes() { return quantum_bytes_; }
|
||||||
|
size_t Allocator::L1Bytes() { return l1_bytes_; }
|
||||||
|
size_t Allocator::L2Bytes() { return l2_bytes_; }
|
||||||
|
bool Allocator::ShouldBind() { return should_bind_; }
|
||||||
|
|
||||||
|
void Allocator::Init(const BoundedTopology& topology) {
|
||||||
|
line_bytes_ = DetectLineBytes();
|
||||||
|
vector_bytes_ = hwy::VectorBytes();
|
||||||
|
quantum_bytes_ = HWY_MAX(line_bytes_, vector_bytes_); // may overwrite below
|
||||||
|
|
||||||
|
if (const hwy::Cache* caches = hwy::DataCaches()) {
|
||||||
|
l1_bytes_ = caches[1].size_kib << 10;
|
||||||
|
l2_bytes_ = caches[2].size_kib << 10;
|
||||||
|
} else { // Unknown, make reasonable assumptions.
|
||||||
|
const BoundedTopology::Cluster& cluster = topology.GetCluster(0, 0);
|
||||||
|
l1_bytes_ = 32 << 10;
|
||||||
|
l2_bytes_ = (cluster.PrivateKiB() ? cluster.PrivateKiB() : 256) << 10;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prerequisites for binding:
|
||||||
|
// - supported by the OS (currently Linux only),
|
||||||
|
// - the page size is known and 'reasonably small', preferably less than
|
||||||
|
// a fraction of MatMul row/col sizes, which for 27B are up to 144 KiB.
|
||||||
|
// - we successfully detected topology and there are multiple nodes;
|
||||||
|
// - there are multiple packages, because we shard by package_idx.
|
||||||
|
if constexpr (GEMMA_BIND) {
|
||||||
|
const size_t page_bytes = DetectPageSize();
|
||||||
|
if ((page_bytes != 0 && page_bytes <= 16 * 1024) &&
|
||||||
|
topology.NumNodes() > 1 && topology.NumPackages() > 1) {
|
||||||
|
// Ensure pages meet the alignment requirements of `AllocBytes`.
|
||||||
|
HWY_ASSERT(page_bytes >= quantum_bytes_);
|
||||||
|
quantum_bytes_ = page_bytes;
|
||||||
|
should_bind_ = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Allocator::PtrAndDeleter Allocator::AllocBytes(size_t bytes) {
|
||||||
|
// If we are not binding, the Highway allocator is cheaper than `mmap`, and
|
||||||
|
// defends against 2K aliasing.
|
||||||
|
if (!should_bind_) {
|
||||||
|
// Perf warning if Highway's alignment is less than we want.
|
||||||
|
if (HWY_ALIGNMENT < QuantumBytes()) {
|
||||||
|
HWY_WARN(
|
||||||
|
"HWY_ALIGNMENT %d < QuantumBytes %zu: either vector or cache lines "
|
||||||
|
"are huge, enable GEMMA_BIND to avoid this warning.",
|
||||||
|
HWY_ALIGNMENT, QuantumBytes());
|
||||||
|
}
|
||||||
|
auto p = hwy::AllocateAligned<uint8_t>(bytes);
|
||||||
|
// The `hwy::AlignedFreeUniquePtr` deleter is unfortunately specific to the
|
||||||
|
// alignment scheme in aligned_allocator.cc and does not work for
|
||||||
|
// already-aligned pointers as returned by `mmap`, hence we wrap the Highway
|
||||||
|
// pointer in our own deleter.
|
||||||
|
auto call_free = [](void* ptr, size_t /*bytes*/) {
|
||||||
|
hwy::FreeAlignedBytes(ptr, nullptr, nullptr);
|
||||||
|
};
|
||||||
|
return PtrAndDeleter{p.release(), Deleter(call_free, bytes)};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Binding, or large vector/cache line size: use platform-specific allocator.
|
||||||
|
|
||||||
|
#if HWY_OS_LINUX && !defined(__ANDROID_API__)
|
||||||
|
// `move_pages` is documented to require an anonymous/private mapping or
|
||||||
|
// `MAP_SHARED`. A normal allocation might not suffice, so we use `mmap`.
|
||||||
|
// `Init` verified that the page size is a multiple of `QuantumBytes()`.
|
||||||
|
const int prot = PROT_READ | PROT_WRITE;
|
||||||
|
const int flags = MAP_ANONYMOUS | MAP_PRIVATE;
|
||||||
|
const int fd = -1;
|
||||||
|
// Encourage transparent hugepages by rounding up to a multiple of 2 MiB.
|
||||||
|
bytes = hwy::RoundUpTo(bytes, 2ull << 20);
|
||||||
|
void* p = mmap(0, bytes, prot, flags, fd, off_t{0});
|
||||||
|
if (p == MAP_FAILED) p = nullptr;
|
||||||
|
const auto call_munmap = [](void* ptr, size_t bytes) {
|
||||||
|
const int ret = munmap(ptr, bytes);
|
||||||
|
HWY_ASSERT(ret == 0);
|
||||||
|
};
|
||||||
|
return PtrAndDeleter{p, Deleter(call_munmap, bytes)};
|
||||||
|
#elif HWY_OS_WIN
|
||||||
|
const auto call_free = [](void* ptr, void*) { _aligned_free(ptr); };
|
||||||
|
const size_t alignment = HWY_MAX(vector_bytes_, line_bytes_);
|
||||||
|
return PtrAndDeleter{_aligned_malloc(bytes, alignment),
|
||||||
|
Deleter(call_free, bytes)};
|
||||||
|
#else
|
||||||
|
return PtrAndDeleter{nullptr, Deleter(nullptr, 0)};
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
#if GEMMA_BIND && HWY_OS_LINUX
|
||||||
|
|
||||||
using Ret = long; // NOLINT(runtime/int)
|
using Ret = long; // NOLINT(runtime/int)
|
||||||
using UL = unsigned long; // NOLINT(runtime/int)
|
using UL = unsigned long; // NOLINT(runtime/int)
|
||||||
|
|
@ -76,90 +191,91 @@ struct SyscallWrappers {
|
||||||
MaybeCheckInitialized(status, count * sizeof(int));
|
MaybeCheckInitialized(status, count * sizeof(int));
|
||||||
return syscall(__NR_move_pages, pid, count, pages, nodes, status, flags);
|
return syscall(__NR_move_pages, pid, count, pages, nodes, status, flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Ret get_mempolicy(int* mode, UL* nodes, UL max_node, void* addr,
|
||||||
|
unsigned flags) {
|
||||||
|
return syscall(__NR_get_mempolicy, mode, nodes, max_node, addr, flags);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Returns the number of pages that are currently busy (hence not yet moved),
|
||||||
|
// and warns if there are any other reasons for not moving a page. Note that
|
||||||
|
// `move_pages` can return 0 regardless of whether all pages were moved.
|
||||||
size_t CountBusyPages(size_t num_pages, size_t node, void** pages,
|
size_t CountBusyPages(size_t num_pages, size_t node, void** pages,
|
||||||
const int* status) {
|
const int* status) {
|
||||||
// Return value 0 does not actually guarantee all pages were moved.
|
|
||||||
size_t num_busy = 0;
|
size_t num_busy = 0;
|
||||||
for (size_t i = 0; i < num_pages; ++i) {
|
for (size_t i = 0; i < num_pages; ++i) {
|
||||||
if (status[i] == -EBUSY) {
|
if (status[i] == -EBUSY) {
|
||||||
++num_busy;
|
++num_busy;
|
||||||
// Touch
|
|
||||||
hwy::ZeroBytes(pages[i], 8);
|
|
||||||
} else if (status[i] != static_cast<int>(node)) {
|
} else if (status[i] != static_cast<int>(node)) {
|
||||||
fprintf(stderr, "Error %d moving pages[%zu]=%p to node %zu (errno %d)\n",
|
static std::atomic_flag first = ATOMIC_FLAG_INIT;
|
||||||
status[i], i, pages[i], node, errno);
|
if (!first.test_and_set()) {
|
||||||
|
HWY_WARN("Error %d moving pages[%zu]=%p to node %zu (errno %d).",
|
||||||
|
status[i], i, pages[i], node, errno);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return num_busy;
|
return num_busy;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attempts to move(!) memory to the given NUMA node, typically obtained from
|
bool Allocator::BindMemory(void* ptr, size_t bytes, size_t node) {
|
||||||
// `BoundedTopology::GetCluster(package_idx, cluster_idx).node`. Using `mbind`
|
HWY_DASSERT(should_bind_);
|
||||||
// directly is easier than calling libnuma's `numa_move_pages`, which requires
|
|
||||||
// an array of pages. Note that `numa_tonode_memory` is insufficient because
|
|
||||||
// it does not specify the `MPOL_MF_MOVE` flag, so it only sets the policy,
|
|
||||||
// which means it would have to be called before pages are faulted in, but
|
|
||||||
// `aligned_allocator.h` modifies the first bytes for its bookkeeping.
|
|
||||||
// May overwrite some of the memory with zeros.
|
|
||||||
void BindMemory(void* ptr, size_t bytes, size_t node) {
|
|
||||||
constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough"
|
constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough"
|
||||||
|
|
||||||
|
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||||
|
// Ensure the requested `node` is allowed.
|
||||||
|
UL nodes[kMaxNodes / 64] = {0};
|
||||||
|
const unsigned flags = 4; // MPOL_F_MEMS_ALLOWED
|
||||||
|
HWY_ASSERT(SyscallWrappers::get_mempolicy(nullptr, nodes, kMaxNodes,
|
||||||
|
nullptr, flags) == 0);
|
||||||
|
HWY_ASSERT(nodes[node / 64] & (1ull << (node % 64)));
|
||||||
|
}
|
||||||
|
|
||||||
// Avoid mbind because it does not report why it failed, which is most likely
|
// Avoid mbind because it does not report why it failed, which is most likely
|
||||||
// because pages are busy, in which case we want to know which.
|
// because pages are busy, in which case we want to know which.
|
||||||
#if 0
|
|
||||||
// nodemask with only the given node set.
|
|
||||||
UL nodes[hwy::DivCeil(kMaxNodes, ULBits)] = {};
|
|
||||||
nodes[node / ULBits] = 1ULL << (node % ULBits);
|
|
||||||
|
|
||||||
const int mode = 2; // MPOL_BIND
|
// `MPOL_MF_MOVE_ALL` requires cap sys_nice, which is not easy to set.
|
||||||
const unsigned flags = 3; // MPOL_MF_MOVE | MPOL_MF_STRICT
|
|
||||||
const int ret =
|
|
||||||
SyscallWrappers::mbind(ptr, bytes, mode, nodes, kMaxNodes, flags);
|
|
||||||
if (ret != 0) {
|
|
||||||
fprintf(stderr, "Failed to bind %p %zu to node %zu (errno %d)\n", ptr,
|
|
||||||
bytes, node, errno);
|
|
||||||
}
|
|
||||||
#elif 1
|
|
||||||
const unsigned flags = 2; // MPOL_MF_MOVE
|
const unsigned flags = 2; // MPOL_MF_MOVE
|
||||||
const size_t bytes_per_page = static_cast<size_t>(sysconf(_SC_PAGESIZE));
|
HWY_ASSERT(bytes % quantum_bytes_ == 0);
|
||||||
HWY_ASSERT(bytes % bytes_per_page == 0);
|
const size_t num_pages = bytes / quantum_bytes_;
|
||||||
const size_t num_pages = bytes / bytes_per_page;
|
|
||||||
std::vector<void*> pages;
|
std::vector<void*> pages;
|
||||||
pages.reserve(num_pages);
|
pages.reserve(num_pages);
|
||||||
for (size_t i = 0; i < num_pages; ++i) {
|
for (size_t i = 0; i < num_pages; ++i) {
|
||||||
pages.push_back(static_cast<uint8_t*>(ptr) + i * bytes_per_page);
|
pages.push_back(static_cast<uint8_t*>(ptr) + i * quantum_bytes_);
|
||||||
|
// Ensure the page is faulted in to prevent `move_pages` from failing,
|
||||||
|
// because freshly allocated pages may be mapped to a shared 'zero page'.
|
||||||
|
hwy::ZeroBytes(pages.back(), 8);
|
||||||
}
|
}
|
||||||
std::vector<int> nodes(num_pages, node);
|
std::vector<int> nodes(num_pages, node);
|
||||||
std::vector<int> status(num_pages, static_cast<int>(kMaxNodes));
|
std::vector<int> status(num_pages, static_cast<int>(kMaxNodes));
|
||||||
|
|
||||||
Ret ret = SyscallWrappers::move_pages(
|
Ret ret = SyscallWrappers::move_pages(
|
||||||
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
|
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
|
||||||
size_t num_busy =
|
if (ret < 0) {
|
||||||
CountBusyPages(num_pages, node, pages.data(), status.data());
|
HWY_WARN("Failed to bind %p %zu to node %zu (errno %d) status %d.", ptr,
|
||||||
if (num_busy != 0) {
|
bytes, node, errno, status[0]);
|
||||||
// Try again
|
return false;
|
||||||
ret = SyscallWrappers::move_pages(
|
|
||||||
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
|
|
||||||
const size_t num_busy_before = num_busy;
|
|
||||||
num_busy = CountBusyPages(num_pages, node, pages.data(), status.data());
|
|
||||||
fprintf(
|
|
||||||
stderr,
|
|
||||||
"second try still %zu busy, was %zu. 2nd ret %d status %d %d %d %d\n",
|
|
||||||
num_busy, num_busy_before, static_cast<int>(ret), status[0], status[1],
|
|
||||||
status[2], status[3]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ret < 0) {
|
const size_t num_busy =
|
||||||
fprintf(stderr,
|
CountBusyPages(num_pages, node, pages.data(), status.data());
|
||||||
"Failed to bind %p %zu to node %zu (errno %d) status %d %d\n", ptr,
|
if (HWY_UNLIKELY(num_busy != 0)) {
|
||||||
bytes, node, errno, status[0], status[1]);
|
// Trying again is usually enough to succeed.
|
||||||
|
usleep(5); // NOLINT(runtime/sleep)
|
||||||
|
(void)SyscallWrappers::move_pages(
|
||||||
|
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
|
||||||
|
const size_t still_busy =
|
||||||
|
CountBusyPages(num_pages, node, pages.data(), status.data());
|
||||||
|
if (HWY_UNLIKELY(still_busy != 0)) {
|
||||||
|
HWY_WARN("BindMemory: %zu pages still busy after retrying %zu.",
|
||||||
|
still_busy, num_busy);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#endif
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
// TODO: support other OSes.
|
bool Allocator::BindMemory(void*, size_t, size_t) { return false; }
|
||||||
void BindMemory(void*, size_t, size_t) {}
|
#endif // GEMMA_BIND && HWY_OS_LINUX
|
||||||
#endif // GEMMA_NUMA && HWY_OS_LINUX
|
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
280
util/allocator.h
280
util/allocator.h
|
|
@ -19,114 +19,232 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
#include <cstdlib> // std::aligned_alloc / _aligned_malloc
|
|
||||||
|
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
#include "util/basics.h"
|
#include "util/basics.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
|
|
||||||
#ifndef GEMMA_NUMA
|
#include "hwy/aligned_allocator.h"
|
||||||
// The check below requires two #if, hence start with 0 and redefine to 1.
|
|
||||||
#define GEMMA_NUMA 0
|
|
||||||
|
|
||||||
// To avoid a dependency on libnuma, use syscalls directly. We require six
|
|
||||||
// arguments, which has been supported by glibc since around 2010.
|
|
||||||
#if defined(__GLIBC__) && defined(__GLIBC_PREREQ)
|
|
||||||
#if HWY_OS_LINUX && __GLIBC_PREREQ(2, 11)
|
|
||||||
#undef GEMMA_NUMA
|
|
||||||
#define GEMMA_NUMA 1
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#endif // GEMMA_NUMA
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
using ByteStorageT = hwy::AlignedFreeUniquePtr<uint8_t[]>;
|
// Points to an adapter lambda that calls `FreeAlignedBytes` or `munmap`. The
|
||||||
|
// `bytes` argument is required for the latter.
|
||||||
|
using FreeFunc = void (*)(void* mem, size_t bytes);
|
||||||
|
|
||||||
template <typename T>
|
// Custom deleter for std::unique_ptr that calls `FreeFunc`.
|
||||||
ByteStorageT AllocateSizeof() {
|
class Deleter {
|
||||||
return hwy::AllocateAligned<uint8_t>(sizeof(T));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stateful in order to know whether to bind to NUMA nodes. `Monostate` for
|
|
||||||
// convenience - avoids passing around a reference.
|
|
||||||
class Allocator {
|
|
||||||
public:
|
public:
|
||||||
static void Init(const BoundedTopology& topology) {
|
// `MatStorageT` requires this to be default-constructible.
|
||||||
bytes_per_page_ = DetectPageSize();
|
Deleter() : free_func_(nullptr), bytes_(0) {}
|
||||||
HWY_ASSERT(bytes_per_page_ <= (4 << 20));
|
Deleter(FreeFunc free_func, size_t bytes)
|
||||||
|
: free_func_(free_func), bytes_(bytes) {}
|
||||||
// NUMA only makes sense if:
|
|
||||||
// - the page size is known and 'reasonably small', preferably less than
|
|
||||||
// a fraction of MatMul row/col sizes, which for 27B are up to 144 KiB.
|
|
||||||
// - we successfully detected topology and there are multiple nodes;
|
|
||||||
// - there are multiple packages, because we shard by package_idx.
|
|
||||||
use_numa_ = (bytes_per_page_ != 0 && bytes_per_page_ <= 16 * 1024) &&
|
|
||||||
topology.NumNodes() > 1 && topology.NumPackages() > 1;
|
|
||||||
// TODO: remove once tensors are page-aligned.
|
|
||||||
use_numa_ = false;
|
|
||||||
fprintf(stderr, "Warning: disabling use_numa_\n");
|
|
||||||
|
|
||||||
alignment_ = use_numa_ ? bytes_per_page_ : HWY_ALIGNMENT;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool UseNUMA() { return use_numa_; }
|
|
||||||
|
|
||||||
// BindTensor requires row pointers and lengths be a multiple of this.
|
|
||||||
static size_t Alignment() { return alignment_; }
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static hwy::AlignedFreeUniquePtr<T[]> Alloc(size_t num) {
|
void operator()(T* p) const {
|
||||||
// For non-NUMA, use the Highway allocator because it defends against 2k
|
free_func_(p, bytes_);
|
||||||
// aliasing.
|
}
|
||||||
if (!use_numa_) return hwy::AllocateAligned<T>(num);
|
|
||||||
|
|
||||||
|
private:
|
||||||
|
FreeFunc free_func_;
|
||||||
|
size_t bytes_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Unique (move-only) pointer to an aligned array of POD T.
|
||||||
|
template <typename T>
|
||||||
|
using AlignedPtr = std::unique_ptr<T[], Deleter>;
|
||||||
|
|
||||||
|
// Both allocation, binding, and row accessors depend on the sizes of memory
|
||||||
|
// pages and cache lines. To avoid having to pass `Allocator&` everywhere, we
|
||||||
|
// use `Monostate` (static members).
|
||||||
|
class Allocator {
|
||||||
|
public:
|
||||||
|
// Must be called at least once before any other function. Not thread-safe,
|
||||||
|
// hence only call this from the main thread.
|
||||||
|
static void Init(const BoundedTopology& topology);
|
||||||
|
|
||||||
|
// Bytes per cache line, or a reasonable guess if unknown. Used to choose
|
||||||
|
// ranges such that there will be no false sharing.
|
||||||
|
static size_t LineBytes();
|
||||||
|
// Bytes per full vector. Used to compute loop steps.
|
||||||
|
static size_t VectorBytes();
|
||||||
|
// Granularity of regions processed by different threads. Their start and
|
||||||
|
// length of regions should be divisible by this, which is at least
|
||||||
|
// `HWY_MAX(LineBytes(), VectorBytes())`.
|
||||||
|
static size_t QuantumBytes();
|
||||||
|
static size_t L1Bytes();
|
||||||
|
static size_t L2Bytes();
|
||||||
|
|
||||||
|
// Returns pointer aligned to `QuantumBytes()`.
|
||||||
|
template <typename T>
|
||||||
|
static AlignedPtr<T> Alloc(size_t num) {
|
||||||
constexpr size_t kSize = sizeof(T);
|
constexpr size_t kSize = sizeof(T);
|
||||||
// Ensure the `bytes = num * kSize` computation did not overflow.
|
|
||||||
constexpr bool kIsPow2 = (kSize & (kSize - 1)) == 0;
|
constexpr bool kIsPow2 = (kSize & (kSize - 1)) == 0;
|
||||||
constexpr size_t kBits = hwy::detail::ShiftCount(kSize);
|
constexpr size_t kBits = hwy::detail::ShiftCount(kSize);
|
||||||
static_assert(!kIsPow2 || (1ull << kBits) == kSize, "ShiftCount has a bug");
|
static_assert(!kIsPow2 || (1ull << kBits) == kSize, "ShiftCount has a bug");
|
||||||
const size_t bytes = kIsPow2 ? num << kBits : num * kSize;
|
const size_t bytes = kIsPow2 ? num << kBits : num * kSize;
|
||||||
|
// Fail if the `bytes = num * kSize` computation overflowed.
|
||||||
const size_t check = kIsPow2 ? bytes >> kBits : bytes / kSize;
|
const size_t check = kIsPow2 ? bytes >> kBits : bytes / kSize;
|
||||||
if (check != num) {
|
if (check != num) return AlignedPtr<T>();
|
||||||
return hwy::AlignedFreeUniquePtr<T[]>(); // overflowed
|
|
||||||
}
|
|
||||||
|
|
||||||
// AlignedFreeUniquePtr has a deleter that can call an arbitrary `free`, but
|
PtrAndDeleter pd = AllocBytes(bytes);
|
||||||
// with an extra opaque pointer, which we discard via `call_free`.
|
return AlignedPtr<T>(static_cast<T*>(pd.p), pd.deleter);
|
||||||
#if defined(__ANDROID_API__) && __ANDROID_API__ < 28
|
}
|
||||||
const auto call_free = [](void* ptr, void*) { std::free(ptr); };
|
|
||||||
void* mem = nullptr;
|
// Returns whether `BindMemory` can/should be called, i.e. we have page-level
|
||||||
int err = posix_memalign(&mem, Alignment(), bytes);
|
// control over memory placement and multiple packages and NUMA nodes.
|
||||||
HWY_ASSERT(err == 0);
|
static bool ShouldBind();
|
||||||
T* p = static_cast<T*>(mem);
|
|
||||||
#elif HWY_OS_WIN
|
// Attempts to move(!) `[p, p + bytes)` to the given NUMA node, which is
|
||||||
const auto call_free = [](void* ptr, void*) { _aligned_free(ptr); };
|
// typically `BoundedTopology::GetCluster(package_idx, cluster_idx).node`.
|
||||||
T* p = static_cast<T*>(_aligned_malloc(bytes, Alignment()));
|
// Writes zeros to SOME of the memory. Only call if `ShouldBind()`.
|
||||||
#else
|
// `p` and `bytes` must be multiples of `QuantumBytes()`.
|
||||||
const auto call_free = [](void* ptr, void*) { std::free(ptr); };
|
static bool BindMemory(void* p, size_t bytes, size_t node);
|
||||||
T* p = static_cast<T*>(std::aligned_alloc(Alignment(), bytes));
|
|
||||||
#endif
|
private:
|
||||||
return hwy::AlignedFreeUniquePtr<T[]>(
|
// Type-erased so this can be implemented in allocator.cc.
|
||||||
p, hwy::AlignedFreer(call_free, nullptr));
|
struct PtrAndDeleter {
|
||||||
|
void* p;
|
||||||
|
Deleter deleter;
|
||||||
|
};
|
||||||
|
static PtrAndDeleter AllocBytes(size_t bytes);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Owns dynamically-allocated aligned memory for a batch of row vectors.
|
||||||
|
// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns
|
||||||
|
// the memory.
|
||||||
|
template <typename T>
|
||||||
|
class RowVectorBatch {
|
||||||
|
public:
|
||||||
|
// Default ctor for Activations ctor.
|
||||||
|
RowVectorBatch() = default;
|
||||||
|
// Main ctor, called from Activations::Allocate. If `stride` = 0, the default,
|
||||||
|
// we default to tightly packed rows (`stride = cols`).
|
||||||
|
// WARNING: not all call sites support `stride` != cols.
|
||||||
|
RowVectorBatch(Extents2D extents, size_t stride = 0) : extents_(extents) {
|
||||||
|
if (stride == 0) {
|
||||||
|
stride_ = extents_.cols;
|
||||||
|
} else {
|
||||||
|
HWY_ASSERT(stride >= extents_.cols);
|
||||||
|
stride_ = stride;
|
||||||
|
}
|
||||||
|
mem_ = Allocator::Alloc<T>(extents_.rows * stride_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move-only
|
||||||
|
RowVectorBatch(RowVectorBatch&) noexcept = delete;
|
||||||
|
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
|
||||||
|
RowVectorBatch(RowVectorBatch&&) noexcept = default;
|
||||||
|
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
|
||||||
|
|
||||||
|
size_t BatchSize() const { return extents_.rows; }
|
||||||
|
size_t Cols() const { return extents_.cols; }
|
||||||
|
size_t Stride() const { return stride_; }
|
||||||
|
Extents2D Extents() const { return extents_; }
|
||||||
|
|
||||||
|
// Returns the given row vector of length `Cols()`.
|
||||||
|
T* Batch(size_t batch_idx) {
|
||||||
|
HWY_DASSERT(batch_idx < BatchSize());
|
||||||
|
return mem_.get() + batch_idx * stride_;
|
||||||
|
}
|
||||||
|
const T* Batch(size_t batch_idx) const {
|
||||||
|
HWY_DASSERT(batch_idx < BatchSize());
|
||||||
|
return mem_.get() + batch_idx * stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For MatMul or other operations that process the entire batch at once.
|
||||||
|
// TODO: remove once we only use Mat.
|
||||||
|
T* All() { return mem_.get(); }
|
||||||
|
const T* Const() const { return mem_.get(); }
|
||||||
|
size_t NumBytes() const { return BatchSize() * stride_ * sizeof(T); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
AlignedPtr<T> mem_;
|
||||||
|
Extents2D extents_;
|
||||||
|
size_t stride_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Returns `num` rounded up to an odd number of cache lines. This is used to
|
||||||
|
// compute strides. An odd number of cache lines prevents 2K aliasing and is
|
||||||
|
// coprime with the cache associativity, which reduces conflict misses.
|
||||||
|
template <typename T>
|
||||||
|
static HWY_INLINE size_t RoundUpToOddLines(size_t num, size_t line_bytes) {
|
||||||
|
HWY_DASSERT(line_bytes >= 32);
|
||||||
|
HWY_DASSERT(line_bytes % sizeof(T) == 0);
|
||||||
|
const size_t lines = hwy::DivCeil(num * sizeof(T), line_bytes);
|
||||||
|
const size_t padded_num = (lines | 1) * line_bytes / sizeof(T);
|
||||||
|
HWY_DASSERT(padded_num >= num);
|
||||||
|
return padded_num;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because
|
||||||
|
// it is always float and does not support compressed T, but does support an
|
||||||
|
// arbitrary stride >= cols.
|
||||||
|
#pragma pack(push, 1) // power of two size
|
||||||
|
template <typename T>
|
||||||
|
class RowPtr {
|
||||||
|
public:
|
||||||
|
RowPtr() = default; // for `MMPtrs`.
|
||||||
|
RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride)
|
||||||
|
: row0_(row0),
|
||||||
|
stride_(stride),
|
||||||
|
step_(static_cast<uint32_t>(
|
||||||
|
HWY_MAX(Allocator::LineBytes(), Allocator::VectorBytes()))),
|
||||||
|
cols_(static_cast<uint32_t>(cols)),
|
||||||
|
row_mask_(Allocator::QuantumBytes() / step_ - 1) {
|
||||||
|
HWY_DASSERT(stride >= cols);
|
||||||
|
HWY_DASSERT(row_mask_ != ~size_t{0});
|
||||||
|
row_mask_ = 0; // TODO: remove
|
||||||
|
}
|
||||||
|
RowPtr(T* HWY_RESTRICT row0, size_t cols) : RowPtr(row0, cols, cols) {}
|
||||||
|
|
||||||
|
T* HWY_RESTRICT Row(size_t r) const {
|
||||||
|
// How much of the previous row's padding to consume.
|
||||||
|
const size_t pad_bytes = (r & row_mask_) * step_;
|
||||||
|
HWY_DASSERT(pad_bytes < Allocator::QuantumBytes());
|
||||||
|
return row0_ + stride_ * r - pad_bytes;
|
||||||
|
}
|
||||||
|
size_t Cols() const { return cols_; }
|
||||||
|
|
||||||
|
size_t Stride() const { return stride_; }
|
||||||
|
void SetStride(size_t stride) {
|
||||||
|
HWY_DASSERT(stride >= Cols());
|
||||||
|
stride_ = stride;
|
||||||
|
// The caller might not have padded enough, so disable the padding in Row().
|
||||||
|
// Rows will now be exactly `stride` elements apart. This is used when
|
||||||
|
// writing to the KV cache via MatMul.
|
||||||
|
row_mask_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
||||||
|
RowPtr<T> View(size_t r, size_t c, size_t cols) const {
|
||||||
|
HWY_DASSERT(c < cols_);
|
||||||
|
HWY_DASSERT(cols <= cols_ - c);
|
||||||
|
return RowPtr<T>(Row(r) + c, cols, stride_);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static size_t DetectPageSize();
|
T* HWY_RESTRICT row0_;
|
||||||
|
size_t stride_;
|
||||||
// Required for BindMemory. Usually 4K, but can differ on Arm.
|
uint32_t step_; // Copy from Allocator::LineBytes() to improve locality.
|
||||||
static size_t bytes_per_page_;
|
uint32_t cols_;
|
||||||
static bool use_numa_;
|
size_t row_mask_;
|
||||||
static size_t alignment_;
|
|
||||||
};
|
};
|
||||||
|
#pragma pack(pop)
|
||||||
|
|
||||||
// For future NUMA support. TODO: use.
|
using RowPtrBF = RowPtr<BF16>;
|
||||||
void BindMemory(void* ptr, size_t bytes, size_t node);
|
using RowPtrF = RowPtr<float>;
|
||||||
|
using RowPtrD = RowPtr<double>;
|
||||||
|
|
||||||
|
// For C argument to MatMul.
|
||||||
|
template <typename T>
|
||||||
|
RowPtr<T> RowPtrFromBatch(RowVectorBatch<T>& row_vectors) {
|
||||||
|
return RowPtr<T>(row_vectors.All(), row_vectors.Cols(), row_vectors.Stride());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
|
|
|
||||||
143
util/basics.h
143
util/basics.h
|
|
@ -64,8 +64,8 @@ struct TokenAndProb {
|
||||||
|
|
||||||
// Entire size of a 2D array.
|
// Entire size of a 2D array.
|
||||||
struct Extents2D {
|
struct Extents2D {
|
||||||
Extents2D() : rows(0), cols(0) {}
|
constexpr Extents2D() : rows(0), cols(0) {}
|
||||||
Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) {
|
constexpr Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) {
|
||||||
HWY_DASSERT(rows != 0);
|
HWY_DASSERT(rows != 0);
|
||||||
HWY_DASSERT(cols != 0);
|
HWY_DASSERT(cols != 0);
|
||||||
}
|
}
|
||||||
|
|
@ -77,6 +77,7 @@ struct Extents2D {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct IndexRange {
|
struct IndexRange {
|
||||||
|
IndexRange() = default;
|
||||||
IndexRange(size_t begin, size_t end) : begin_(begin), end_(end) {
|
IndexRange(size_t begin, size_t end) : begin_(begin), end_(end) {
|
||||||
HWY_DASSERT(begin < end);
|
HWY_DASSERT(begin < end);
|
||||||
}
|
}
|
||||||
|
|
@ -113,144 +114,6 @@ static inline IndexRange MakeIndexRange(size_t begin, size_t end,
|
||||||
size_t max_size) {
|
size_t max_size) {
|
||||||
return IndexRange(begin, HWY_MIN(begin + max_size, end));
|
return IndexRange(begin, HWY_MIN(begin + max_size, end));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because
|
|
||||||
// it is always float and does not support compressed T, but does support an
|
|
||||||
// arbitrary stride >= cols.
|
|
||||||
template <typename T>
|
|
||||||
class RowPtr {
|
|
||||||
public:
|
|
||||||
RowPtr(T* HWY_RESTRICT row0, size_t cols)
|
|
||||||
: row0_(row0), cols_(cols), stride_(cols) {}
|
|
||||||
RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride)
|
|
||||||
: row0_(row0), cols_(cols), stride_(stride) {
|
|
||||||
HWY_DASSERT(stride >= cols);
|
|
||||||
}
|
|
||||||
|
|
||||||
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
|
|
||||||
size_t Cols() const { return cols_; }
|
|
||||||
|
|
||||||
size_t Stride() const { return stride_; }
|
|
||||||
void SetStride(size_t stride) {
|
|
||||||
HWY_DASSERT(stride >= Cols());
|
|
||||||
stride_ = stride;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
T* HWY_RESTRICT row0_;
|
|
||||||
size_t stride_;
|
|
||||||
size_t cols_;
|
|
||||||
};
|
|
||||||
|
|
||||||
using RowPtrF = RowPtr<float>;
|
|
||||||
|
|
||||||
// Owns dynamically-allocated aligned memory for a batch of row vectors.
|
|
||||||
// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns
|
|
||||||
// the memory.
|
|
||||||
template <typename T>
|
|
||||||
class RowVectorBatch {
|
|
||||||
public:
|
|
||||||
// Default ctor for Activations ctor.
|
|
||||||
RowVectorBatch() = default;
|
|
||||||
// Main ctor, called from Activations::Allocate.
|
|
||||||
RowVectorBatch(Extents2D extents) : extents_(extents) {
|
|
||||||
mem_ = hwy::AllocateAligned<T>(extents_.rows * extents_.cols);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Move-only
|
|
||||||
RowVectorBatch(RowVectorBatch&) noexcept = delete;
|
|
||||||
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
|
|
||||||
RowVectorBatch(RowVectorBatch&&) noexcept = default;
|
|
||||||
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
|
|
||||||
|
|
||||||
size_t BatchSize() const { return extents_.rows; }
|
|
||||||
size_t Cols() const { return extents_.cols; }
|
|
||||||
Extents2D Extents() const { return extents_; }
|
|
||||||
|
|
||||||
// Returns the given row vector of length `Cols()`.
|
|
||||||
T* Batch(size_t batch_idx) {
|
|
||||||
HWY_DASSERT(batch_idx < BatchSize());
|
|
||||||
return mem_.get() + batch_idx * Cols();
|
|
||||||
}
|
|
||||||
const T* Batch(size_t batch_idx) const {
|
|
||||||
HWY_DASSERT(batch_idx < BatchSize());
|
|
||||||
return mem_.get() + batch_idx * Cols();
|
|
||||||
}
|
|
||||||
|
|
||||||
// For MatMul or other operations that process the entire batch at once.
|
|
||||||
// TODO: remove once we only use Mat.
|
|
||||||
T* All() { return mem_.get(); }
|
|
||||||
const T* Const() const { return mem_.get(); }
|
|
||||||
size_t NumBytes() const { return BatchSize() * Cols() * sizeof(T); }
|
|
||||||
|
|
||||||
private:
|
|
||||||
hwy::AlignedFreeUniquePtr<T[]> mem_;
|
|
||||||
Extents2D extents_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Used for the A and B arguments of `MatMul`, which are always const.
|
|
||||||
// Create via MakeConstMat. This differs from `RowPtr` in that it supports the
|
|
||||||
// `ofs` required for compressed T.
|
|
||||||
template <typename T>
|
|
||||||
struct ConstMat {
|
|
||||||
ConstMat(const T* ptr, Extents2D extents, size_t ofs = 0)
|
|
||||||
: ptr(ptr), extents(extents), ofs(ofs) {
|
|
||||||
HWY_DASSERT(ptr != nullptr);
|
|
||||||
}
|
|
||||||
// TODO: support stride for page alignment.
|
|
||||||
size_t Row(size_t r) const {
|
|
||||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
|
||||||
if (r >= extents.rows) {
|
|
||||||
HWY_ABORT("ConstMat::Row %zu out of bounds %zu", r, extents.rows);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ofs + extents.cols * r;
|
|
||||||
}
|
|
||||||
|
|
||||||
const Extents2D& Extents() const { return extents; }
|
|
||||||
size_t Stride() const { return extents.cols; }
|
|
||||||
|
|
||||||
// Shrinks the row-extent of this matrix view, i.e. reduces the view to a
|
|
||||||
// subrange of the original rows starting at row 0.
|
|
||||||
void ShrinkRows(size_t rows) {
|
|
||||||
HWY_ASSERT(rows <= extents.rows);
|
|
||||||
extents.rows = rows;
|
|
||||||
}
|
|
||||||
|
|
||||||
const T* HWY_RESTRICT ptr;
|
|
||||||
Extents2D extents;
|
|
||||||
|
|
||||||
// `scale` allows expanding the smaller range of `SfpStream` to the original
|
|
||||||
// values. MatFromWeights sets this from `MatPtr`.
|
|
||||||
float scale = 1.0f;
|
|
||||||
|
|
||||||
// Offset to add to `ptr`; separate because T=NuqStream does not support
|
|
||||||
// pointer arithmetic.
|
|
||||||
size_t ofs;
|
|
||||||
};
|
|
||||||
|
|
||||||
// For deducing T.
|
|
||||||
template <typename T>
|
|
||||||
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents,
|
|
||||||
size_t ofs = 0) {
|
|
||||||
return ConstMat<T>(ptr, extents, ofs);
|
|
||||||
}
|
|
||||||
|
|
||||||
// For A argument to MatMul (activations).
|
|
||||||
template <typename T>
|
|
||||||
ConstMat<T> ConstMatFromBatch(size_t batch_size,
|
|
||||||
const RowVectorBatch<T>& row_vectors) {
|
|
||||||
HWY_DASSERT(batch_size <= row_vectors.BatchSize());
|
|
||||||
return MakeConstMat(const_cast<T*>(row_vectors.Const()),
|
|
||||||
Extents2D(batch_size, row_vectors.Cols()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// For C argument to MatMul.
|
|
||||||
template <typename T>
|
|
||||||
RowPtr<T> RowPtrFromBatch(RowVectorBatch<T>& row_vectors) {
|
|
||||||
return RowPtr<T>(row_vectors.All(), row_vectors.Cols());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_
|
||||||
|
|
|
||||||
|
|
@ -55,10 +55,8 @@ class Pinning {
|
||||||
LPS enabled_lps;
|
LPS enabled_lps;
|
||||||
if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) {
|
if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) {
|
||||||
const size_t num_lps = hwy::TotalLogicalProcessors();
|
const size_t num_lps = hwy::TotalLogicalProcessors();
|
||||||
fprintf(
|
HWY_WARN("unknown OS affinity, considering all %zu LPs enabled.",
|
||||||
stderr,
|
num_lps);
|
||||||
"Warning, unknown OS affinity, considering all %zu LPs enabled\n.",
|
|
||||||
num_lps);
|
|
||||||
for (size_t lp = 0; lp < num_lps; ++lp) {
|
for (size_t lp = 0; lp < num_lps; ++lp) {
|
||||||
enabled_lps.Set(lp);
|
enabled_lps.Set(lp);
|
||||||
}
|
}
|
||||||
|
|
@ -71,8 +69,7 @@ class Pinning {
|
||||||
const size_t lp = enabled_lps.First();
|
const size_t lp = enabled_lps.First();
|
||||||
enabled_lps = LPS();
|
enabled_lps = LPS();
|
||||||
enabled_lps.Set(lp);
|
enabled_lps.Set(lp);
|
||||||
fprintf(stderr,
|
HWY_WARN("Warning, threads not supported, using only the main thread.");
|
||||||
"Warning, threads not supported, using only the main thread\n.");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
original_affinity_ = enabled_lps;
|
original_affinity_ = enabled_lps;
|
||||||
|
|
@ -155,23 +152,10 @@ BoundedTopology::BoundedTopology(BoundedSlice package_slice,
|
||||||
HWY_ASSERT(NumPackages() != 0 && NumClusters(0) != 0 && NumNodes() != 0);
|
HWY_ASSERT(NumPackages() != 0 && NumClusters(0) != 0 && NumNodes() != 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Topology is unknown, rely on OS affinity and user-specified slice.
|
// Topology is unknown, take the given set of LPs.
|
||||||
BoundedTopology::Cluster::Cluster(const LPS& enabled_lps,
|
BoundedTopology::Cluster::Cluster(const LPS& lps) {
|
||||||
BoundedSlice lp_slice) {
|
lps_ = lps;
|
||||||
// Interpret `lp_slice` as a slice of the 1-bits of `enabled_lps`, so
|
num_workers_ = lps.Count();
|
||||||
// we honor both the OS affinity and the user-specified slice. Note that
|
|
||||||
// this can be used to exclude hyperthreads because Linux groups LPs by
|
|
||||||
// sibling index. For example, the first `num_cores` are not siblings.
|
|
||||||
const size_t detected = enabled_lps.Count();
|
|
||||||
size_t enabled_idx = 0;
|
|
||||||
enabled_lps.Foreach([&](size_t lp) {
|
|
||||||
if (lp_slice.Contains(detected, enabled_idx++)) {
|
|
||||||
AddLP(lp);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// lp_slice can only reduce the number of `enabled_lps`, and not below 1.
|
|
||||||
HWY_ASSERT(num_workers_ != 0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BoundedTopology::Cluster::Cluster(const LPS& enabled_lps,
|
BoundedTopology::Cluster::Cluster(const LPS& enabled_lps,
|
||||||
|
|
@ -183,7 +167,9 @@ BoundedTopology::Cluster::Cluster(const LPS& enabled_lps,
|
||||||
// Skip if not first-hyperthread or disabled.
|
// Skip if not first-hyperthread or disabled.
|
||||||
if (all_lps[lp].smt != 0 || !enabled_lps.Get(lp)) return;
|
if (all_lps[lp].smt != 0 || !enabled_lps.Get(lp)) return;
|
||||||
|
|
||||||
AddLP(lp);
|
HWY_ASSERT(!lps_.Get(lp)); // Foreach ensures uniqueness
|
||||||
|
lps_.Set(lp);
|
||||||
|
++num_workers_;
|
||||||
|
|
||||||
// Set fields once, and ensure subsequent LPs match - we assume there
|
// Set fields once, and ensure subsequent LPs match - we assume there
|
||||||
// is only one NUMA node per cluster, with the same L2/L3 size.
|
// is only one NUMA node per cluster, with the same L2/L3 size.
|
||||||
|
|
@ -198,30 +184,63 @@ BoundedTopology::Cluster::Cluster(const LPS& enabled_lps,
|
||||||
if (HWY_LIKELY(!warned)) {
|
if (HWY_LIKELY(!warned)) {
|
||||||
if (HWY_UNLIKELY(lp_node != node_)) {
|
if (HWY_UNLIKELY(lp_node != node_)) {
|
||||||
warned = true;
|
warned = true;
|
||||||
fprintf(stderr, "WARNING: lp %zu on node %zu != cluster node %zu.\n",
|
HWY_WARN("lp %zu on node %zu != cluster node %zu.", lp, lp_node,
|
||||||
lp, lp_node, node_);
|
node_);
|
||||||
}
|
}
|
||||||
if (HWY_UNLIKELY(private_kib_ != tcluster.private_kib)) {
|
if (HWY_UNLIKELY(private_kib_ != tcluster.private_kib)) {
|
||||||
warned = true;
|
warned = true;
|
||||||
fprintf(stderr, "WARNING: lp %zu private_kib %zu != cluster %zu.\n",
|
HWY_WARN("lp %zu private_kib %zu != cluster %zu.", lp, private_kib_,
|
||||||
lp, private_kib_, tcluster.private_kib);
|
tcluster.private_kib);
|
||||||
}
|
}
|
||||||
if (HWY_UNLIKELY(shared_kib_ != tcluster.shared_kib)) {
|
if (HWY_UNLIKELY(shared_kib_ != tcluster.shared_kib)) {
|
||||||
warned = true;
|
warned = true;
|
||||||
fprintf(stderr, "WARNING: lp %zu shared_kib %zu != cluster %zu.\n",
|
HWY_WARN("lp %zu shared_kib %zu != cluster %zu.", lp, shared_kib_,
|
||||||
lp, shared_kib_, tcluster.shared_kib);
|
tcluster.shared_kib);
|
||||||
}
|
}
|
||||||
} // !warned
|
} // !warned
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CPUs without clusters are rarely more than dozens of cores, and 6 is a
|
||||||
|
// decent number of threads in a per-cluster pool.
|
||||||
|
constexpr bool kSplitLargeClusters = false;
|
||||||
|
constexpr size_t kMaxClusters = 8;
|
||||||
|
constexpr size_t kMaxLPsPerCluster = 6;
|
||||||
|
|
||||||
|
// Topology is unknown, rely on OS affinity and user-specified slice.
|
||||||
|
BoundedTopology::Package::Package(const LPS& enabled_lps,
|
||||||
|
BoundedSlice lp_slice) {
|
||||||
|
LPS clusters_lps[kMaxClusters];
|
||||||
|
const size_t num_clusters =
|
||||||
|
kSplitLargeClusters
|
||||||
|
? HWY_MIN(kMaxClusters,
|
||||||
|
hwy::DivCeil(enabled_lps.Count(), kMaxLPsPerCluster))
|
||||||
|
: 1;
|
||||||
|
|
||||||
|
// Interpret `lp_slice` as a slice of the 1-bits of `enabled_lps`, so
|
||||||
|
// we honor both the OS affinity and the user-specified slice. Note that
|
||||||
|
// this can be used to exclude hyperthreads because Linux groups LPs by
|
||||||
|
// sibling index. For example, the first `num_cores` are not siblings.
|
||||||
|
const size_t detected = enabled_lps.Count();
|
||||||
|
size_t enabled_idx = 0;
|
||||||
|
enabled_lps.Foreach([&](size_t lp) {
|
||||||
|
if (lp_slice.Contains(detected, enabled_idx)) {
|
||||||
|
clusters_lps[enabled_idx % num_clusters].Set(lp);
|
||||||
|
}
|
||||||
|
++enabled_idx;
|
||||||
|
});
|
||||||
|
|
||||||
|
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
|
||||||
|
clusters.push_back(Cluster(clusters_lps[cluster_idx]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NOTE: caller is responsible for checking whether `clusters` is empty.
|
// NOTE: caller is responsible for checking whether `clusters` is empty.
|
||||||
BoundedTopology::Package::Package(const LPS& enabled_lps,
|
BoundedTopology::Package::Package(const LPS& enabled_lps,
|
||||||
const hwy::Topology& topology,
|
const hwy::Topology& topology, size_t pkg_idx,
|
||||||
size_t package_idx,
|
|
||||||
BoundedSlice cluster_slice) {
|
BoundedSlice cluster_slice) {
|
||||||
const hwy::Topology::Package& tpackage = topology.packages[package_idx];
|
const hwy::Topology::Package& tpackage = topology.packages[pkg_idx];
|
||||||
// Populate `clusters` with the subset of clusters in `cluster_slice` that
|
// Populate `clusters` with the subset of clusters in `cluster_slice` that
|
||||||
// have any enabled LPs. If `clusters` remains empty, the caller will
|
// have any enabled LPs. If `clusters` remains empty, the caller will
|
||||||
// skip this `Package`.
|
// skip this `Package`.
|
||||||
|
|
@ -233,10 +252,34 @@ BoundedTopology::Package::Package(const LPS& enabled_lps,
|
||||||
|
|
||||||
// Skip if empty, i.e. too few `enabled_lps`.
|
// Skip if empty, i.e. too few `enabled_lps`.
|
||||||
if (HWY_LIKELY(cluster.Size() != 0)) {
|
if (HWY_LIKELY(cluster.Size() != 0)) {
|
||||||
clusters.push_back(std::move(cluster));
|
clusters.push_back(cluster);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
SortByDescendingSize(clusters);
|
SortByDescendingSize(clusters);
|
||||||
|
|
||||||
|
// If there is only one large cluster, split it into smaller ones.
|
||||||
|
if (kSplitLargeClusters && clusters.size() == 1 &&
|
||||||
|
enabled_lps.Count() >= 16) {
|
||||||
|
const LPS lps = clusters[0].LPSet(); // copy so we can clear
|
||||||
|
clusters.clear();
|
||||||
|
|
||||||
|
// Split `lps` into several clusters.
|
||||||
|
LPS clusters_lps[kMaxClusters];
|
||||||
|
const size_t num_clusters =
|
||||||
|
HWY_MIN(kMaxClusters, hwy::DivCeil(lps.Count(), kMaxLPsPerCluster));
|
||||||
|
size_t num_lps = 0;
|
||||||
|
lps.Foreach(
|
||||||
|
[&](size_t lp) { clusters_lps[num_lps++ % num_clusters].Set(lp); });
|
||||||
|
HWY_DASSERT(num_lps == lps.Count());
|
||||||
|
|
||||||
|
// Create new clusters, just inserting the new LPS.
|
||||||
|
hwy::Topology::Cluster tcluster = tpackage.clusters[0]; // modifiable copy
|
||||||
|
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
|
||||||
|
tcluster.lps = clusters_lps[cluster_idx];
|
||||||
|
// Keep same `private_kib` and `shared_kib`.
|
||||||
|
clusters.push_back(Cluster(enabled_lps, topology.lps, tcluster));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !GEMMA_DISABLE_TOPOLOGY
|
#if !GEMMA_DISABLE_TOPOLOGY
|
||||||
|
|
@ -256,10 +299,9 @@ static void ScanTClusters(hwy::Topology& topology_, size_t& max_tclusters,
|
||||||
max_tclusters = 0;
|
max_tclusters = 0;
|
||||||
max_tcluster_cores = 0;
|
max_tcluster_cores = 0;
|
||||||
max_tcluster_lps = 0;
|
max_tcluster_lps = 0;
|
||||||
for (size_t package_idx = 0; package_idx < topology_.packages.size();
|
for (size_t pkg_idx = 0; pkg_idx < topology_.packages.size(); ++pkg_idx) {
|
||||||
++package_idx) {
|
|
||||||
const std::vector<hwy::Topology::Cluster>& tclusters =
|
const std::vector<hwy::Topology::Cluster>& tclusters =
|
||||||
topology_.packages[package_idx].clusters;
|
topology_.packages[pkg_idx].clusters;
|
||||||
max_tclusters = HWY_MAX(max_tclusters, tclusters.size());
|
max_tclusters = HWY_MAX(max_tclusters, tclusters.size());
|
||||||
size_t tcluster_cores = 0;
|
size_t tcluster_cores = 0;
|
||||||
size_t tcluster_lps = 0;
|
size_t tcluster_lps = 0;
|
||||||
|
|
@ -272,10 +314,10 @@ static void ScanTClusters(hwy::Topology& topology_, size_t& max_tclusters,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tclusters.size() > 1 && tcluster_cores > 8) {
|
if (tclusters.size() > 1 && tcluster_cores > 8) {
|
||||||
fprintf(stderr,
|
HWY_WARN(
|
||||||
"Package %zu: multiple clusters with max size %zu, whereas CCX "
|
"Package %zu: multiple clusters with max size %zu, whereas CCX "
|
||||||
"only have 8, may indicate a bug in hwy::Topology.\n",
|
"only have 8, may indicate a bug in hwy::Topology.",
|
||||||
package_idx, tcluster_cores);
|
pkg_idx, tcluster_cores);
|
||||||
}
|
}
|
||||||
max_tcluster_cores = HWY_MAX(max_tcluster_cores, tcluster_cores);
|
max_tcluster_cores = HWY_MAX(max_tcluster_cores, tcluster_cores);
|
||||||
max_tcluster_lps = HWY_MAX(max_tcluster_lps, tcluster_lps);
|
max_tcluster_lps = HWY_MAX(max_tcluster_lps, tcluster_lps);
|
||||||
|
|
@ -294,8 +336,8 @@ void BoundedTopology::InitFromTopology(const LPS& enabled_lps,
|
||||||
|
|
||||||
// (Possibly empty) subset of `Topology` packages that have `enabled_lps`.
|
// (Possibly empty) subset of `Topology` packages that have `enabled_lps`.
|
||||||
package_slice.Foreach(
|
package_slice.Foreach(
|
||||||
"package", topology_.packages.size(), [&](size_t package_idx) {
|
"package", topology_.packages.size(), [&](size_t pkg_idx) {
|
||||||
Package package(enabled_lps, topology_, package_idx, cluster_slice);
|
Package package(enabled_lps, topology_, pkg_idx, cluster_slice);
|
||||||
// Skip if empty, i.e. too few `enabled_lps`.
|
// Skip if empty, i.e. too few `enabled_lps`.
|
||||||
if (HWY_LIKELY(!package.clusters.empty())) {
|
if (HWY_LIKELY(!package.clusters.empty())) {
|
||||||
packages_.push_back(std::move(package));
|
packages_.push_back(std::move(package));
|
||||||
|
|
@ -313,18 +355,18 @@ void BoundedTopology::InitFromTopology(const LPS& enabled_lps,
|
||||||
|
|
||||||
// Scan for max BoundedTopology clusters and their size, for topology_string_.
|
// Scan for max BoundedTopology clusters and their size, for topology_string_.
|
||||||
size_t all_max_cluster_size = 0;
|
size_t all_max_cluster_size = 0;
|
||||||
for (size_t package_idx = 0; package_idx < NumPackages(); ++package_idx) {
|
for (size_t pkg_idx = 0; pkg_idx < NumPackages(); ++pkg_idx) {
|
||||||
size_t max_cluster_size = 0;
|
size_t max_cluster_size = 0;
|
||||||
for (size_t cluster_idx = 0; cluster_idx < NumClusters(package_idx);
|
for (size_t cluster_idx = 0; cluster_idx < NumClusters(pkg_idx);
|
||||||
++cluster_idx) {
|
++cluster_idx) {
|
||||||
max_cluster_size = HWY_MAX(max_cluster_size,
|
max_cluster_size =
|
||||||
GetCluster(package_idx, cluster_idx).Size());
|
HWY_MAX(max_cluster_size, GetCluster(pkg_idx, cluster_idx).Size());
|
||||||
}
|
}
|
||||||
if (NumClusters(package_idx) > 1 && max_cluster_size > 8) {
|
if (NumClusters(pkg_idx) > 1 && max_cluster_size > 8) {
|
||||||
fprintf(stderr,
|
HWY_WARN(
|
||||||
"Package %zu: multiple clusters with max size %zu, whereas CCX "
|
"Package %zu: multiple clusters with max size %zu, whereas CCX "
|
||||||
"only have 8, may indicate a bug in BoundedTopology.\n",
|
"only have 8, may indicate a bug in BoundedTopology.",
|
||||||
package_idx, max_cluster_size);
|
pkg_idx, max_cluster_size);
|
||||||
}
|
}
|
||||||
all_max_cluster_size = HWY_MAX(all_max_cluster_size, max_cluster_size);
|
all_max_cluster_size = HWY_MAX(all_max_cluster_size, max_cluster_size);
|
||||||
}
|
}
|
||||||
|
|
@ -382,10 +424,10 @@ NestedPools::NestedPools(size_t max_threads, Tristate pin,
|
||||||
// calling thread of an all_clusters->Run, and hence pinned to one of the
|
// calling thread of an all_clusters->Run, and hence pinned to one of the
|
||||||
// `cluster.lps` if `pin`.
|
// `cluster.lps` if `pin`.
|
||||||
all_packages_->Run(
|
all_packages_->Run(
|
||||||
0, all_packages_->NumWorkers(), [&](uint64_t package_idx, size_t thread) {
|
0, all_packages_->NumWorkers(), [&](uint64_t pkg_idx, size_t thread) {
|
||||||
HWY_ASSERT(package_idx == thread); // each thread has one task
|
HWY_ASSERT(pkg_idx == thread); // each thread has one task
|
||||||
packages_[package_idx] =
|
packages_[pkg_idx] =
|
||||||
Package(topology_, package_idx, max_workers_per_package, lp_slice);
|
Package(topology_, pkg_idx, max_workers_per_package, lp_slice);
|
||||||
});
|
});
|
||||||
|
|
||||||
all_pinned_ = GetPinning().AllPinned(&pin_string_);
|
all_pinned_ = GetPinning().AllPinned(&pin_string_);
|
||||||
|
|
@ -405,12 +447,11 @@ NestedPools::NestedPools(size_t max_threads, Tristate pin,
|
||||||
HWY_ASSERT(max_workers_per_cluster_ <= 256);
|
HWY_ASSERT(max_workers_per_cluster_ <= 256);
|
||||||
}
|
}
|
||||||
|
|
||||||
NestedPools::Package::Package(const BoundedTopology& topology,
|
NestedPools::Package::Package(const BoundedTopology& topology, size_t pkg_idx,
|
||||||
size_t package_idx,
|
|
||||||
size_t max_workers_per_package,
|
size_t max_workers_per_package,
|
||||||
BoundedSlice lp_slice) {
|
BoundedSlice lp_slice) {
|
||||||
// Pre-allocate because elements are set concurrently.
|
// Pre-allocate because elements are set concurrently.
|
||||||
clusters_.resize(topology.NumClusters(package_idx));
|
clusters_.resize(topology.NumClusters(pkg_idx));
|
||||||
const size_t max_workers_per_cluster =
|
const size_t max_workers_per_cluster =
|
||||||
DivideMaxAcross(max_workers_per_package, clusters_.size());
|
DivideMaxAcross(max_workers_per_package, clusters_.size());
|
||||||
|
|
||||||
|
|
@ -421,7 +462,7 @@ NestedPools::Package::Package(const BoundedTopology& topology,
|
||||||
0, all_clusters_->NumWorkers(), [&](size_t cluster_idx, size_t thread) {
|
0, all_clusters_->NumWorkers(), [&](size_t cluster_idx, size_t thread) {
|
||||||
HWY_ASSERT(cluster_idx == thread); // each thread has one task
|
HWY_ASSERT(cluster_idx == thread); // each thread has one task
|
||||||
const BoundedTopology::Cluster& cluster =
|
const BoundedTopology::Cluster& cluster =
|
||||||
topology.GetCluster(package_idx, cluster_idx);
|
topology.GetCluster(pkg_idx, cluster_idx);
|
||||||
clusters_[cluster_idx] =
|
clusters_[cluster_idx] =
|
||||||
MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster));
|
MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster));
|
||||||
// Pin workers AND the calling thread from `all_clusters`.
|
// Pin workers AND the calling thread from `all_clusters`.
|
||||||
|
|
|
||||||
140
util/threading.h
140
util/threading.h
|
|
@ -108,7 +108,7 @@ class BoundedTopology {
|
||||||
|
|
||||||
class Cluster {
|
class Cluster {
|
||||||
public:
|
public:
|
||||||
Cluster(const LPS& enabled_lps, BoundedSlice lp_slice);
|
Cluster(const LPS& lps);
|
||||||
Cluster(const LPS& enabled_lps,
|
Cluster(const LPS& enabled_lps,
|
||||||
const std::vector<hwy::Topology::LP>& all_lps,
|
const std::vector<hwy::Topology::LP>& all_lps,
|
||||||
const hwy::Topology::Cluster& tcluster);
|
const hwy::Topology::Cluster& tcluster);
|
||||||
|
|
@ -124,17 +124,12 @@ class BoundedTopology {
|
||||||
return lps;
|
return lps;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const LPS& LPSet() const { return lps_; }
|
||||||
size_t Node() const { return node_; }
|
size_t Node() const { return node_; }
|
||||||
size_t PrivateKiB() const { return private_kib_; }
|
size_t PrivateKiB() const { return private_kib_; }
|
||||||
size_t SharedKiB() const { return shared_kib_; }
|
size_t SharedKiB() const { return shared_kib_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void AddLP(size_t lp) {
|
|
||||||
HWY_ASSERT(!lps_.Get(lp)); // Foreach ensures uniqueness
|
|
||||||
lps_.Set(lp);
|
|
||||||
++num_workers_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enabled LPs; if topology is known, only the ones in this cluster.
|
// Enabled LPs; if topology is known, only the ones in this cluster.
|
||||||
LPS lps_;
|
LPS lps_;
|
||||||
// How many workers in the per-cluster pool. If 0, this Cluster is removed.
|
// How many workers in the per-cluster pool. If 0, this Cluster is removed.
|
||||||
|
|
@ -147,19 +142,19 @@ class BoundedTopology {
|
||||||
size_t shared_kib_ = 0;
|
size_t shared_kib_ = 0;
|
||||||
}; // Cluster
|
}; // Cluster
|
||||||
|
|
||||||
size_t NumClusters(size_t package_idx) const {
|
size_t NumClusters(size_t pkg_idx) const {
|
||||||
HWY_ASSERT(package_idx < NumPackages());
|
HWY_ASSERT(pkg_idx < NumPackages());
|
||||||
return packages_[package_idx].clusters.size();
|
return packages_[pkg_idx].clusters.size();
|
||||||
}
|
}
|
||||||
const Cluster& GetCluster(size_t package_idx, size_t cluster_idx) const {
|
const Cluster& GetCluster(size_t pkg_idx, size_t cluster_idx) const {
|
||||||
HWY_ASSERT(package_idx < NumPackages());
|
HWY_ASSERT(pkg_idx < NumPackages());
|
||||||
const Package& package = packages_[package_idx];
|
const Package& package = packages_[pkg_idx];
|
||||||
HWY_ASSERT(cluster_idx < package.clusters.size());
|
HWY_ASSERT(cluster_idx < package.clusters.size());
|
||||||
return package.clusters[cluster_idx];
|
return package.clusters[cluster_idx];
|
||||||
}
|
}
|
||||||
Cluster& GetCluster(size_t package_idx, size_t cluster_idx) {
|
Cluster& GetCluster(size_t pkg_idx, size_t cluster_idx) {
|
||||||
HWY_ASSERT(package_idx < NumPackages());
|
HWY_ASSERT(pkg_idx < NumPackages());
|
||||||
Package& package = packages_[package_idx];
|
Package& package = packages_[pkg_idx];
|
||||||
HWY_ASSERT(cluster_idx < package.clusters.size());
|
HWY_ASSERT(cluster_idx < package.clusters.size());
|
||||||
return package.clusters[cluster_idx];
|
return package.clusters[cluster_idx];
|
||||||
}
|
}
|
||||||
|
|
@ -170,13 +165,9 @@ class BoundedTopology {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct Package {
|
struct Package {
|
||||||
// Topology is unknown, rely on OS affinity and user-specified slice.
|
Package(const LPS& enabled_lps, BoundedSlice lp_slice);
|
||||||
Package(const LPS& enabled_lps, BoundedSlice lp_slice) {
|
|
||||||
clusters.push_back(Cluster(enabled_lps, lp_slice));
|
|
||||||
}
|
|
||||||
|
|
||||||
Package(const LPS& enabled_lps, const hwy::Topology& topology,
|
Package(const LPS& enabled_lps, const hwy::Topology& topology,
|
||||||
size_t package_idx, BoundedSlice cluster_slice);
|
size_t pkg_idx, BoundedSlice cluster_slice);
|
||||||
|
|
||||||
// For SortByDescendingSize.
|
// For SortByDescendingSize.
|
||||||
size_t Size() const { return clusters.size(); }
|
size_t Size() const { return clusters.size(); }
|
||||||
|
|
@ -257,33 +248,36 @@ class NestedPools {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t NumPackages() const { return packages_.size(); }
|
||||||
hwy::ThreadPool& AllPackages() { return *all_packages_; }
|
hwy::ThreadPool& AllPackages() { return *all_packages_; }
|
||||||
hwy::ThreadPool& AllClusters(size_t package_idx) {
|
hwy::ThreadPool& AllClusters(size_t pkg_idx) {
|
||||||
HWY_DASSERT(package_idx < packages_.size());
|
HWY_DASSERT(pkg_idx < NumPackages());
|
||||||
return packages_[package_idx].AllClusters();
|
return packages_[pkg_idx].AllClusters();
|
||||||
}
|
}
|
||||||
hwy::ThreadPool& Cluster(size_t package_idx, size_t cluster_idx) {
|
hwy::ThreadPool& Cluster(size_t pkg_idx, size_t cluster_idx) {
|
||||||
HWY_DASSERT(package_idx < packages_.size());
|
HWY_DASSERT(pkg_idx < NumPackages());
|
||||||
return packages_[package_idx].Cluster(cluster_idx);
|
return packages_[pkg_idx].Cluster(cluster_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
// For binding to NUMA nodes.
|
// For binding to NUMA nodes.
|
||||||
size_t Node(size_t package_idx, size_t cluster_idx) const {
|
size_t Node(size_t pkg_idx, size_t cluster_idx) const {
|
||||||
return topology_.GetCluster(package_idx, cluster_idx).Node();
|
return topology_.GetCluster(pkg_idx, cluster_idx).Node();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reasonably tight upper bound for allocating thread-local storage (TLS).
|
// Reasonably tight upper bounds for allocating thread-local storage (TLS).
|
||||||
size_t MaxWorkers() const {
|
size_t MaxWorkersPerCluster() const { return max_workers_per_cluster_; }
|
||||||
return packages_.size() * max_clusters_per_package_ *
|
size_t MaxWorkersPerPackage() const {
|
||||||
max_workers_per_cluster_;
|
return max_clusters_per_package_ * MaxWorkersPerCluster();
|
||||||
}
|
}
|
||||||
// Returns the first of `cluster.NumWorkers()` TLS indices, to which callers
|
size_t MaxWorkers() const { return NumPackages() * MaxWorkersPerPackage(); }
|
||||||
// add the worker index given by `cluster.Run`.
|
|
||||||
size_t WorkerOffset(size_t package_idx, size_t cluster_idx) const {
|
// Actual number of workers.
|
||||||
HWY_DASSERT(package_idx < packages_.size());
|
size_t TotalWorkers() const {
|
||||||
HWY_DASSERT(cluster_idx < packages_[package_idx].NumClusters());
|
size_t total_workers = 0;
|
||||||
return (package_idx * max_clusters_per_package_ + cluster_idx) *
|
for (size_t pkg_idx = 0; pkg_idx < NumPackages(); ++pkg_idx) {
|
||||||
max_workers_per_cluster_;
|
total_workers += packages_[pkg_idx].TotalWorkers();
|
||||||
|
}
|
||||||
|
return total_workers;
|
||||||
}
|
}
|
||||||
|
|
||||||
// For Allocator
|
// For Allocator
|
||||||
|
|
@ -296,20 +290,20 @@ class NestedPools {
|
||||||
// if there is more than one, which maximizes available memory bandwidth, or
|
// if there is more than one, which maximizes available memory bandwidth, or
|
||||||
// the first cluster, which is typically the whole package. For use by callers
|
// the first cluster, which is typically the whole package. For use by callers
|
||||||
// that only have a single parallel-for.
|
// that only have a single parallel-for.
|
||||||
hwy::ThreadPool& Pool(size_t package_idx = 0) {
|
hwy::ThreadPool& Pool(size_t pkg_idx = 0) {
|
||||||
// Only one cluster: use its pool, typically a whole socket.
|
// Only one cluster: use its pool, typically a whole socket.
|
||||||
if (AllClusters(package_idx).NumWorkers() == 1) {
|
if (AllClusters(pkg_idx).NumWorkers() == 1) {
|
||||||
return Cluster(package_idx, 0);
|
return Cluster(pkg_idx, 0);
|
||||||
}
|
}
|
||||||
// One worker per cluster to maximize bandwidth availability.
|
// One worker per cluster to maximize bandwidth availability.
|
||||||
return AllClusters(package_idx);
|
return AllClusters(pkg_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class Package {
|
class Package {
|
||||||
public:
|
public:
|
||||||
Package() = default; // for vector
|
Package() = default; // for vector
|
||||||
Package(const BoundedTopology& topology, size_t package_idx,
|
Package(const BoundedTopology& topology, size_t pkg_idx,
|
||||||
size_t max_workers_per_package, BoundedSlice lp_slice);
|
size_t max_workers_per_package, BoundedSlice lp_slice);
|
||||||
|
|
||||||
size_t NumClusters() const { return clusters_.size(); }
|
size_t NumClusters() const { return clusters_.size(); }
|
||||||
|
|
@ -321,6 +315,13 @@ class NestedPools {
|
||||||
}
|
}
|
||||||
return max_workers_per_cluster;
|
return max_workers_per_cluster;
|
||||||
}
|
}
|
||||||
|
size_t TotalWorkers() const {
|
||||||
|
size_t total_workers = 0;
|
||||||
|
for (const PoolPtr& cluster : clusters_) {
|
||||||
|
total_workers += cluster->NumWorkers();
|
||||||
|
}
|
||||||
|
return total_workers;
|
||||||
|
}
|
||||||
|
|
||||||
hwy::ThreadPool& AllClusters() { return *all_clusters_; }
|
hwy::ThreadPool& AllClusters() { return *all_clusters_; }
|
||||||
hwy::ThreadPool& Cluster(size_t cluster_idx) {
|
hwy::ThreadPool& Cluster(size_t cluster_idx) {
|
||||||
|
|
@ -365,32 +366,34 @@ class NestedPools {
|
||||||
// functions below.
|
// functions below.
|
||||||
class IndexRangePartition {
|
class IndexRangePartition {
|
||||||
public:
|
public:
|
||||||
|
IndexRangePartition() = default; // for MMPartitions
|
||||||
IndexRangePartition(const IndexRange& range, const size_t task_size)
|
IndexRangePartition(const IndexRange& range, const size_t task_size)
|
||||||
: range_(range), task_size_(task_size) {
|
: range_(range), task_size_(static_cast<uint32_t>(task_size)) {
|
||||||
const size_t num = range.Num();
|
const uint32_t num = static_cast<uint32_t>(range.Num());
|
||||||
HWY_DASSERT(task_size_ != 0);
|
HWY_DASSERT(task_size_ != 0);
|
||||||
num_tasks_ = hwy::DivCeil(num, task_size_);
|
num_tasks_ = hwy::DivCeil(num, task_size_);
|
||||||
HWY_DASSERT(num_tasks_ != 0);
|
HWY_DASSERT(num_tasks_ != 0);
|
||||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||||
const size_t handled = num_tasks_ * task_size_;
|
const uint32_t handled = num_tasks_ * task_size_;
|
||||||
// The last task may extend beyond items, but at most by (task_size_ - 1).
|
// The last task may extend beyond items, but at most by (task_size_ - 1).
|
||||||
HWY_DASSERT(num <= handled && handled < num + task_size_);
|
HWY_DASSERT(num <= handled && handled < num + task_size_);
|
||||||
|
(void)handled;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t TaskSize() const { return task_size_; }
|
size_t TaskSize() const { return static_cast<size_t>(task_size_); }
|
||||||
size_t NumTasks() const { return num_tasks_; }
|
size_t NumTasks() const { return static_cast<size_t>(num_tasks_); }
|
||||||
|
|
||||||
IndexRange Range(size_t task_idx) const {
|
IndexRange Range(size_t task_idx) const {
|
||||||
HWY_DASSERT(task_idx < NumTasks());
|
HWY_DASSERT(task_idx < NumTasks());
|
||||||
return MakeIndexRange(range_.begin() + task_idx * task_size_, range_.end(),
|
return MakeIndexRange(range_.begin() + task_idx * TaskSize(), range_.end(),
|
||||||
task_size_);
|
TaskSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
IndexRange range_;
|
IndexRange range_;
|
||||||
size_t task_size_;
|
uint32_t task_size_;
|
||||||
size_t num_tasks_;
|
uint32_t num_tasks_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Starts with `max_size` and rounds DOWN to a multiple of `size_multiple`
|
// Starts with `max_size` and rounds DOWN to a multiple of `size_multiple`
|
||||||
|
|
@ -455,33 +458,6 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// As above, for three ranges.
|
|
||||||
template <class Func>
|
|
||||||
void ParallelizeThreeRanges(const IndexRangePartition& get1,
|
|
||||||
const IndexRangePartition& get2,
|
|
||||||
const IndexRangePartition& get3,
|
|
||||||
hwy::ThreadPool& pool, const Func& func) {
|
|
||||||
const hwy::Divisor div1(static_cast<uint32_t>(get1.NumTasks()));
|
|
||||||
const size_t num12 = get1.NumTasks() * get2.NumTasks();
|
|
||||||
const hwy::Divisor div12(static_cast<uint32_t>(num12));
|
|
||||||
|
|
||||||
const size_t num_tasks = num12 * get3.NumTasks();
|
|
||||||
pool.Run(0, num_tasks, [&](uint64_t task, size_t thread) {
|
|
||||||
HWY_DASSERT(task < (uint64_t{1} << 32));
|
|
||||||
const size_t idx3 = div12.Divide(static_cast<uint32_t>(task));
|
|
||||||
const size_t task12 = div12.Remainder(static_cast<uint32_t>(task));
|
|
||||||
const size_t idx2 = div1.Divide(static_cast<uint32_t>(task12));
|
|
||||||
const size_t idx1 = div1.Remainder(static_cast<uint32_t>(task12));
|
|
||||||
HWY_DASSERT(idx1 < get1.NumTasks());
|
|
||||||
HWY_DASSERT(idx2 < get2.NumTasks());
|
|
||||||
HWY_DASSERT(idx3 < get3.NumTasks());
|
|
||||||
const IndexRange range1 = get1.Range(idx1);
|
|
||||||
const IndexRange range2 = get2.Range(idx2);
|
|
||||||
const IndexRange range3 = get3.Range(idx3);
|
|
||||||
func(range1, range2, range3, thread);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
|
||||||
|
|
|
||||||
|
|
@ -138,6 +138,13 @@ TEST(ThreadingTest, TestMaxSizePartition) {
|
||||||
HWY_ASSERT(partition.TaskSize() == 55);
|
HWY_ASSERT(partition.TaskSize() == 55);
|
||||||
HWY_ASSERT(partition.NumTasks() == 2);
|
HWY_ASSERT(partition.NumTasks() == 2);
|
||||||
}
|
}
|
||||||
|
// `size_multiple` almost as large as range: imbalanced
|
||||||
|
{
|
||||||
|
const IndexRangePartition partition =
|
||||||
|
MaxSizePartition(IndexRange(0, 6), 6, 4);
|
||||||
|
HWY_ASSERT(partition.TaskSize() == 4);
|
||||||
|
HWY_ASSERT(partition.NumTasks() == 2);
|
||||||
|
}
|
||||||
// Small `max_size`: small tasks
|
// Small `max_size`: small tasks
|
||||||
{
|
{
|
||||||
const IndexRangePartition partition = MaxSizePartition(range, 2, 1);
|
const IndexRangePartition partition = MaxSizePartition(range, 2, 1);
|
||||||
|
|
@ -244,97 +251,5 @@ TEST(ThreadingTest, TestParallelizeTwoRanges) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ThreadingTest, TestParallelizeThreeRanges) {
|
|
||||||
// Named according to number of tasks.
|
|
||||||
const IndexRangePartition partition3 =
|
|
||||||
StaticPartition(IndexRange(0, 8), 3, 1); // [0, 3) [3, 6) [6, 8)
|
|
||||||
HWY_ASSERT(partition3.NumTasks() == 3);
|
|
||||||
const IndexRangePartition partition2 =
|
|
||||||
MaxSizePartition(IndexRange(10, 30), 10, 10); // [10, 20), [20, 30)
|
|
||||||
HWY_ASSERT(partition2.NumTasks() == 2);
|
|
||||||
const IndexRangePartition partition4 =
|
|
||||||
MaxSizePartition(IndexRange(100, 500), 100, 100); // 100, 200, 300, 400
|
|
||||||
HWY_ASSERT(partition4.NumTasks() == 4);
|
|
||||||
|
|
||||||
const auto check_ranges = [&](const IndexRange& range3,
|
|
||||||
const IndexRange& range2,
|
|
||||||
const IndexRange& range4) {
|
|
||||||
HWY_ASSERT(range3.begin() == 0 || range3.begin() == 3 ||
|
|
||||||
range3.begin() == 6);
|
|
||||||
HWY_ASSERT(range2.begin() == 10 || range2.begin() == 20);
|
|
||||||
HWY_ASSERT(range4.begin() % 100 == 0);
|
|
||||||
};
|
|
||||||
|
|
||||||
hwy::ThreadPool null_pool(0);
|
|
||||||
// All 6 permutations of the three ranges to test the Remainder() logic:
|
|
||||||
// 3, 2, 4
|
|
||||||
{
|
|
||||||
size_t calls = 0;
|
|
||||||
ParallelizeThreeRanges(
|
|
||||||
partition3, partition2, partition4, null_pool,
|
|
||||||
[&](IndexRange range3, IndexRange range2, IndexRange range4, size_t) {
|
|
||||||
++calls;
|
|
||||||
check_ranges(range3, range2, range4);
|
|
||||||
});
|
|
||||||
HWY_ASSERT(calls == 3 * 2 * 4);
|
|
||||||
}
|
|
||||||
// 3, 4, 2
|
|
||||||
{
|
|
||||||
size_t calls = 0;
|
|
||||||
ParallelizeThreeRanges(
|
|
||||||
partition3, partition4, partition2, null_pool,
|
|
||||||
[&](IndexRange range3, IndexRange range4, IndexRange range2, size_t) {
|
|
||||||
++calls;
|
|
||||||
check_ranges(range3, range2, range4);
|
|
||||||
});
|
|
||||||
HWY_ASSERT(calls == 3 * 2 * 4);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4, 2, 3
|
|
||||||
{
|
|
||||||
size_t calls = 0;
|
|
||||||
ParallelizeThreeRanges(
|
|
||||||
partition4, partition2, partition3, null_pool,
|
|
||||||
[&](IndexRange range4, IndexRange range2, IndexRange range3, size_t) {
|
|
||||||
++calls;
|
|
||||||
check_ranges(range3, range2, range4);
|
|
||||||
});
|
|
||||||
HWY_ASSERT(calls == 3 * 2 * 4);
|
|
||||||
}
|
|
||||||
// 4, 3, 2
|
|
||||||
{
|
|
||||||
size_t calls = 0;
|
|
||||||
ParallelizeThreeRanges(
|
|
||||||
partition4, partition3, partition2, null_pool,
|
|
||||||
[&](IndexRange range4, IndexRange range3, IndexRange range2, size_t) {
|
|
||||||
++calls;
|
|
||||||
check_ranges(range3, range2, range4);
|
|
||||||
});
|
|
||||||
HWY_ASSERT(calls == 3 * 2 * 4);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2, 3, 4
|
|
||||||
{
|
|
||||||
size_t calls = 0;
|
|
||||||
ParallelizeThreeRanges(
|
|
||||||
partition2, partition3, partition4, null_pool,
|
|
||||||
[&](IndexRange range2, IndexRange range3, IndexRange range4, size_t) {
|
|
||||||
++calls;
|
|
||||||
check_ranges(range3, range2, range4);
|
|
||||||
});
|
|
||||||
HWY_ASSERT(calls == 3 * 2 * 4);
|
|
||||||
}
|
|
||||||
// 2, 4, 3
|
|
||||||
{
|
|
||||||
size_t calls = 0;
|
|
||||||
ParallelizeThreeRanges(
|
|
||||||
partition2, partition4, partition3, null_pool,
|
|
||||||
[&](IndexRange range2, IndexRange range4, IndexRange range3, size_t) {
|
|
||||||
++calls;
|
|
||||||
check_ranges(range3, range2, range4);
|
|
||||||
});
|
|
||||||
HWY_ASSERT(calls == 3 * 2 * 4);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue